classPremAI(LLM):"""PremAI LLM Provider."""project_id:int=Field(description=("The project ID in which the experiments or deployments are carried out. can find all your projects here: https://app.premai.io/projects/"))premai_api_key:Optional[str]=Field(description="Prem AI API Key. Get it here: https://app.premai.io/api_keys/")model:Optional[str]=Field(description=("Name of the model. This is an optional parameter. The default model is the one deployed from Prem's LaunchPad. An example: https://app.premai.io/projects/<project-id>/launchpad. If model name is other than default model then it will override the calls from the model deployed from launchpad."),)system_prompt:Optional[str]=Field(description=("System prompts helps the model to guide the generation and the way it acts. Default system prompt is the one set on your deployed LaunchPad model under the specified project."),)max_tokens:Optional[int]=Field(description=("The max number of tokens to output from the LLM. "))temperature:Optional[float]=Field(description="Model temperature. Value should be >= 0 and <= 1.0")max_retries:Optional[int]=Field(description="Max number of retries to call the API")repositories:Optional[dict]=Field(description="Add valid repository ids. This will be overriding existing connected repositories (if any) and will use RAG with the connected repos.")additional_kwargs:Optional[dict]=Field(description="Add any additional kwargs. This may override your existing settings.")_client:"Prem"=PrivateAttr()def__init__(self,project_id:int,premai_api_key:Optional[str]=None,model:Optional[str]=None,system_prompt:Optional[str]=None,max_tokens:Optional[str]=128,temperature:Optional[float]=0.1,max_retries:Optional[int]=1,repositories:Optional[dict]=None,additional_kwargs:Optional[Dict[str,Any]]=None,callback_manager:Optional[CallbackManager]=None,messages_to_prompt:Optional[Callable[[Sequence[ChatMessage]],str]]=None,completion_to_prompt:Optional[Callable[[str],str]]=None,pydantic_program_mode:PydanticProgramMode=PydanticProgramMode.DEFAULT,output_parser:Optional[BaseOutputParser]=None,**kwargs,):callback_manager=callback_managerorCallbackManager([])api_key=get_from_param_or_env("api_key",premai_api_key,"PREMAI_API_KEY","")ifnotapi_key:raiseValueError("You must provide an API key to use premai. ""You can either pass it in as an argument or set it `PREMAI_API_KEY`. You can get your API key here: https://app.premai.io/api_keys/")additional_kwargs={**(additional_kwargsor{}),**kwargs}super().__init__(project_id=project_id,temperature=temperature,max_tokens=max_tokens,model=model,api_key=api_key,callback_manager=callback_manager,system_prompt=system_prompt,additional_kwargs=additional_kwargs,messages_to_prompt=messages_to_prompt,completion_to_prompt=completion_to_prompt,pydantic_program_mode=pydantic_program_mode,output_parser=output_parser,max_retries=max_retries,repositories=repositories,)self._client=Prem(api_key=api_key)@classmethoddefclass_name(cls)->str:return"PremAI_LLM"@propertydefmetadata(self)->LLMMetadata:# TODO: We need to fetch information from prem-sdk herereturnLLMMetadata(num_output=self.max_tokens,is_chat_model=True,temperature=self.temperature,)@propertydef_model_kwargs(self)->Dict[str,Any]:return{"model":self.model,"temperature":self.temperature,"max_tokens":self.max_tokens,"system_prompt":self.system_prompt,"repositories":self.repositories,}def_get_all_kwargs(self,**kwargs)->Dict[str,Any]:kwargs_to_ignore=["top_p","tools","frequency_penalty","presence_penalty","logit_bias","stop","seed",]keys_to_remove=[]forkeyinkwargs:ifkeyinkwargs_to_ignore:print(f"WARNING: Parameter {key} is not supported in kwargs.")keys_to_remove.append(key)forkeyinkeys_to_remove:kwargs.pop(key)all_kwargs={**self._model_kwargs,**kwargs}forkeyinlist(self._model_kwargs.keys()):ifall_kwargs.get(key)isNoneorall_kwargs.get(key)=="":all_kwargs.pop(key,None)returnall_kwargs@llm_chat_callback()defchat(self,messages:Sequence[ChatMessage],**kwargs:Any)->ChatResponse:all_kwargs=self._get_all_kwargs(**{**self.additional_kwargs,**kwargs})chat_messages,all_kwargs=prepare_messages_before_chat(messages=messages,**all_kwargs)response=self._client.chat.completions.create(project_id=self.project_id,messages=chat_messages,**all_kwargs)ifnotresponse.choices:raiseChatPremError("ChatResponse must have at least one candidate")choice=response.choices[0]role=choice.message.roleifroleisNone:raiseChatPremError(f"ChatResponse {choice} must have a role.")content=choice.message.contentor""returnChatResponse(message=ChatMessage(role=role,content=content),raw={"role":role,"content":content,"document_chunks":[chunk.to_dict()forchunkinresponse.document_chunks],},)defstream_chat(self,messages:Sequence[ChatMessage],**kwargs:Any)->ChatResponseGen:all_kwargs=self._get_all_kwargs(**{**self.additional_kwargs,**kwargs})chat_messages,all_kwargs=prepare_messages_before_chat(messages=messages,**all_kwargs)response_generator=self._client.chat.completions.create(project_id=self.project_id,messages=chat_messages,stream=True,**all_kwargs,)defgen()->ChatResponseGen:content=""role=MessageRole.ASSISTANTforchunkinresponse_generator:delta=chunk.choices[0].deltaifdeltaisNoneordelta["content"]isNone:continuechunk_content=delta["content"]content+=chunk_contentyieldChatResponse(message=ChatMessage(content=content,role=role),delta=chunk_content)returngen()defachat(self):raiseNotImplementedError("Current version of premai does not support async calls.")@llm_completion_callback()defcomplete(self,prompt:str,formatted:bool=False,**kwargs:Any)->CompletionResponse:complete_fn=chat_to_completion_decorator(self.chat)kwargs["is_completion"]=Truereturncomplete_fn(prompt,**kwargs)@llm_completion_callback()defstream_complete(self,prompt:str,formatted:bool=False,**kwargs:Any)->CompletionResponseGen:stream_complete_fn=stream_chat_to_completion_decorator(self.stream_chat)returnstream_complete_fn(prompt,**kwargs)defacomplete(self):raiseNotImplementedError("Current version of premai does not support async calls.")defastream_complete(self):raiseNotImplementedError("Current version of premai does not support async calls.")defastream_chat(self):raiseNotImplementedError("Current version of premai does not support async calls.")