classCohereEmbedding(MultiModalEmbedding):"""CohereEmbedding uses the Cohere API to generate embeddings for text."""# Instance variables initialized via Pydantic's mechanismapi_key:str=Field(description="The Cohere API key.")truncate:str=Field(description="Truncation type - START/ END/ NONE")input_type:Optional[str]=Field(default=None,description="Model Input type. If not provided, search_document and search_query are used when needed.",)embedding_type:str=Field(description="Embedding type. If not provided float embedding_type is used when needed.")_client:cohere.Client=PrivateAttr()_async_client:cohere.AsyncClient=PrivateAttr()_base_url:Optional[str]=PrivateAttr()_timeout:Optional[float]=PrivateAttr()_httpx_client:Optional[httpx.Client]=PrivateAttr()_httpx_async_client:Optional[httpx.AsyncClient]=PrivateAttr()def__init__(self,# deprecatedcohere_api_key:Optional[str]=None,api_key:Optional[str]=None,model_name:str="embed-english-v3.0",truncate:str="END",input_type:Optional[str]=None,embedding_type:str="float",embed_batch_size:int=DEFAULT_EMBED_BATCH_SIZE,callback_manager:Optional[CallbackManager]=None,base_url:Optional[str]=None,timeout:Optional[float]=None,httpx_client:Optional[httpx.Client]=None,httpx_async_client:Optional[httpx.AsyncClient]=None,num_workers:Optional[int]=None,**kwargs:Any,):""" A class representation for generating embeddings using the Cohere API. Args: truncate (str): A string indicating the truncation strategy to be applied to input text. Possible values are 'START', 'END', or 'NONE'. input_type (Optional[str]): An optional string that specifies the type of input provided to the model. This is model-dependent and could be one of the following: 'search_query', 'search_document', 'classification', or 'clustering'. model_name (str): The name of the model to be used for generating embeddings. The class ensures that this model is supported and that the input type provided is compatible with the model. """# Validate model_name and input_typeifmodel_namenotinVALID_MODEL_INPUT_TYPES:raiseValueError(f"{model_name} is not a valid model name")ifinput_typenotinVALID_MODEL_INPUT_TYPES[model_name]:raiseValueError(f"{input_type} is not a valid input type for the provided model.")ifembedding_typenotinVALID_MODEL_EMBEDDING_TYPES[model_name]:raiseValueError(f"{embedding_type} is not a embedding type for the provided model.")iftruncatenotinVALID_TRUNCATE_OPTIONS:raiseValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}")super().__init__(api_key=api_keyorcohere_api_key,model_name=model_name,input_type=input_type,embedding_type=embedding_type,truncate=truncate,embed_batch_size=embed_batch_size,callback_manager=callback_manager,num_workers=num_workers,**kwargs,)self._client=Noneself._async_client=Noneself._base_url=base_urlself._timeout=timeoutself._httpx_client=httpx_clientself._httpx_async_client=httpx_async_clientdef_get_client(self)->cohere.Client:ifself._clientisNone:self._client=cohere.Client(api_key=self.api_key,client_name="llama_index",base_url=self._base_url,timeout=self._timeout,httpx_client=self._httpx_client,)returnself._clientdef_get_async_client(self)->cohere.AsyncClient:ifself._async_clientisNone:self._async_client=cohere.AsyncClient(api_key=self.api_key,client_name="llama_index",base_url=self._base_url,timeout=self._timeout,httpx_client=self._httpx_async_client,)returnself._async_client@classmethoddefclass_name(cls)->str:return"CohereEmbedding"def_image_to_base64_data_url(self,image_input:Union[str,Path,BytesIO])->str:"""Convert an image to a base64 Data URL."""ifisinstance(image_input,(str,Path)):# If it's a string or Path, assume it's a file pathimage_path=str(image_input)file_extension=os.path.splitext(image_path)[1][1:].lower()withopen(image_path,"rb")asf:image_data=f.read()elifisinstance(image_input,BytesIO):# If it's a BytesIO, use it directlyimage=Image.open(image_input)file_extension=image.format.lower()image_input.seek(0)# Reset the BytesIO stream to the beginningimage_data=image_input.read()else:raiseValueError("Unsupported input type. Must be a file path or BytesIO.")ifself._validate_image_format(file_extension):enc_img=base64.b64encode(image_data).decode("utf-8")returnf"data:image/{file_extension};base64,{enc_img}"else:raiseValueError(f"Unsupported image format: {file_extension}")def_validate_image_format(self,file_type:str)->bool:"""Validate image format."""returnfile_type.lower()inSUPPORTED_IMAGE_FORMATSdef_embed(self,texts:List[str],input_type:str)->List[List[float]]:"""Embed sentences using Cohere."""client=self._get_client()ifself.model_nameinV3_MODELS:result=client.embed(texts=texts,input_type=self.input_typeorinput_type,embedding_types=[self.embedding_type],model=self.model_name,truncate=self.truncate,).embeddingselse:result=client.embed(texts=texts,model=self.model_name,embedding_types=[self.embedding_type],truncate=self.truncate,).embeddingsreturngetattr(result,self.embedding_type,None)asyncdef_aembed(self,texts:List[str],input_type:str)->List[List[float]]:"""Embed sentences using Cohere."""async_client=self._get_async_client()ifself.model_nameinV3_MODELS:result=(awaitasync_client.embed(texts=texts,input_type=self.input_typeorinput_type,embedding_types=[self.embedding_type],model=self.model_name,truncate=self.truncate,)).embeddingselse:result=(awaitasync_client.embed(texts=texts,model=self.model_name,embedding_types=[self.embedding_type],truncate=self.truncate,)).embeddingsreturngetattr(result,self.embedding_type,None)def_embed_image(self,image_path:ImageType,input_type:str)->List[float]:"""Embed images using Cohere."""ifself.model_namenotinV3_MODELS:raiseValueError(f"{self.model_name} is not a valid multi-modal embedding model. Supported models are {V3_MODELS}")client=self._get_client()processed_image=self._image_to_base64_data_url(image_path)returnclient.embed(model=self.model_name,images=[processed_image],input_type=input_type,).embeddingsasyncdef_aembed_image(self,image_path:ImageType,input_type:str)->List[float]:"""Embed images using Cohere."""ifself.model_namenotinV3_MODELS:raiseValueError(f"{self.model_name} is not a valid multi-modal embedding model. Supported models are {V3_MODELS}")async_client=self._get_async_client()processed_image=self._image_to_base64_data_url(image_path)return(awaitasync_client.embed(model=self.model_name,images=[processed_image],input_type=input_type,)).embeddingsdef_get_query_embedding(self,query:str)->List[float]:"""Get query embedding. For query embeddings, input_type='search_query'."""returnself._embed([query],input_type="search_query")[0]asyncdef_aget_query_embedding(self,query:str)->List[float]:"""Get query embedding async. For query embeddings, input_type='search_query'."""return(awaitself._aembed([query],input_type="search_query"))[0]def_get_text_embedding(self,text:str)->List[float]:"""Get text embedding."""returnself._embed([text],input_type="search_document")[0]asyncdef_aget_text_embedding(self,text:str)->List[float]:"""Get text embedding async."""return(awaitself._aembed([text],input_type="search_document"))[0]def_get_text_embeddings(self,texts:List[str])->List[List[float]]:"""Get text embeddings."""returnself._embed(texts,input_type="search_document")asyncdef_aget_text_embeddings(self,texts:List[str])->List[List[float]]:"""Get text embeddings."""returnawaitself._aembed(texts,input_type="search_document")def_get_image_embedding(self,img_file_path:ImageType)->Embedding:"""Get image embedding."""returnself._embed_image(img_file_path,"image")[0]asyncdef_aget_image_embedding(self,img_file_path:ImageType)->Embedding:"""Get image embedding async."""return(awaitself._aembed_image(img_file_path,"image"))[0]