class VertexTextEmbedding(BaseEmbedding):
embed_mode: VertexEmbeddingMode = Field(description="The embedding mode to use.")
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict, description="Additional kwargs for the Vertex."
)
_model: TextEmbeddingModel = PrivateAttr()
def __init__(
self,
model_name: str = "textembedding-gecko@003",
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
embed_mode: VertexEmbeddingMode = VertexEmbeddingMode.RETRIEVAL_MODE,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
additional_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
init_vertexai(project=project, location=location, credentials=credentials)
callback_manager = callback_manager or CallbackManager([])
additional_kwargs = additional_kwargs or {}
super().__init__(
embed_mode=embed_mode,
additional_kwargs=additional_kwargs,
model_name=model_name,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
)
self._model = TextEmbeddingModel.from_pretrained(model_name)
@classmethod
def class_name(cls) -> str:
return "VertexTextEmbedding"
def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
texts = _get_embedding_request(
texts=texts, embed_mode=self.embed_mode, is_query=False
)
embeddings = self._model.get_embeddings(texts, **self.additional_kwargs)
return [embedding.values for embedding in embeddings]
def _get_text_embedding(self, text: str) -> Embedding:
return self._get_text_embeddings([text])[0]
async def _aget_text_embedding(self, text: str) -> Embedding:
return (await self._aget_text_embeddings([text]))[0]
async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
texts = _get_embedding_request(
texts=texts, embed_mode=self.embed_mode, is_query=False
)
embeddings = await self._model.get_embeddings_async(
texts, **self.additional_kwargs
)
return [embedding.values for embedding in embeddings]
def _get_query_embedding(self, query: str) -> Embedding:
texts = _get_embedding_request(
texts=[query], embed_mode=self.embed_mode, is_query=True
)
embeddings = self._model.get_embeddings(texts, **self.additional_kwargs)
return embeddings[0].values
async def _aget_query_embedding(self, query: str) -> Embedding:
texts = _get_embedding_request(
texts=[query], embed_mode=self.embed_mode, is_query=True
)
embeddings = await self._model.get_embeddings_async(
texts, **self.additional_kwargs
)
return embeddings[0].values