classUpstageEmbedding(OpenAIEmbedding):""" Class for Upstage embeddings. """additional_kwargs:Dict[str,Any]=Field(default_factory=dict,description="Additional kwargs for the Upstage API.")api_key:str=Field(description="The Upstage API key.")api_base:Optional[str]=Field(default=DEFAULT_UPSTAGE_API_BASE,description="The base URL for Upstage API.")dimensions:Optional[int]=Field(None,description="Not supported yet. The number of dimensions the resulting output embeddings should have.",)def__init__(self,model:str="embedding",embed_batch_size:int=100,dimensions:Optional[int]=None,additional_kwargs:Dict[str,Any]=None,api_key:Optional[str]=None,api_base:Optional[str]=None,max_retries:int=10,timeout:float=60.0,reuse_client:bool=True,callback_manager:Optional[CallbackManager]=None,default_headers:Optional[Dict[str,str]]=None,http_client:Optional[httpx.Client]=None,**kwargs:Any,)->None:additional_kwargs=additional_kwargsor{}ifdimensionsisnotNone:warnings.warn("Received dimensions argument. This is not supported yet.")additional_kwargs["dimensions"]=dimensionsifembed_batch_size>MAX_EMBED_BATCH_SIZE:raiseValueError(f"embed_batch_size should be less than or equal to {MAX_EMBED_BATCH_SIZE}.")if"upstage_api_key"inkwargs:api_key=kwargs.pop("upstage_api_key")api_key,api_base=resolve_upstage_credentials(api_key=api_key,api_base=api_base)if"model_name"inkwargs:model=kwargs.pop("model_name")# if model endswith with "-query" or "-passage", remove the suffix and print a warningifmodel.endswith(("-query","-passage")):model=model.rsplit("-",1)[0]logger.warning(f"Model name should not end with '-query' or '-passage'. The suffix has been removed. "f"Model name: {model}")super().__init__(embed_batch_size=embed_batch_size,dimensions=dimensions,callback_manager=callback_manager,model_name=model,additional_kwargs=additional_kwargs,api_key=api_key,api_base=api_base,max_retries=max_retries,reuse_client=reuse_client,timeout=timeout,default_headers=default_headers,**kwargs,)self._client=Noneself._aclient=Noneself._http_client=http_clientself._query_engine,self._text_engine=get_engine(model)defclass_name(cls)->str:return"UpstageEmbedding"def_get_credential_kwargs(self,is_async:bool=False)->Dict[str,Any]:return{"api_key":self.api_key,"base_url":self.api_base,"max_retries":self.max_retries,"timeout":self.timeout,"default_headers":self.default_headers,"http_client":self._async_http_clientifis_asyncelseself._http_client,}def_get_query_embedding(self,query:str)->List[float]:"""Get query embedding."""client=self._get_client()text=query.replace("\n"," ")return(client.embeddings.create(input=text,model=self._query_engine,**self.additional_kwargs).data[0].embedding)asyncdef_aget_query_embedding(self,query:str)->List[float]:"""The asynchronous version of _get_query_embedding."""client=self._get_aclient()text=query.replace("\n"," ")return((awaitclient.embeddings.create(input=text,model=self._query_engine,**self.additional_kwargs)).data[0].embedding)def_get_text_embedding(self,text:str)->List[float]:"""Get text embedding."""client=self._get_client()return(client.embeddings.create(input=text,model=self._text_engine,**self.additional_kwargs).data[0].embedding)asyncdef_aget_text_embedding(self,text:str)->List[float]:"""Asynchronously get text embedding."""client=self._get_aclient()return((awaitclient.embeddings.create(input=text,model=self._text_engine,**self.additional_kwargs)).data[0].embedding)def_get_text_embeddings(self,texts:List[str])->List[List[float]]:"""Get text embeddings."""client=self._get_client()batch_size=min(self.embed_batch_size,len(texts))texts=[text.replace("\n"," ")fortextintexts]embeddings=[]foriinrange(0,len(texts),batch_size):batch=texts[i:i+batch_size]response=client.embeddings.create(input=batch,model=self._text_engine,**self.additional_kwargs)embeddings.extend([r.embeddingforrinresponse.data])returnembeddingsasyncdef_aget_text_embeddings(self,texts:List[str])->List[List[float]]:"""Asynchronously get text embeddings."""client=self._get_aclient()batch_size=min(self.embed_batch_size,len(texts))texts=[text.replace("\n"," ")fortextintexts]embeddings=[]foriinrange(0,len(texts),batch_size):batch=texts[i:i+batch_size]response=awaitclient.embeddings.create(input=batch,model=self._text_engine,**self.additional_kwargs)embeddings.extend([r.embeddingforrinresponse.data])returnembeddings