17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157 | class VertexEndpointEmbedding(BaseEmbedding):
endpoint_id: str = Field(description="Vertex AI endpoint ID")
project_id: str = Field(description="GCP Project ID")
location: str = Field(description="GCP Region for Vertex AI")
endpoint_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description="Additional kwargs for the predict request.",
)
model_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description="kwargs to pass to the model.",
)
content_handler: BaseIOHandler = Field(
default=DEFAULT_IO_HANDLER,
description="used to format input/output",
)
service_account_file: Optional[str] = Field(
default=None, description="Path to the service account JSON file."
)
service_account_info: Optional[Dict[str, str]] = Field(
default=None, description="Directly provide service account credentials."
)
timeout: Optional[float] = Field(
default=60.0,
description="Timeout for API requests in seconds.",
ge=0,
)
_client: aiplatform.Endpoint = PrivateAttr()
_verbose: bool = PrivateAttr()
def __init__(
self,
endpoint_id: str,
project_id: str,
location: str,
content_handler: BaseIOHandler = DEFAULT_IO_HANDLER,
endpoint_kwargs: Optional[Dict[str, Any]] = {},
model_kwargs: Optional[Dict[str, Any]] = {},
service_account_file: Optional[str] = None,
service_account_info: Optional[Dict[str, str]] = None,
timeout: Optional[float] = 60.0,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
):
super().__init__(
endpoint_id=endpoint_id,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
project_id=project_id,
location=location,
content_handler=content_handler,
endpoint_kwargs=endpoint_kwargs or {},
model_kwargs=model_kwargs or {},
timeout=timeout,
)
# Initialize the client
if service_account_file:
credentials = service_account.Credentials.from_service_account_file(
service_account_file
)
elif service_account_info:
credentials = service_account.Credentials.from_service_account_info(
service_account_info
)
else:
credentials = None # Use default application credentials if not provided
try:
self._client = aiplatform.Endpoint(
endpoint_name=endpoint_id,
project=project_id,
location=location,
credentials=credentials,
)
except Exception as e:
raise ValueError("Please verify the provided credentials.") from (e)
self._verbose = verbose
@classmethod
def class_name(cls) -> str:
return "VertexEndpointEmbedding"
def _get_embedding(self, payload: List[str], **kwargs: Any) -> List[Embedding]:
# Combine model kwargs with any additional kwargs passed to the function
endpoint_kwargs = {**self.endpoint_kwargs, **{"timeout": self.timeout}}
model_kwargs = {**self.model_kwargs, **kwargs}
# Directly send the input payload to the endpoint
response = self._client.predict(
instances=self.content_handler.serialize_input(payload),
parameters=model_kwargs,
**endpoint_kwargs
)
# Assuming response contains the embeddings in a field called 'predictions'
return self.content_handler.deserialize_output(response)
async def _aget_embedding(
self, payload: List[str], **kwargs: Any
) -> List[Embedding]:
# Combine model kwargs with any additional kwargs passed to the function
endpoint_kwargs = {**self.endpoint_kwargs, **{"timeout": self.timeout}}
model_kwargs = {**self.model_kwargs, **kwargs}
# Directly send the input payload to the endpoint
response = await self._client.predict_async(
instances=self.content_handler.serialize_input(payload),
parameters=model_kwargs,
**endpoint_kwargs
)
# Assuming response contains the embeddings in a field called 'predictions'
return self.content_handler.deserialize_output(response)
def _get_query_embedding(self, query: str, **kwargs: Any) -> Embedding:
query = query.replace("\n", " ")
return self._get_embedding([query], **kwargs)[0]
def _get_text_embedding(self, text: str, **kwargs: Any) -> Embedding:
text = text.replace("\n", " ")
return self._get_embedding([text], **kwargs)[0]
def _get_text_embeddings(self, texts: List[str], **kwargs: Any) -> List[Embedding]:
texts = [text.replace("\n", " ") for text in texts]
return self._get_embedding(texts, **kwargs)
async def _aget_query_embedding(self, query: str, **kwargs: Any) -> Embedding:
query = query.replace("\n", " ")
return await self._aget_embedding([query], **kwargs)[0]
async def _aget_text_embedding(self, text: str, **kwargs: Any) -> Embedding:
text = text.replace("\n", " ")
return await self._aget_embedding([text], **kwargs)[0]
async def _aget_text_embeddings(
self, texts: List[str], **kwargs: Any
) -> List[Embedding]:
texts = [text.replace("\n", " ") for text in texts]
return await self._aget_embedding(texts, **kwargs)
|