fromllama_index.llms.vllmimportVllm# specific functions to format for mistral instructdefmessages_to_prompt(messages):prompt="\n".join([str(x)forxinmessages])returnf"<s>[INST] {prompt} [/INST] </s>\n"defcompletion_to_prompt(completion):returnf"<s>[INST] {completion} [/INST] </s>\n"llm=Vllm(model="mistralai/Mistral-7B-Instruct-v0.1",tensor_parallel_size=4,max_new_tokens=256,vllm_kwargs={"swap_space":1,"gpu_memory_utilization":0.5},messages_to_prompt=messages_to_prompt,completion_to_prompt=completion_to_prompt,)llm.complete("What is a black hole?")
Source code in llama-index-integrations/llms/llama-index-llms-vllm/llama_index/llms/vllm/base.py
classVllm(LLM):r"""Vllm LLM. This class runs a vLLM model locally. Examples: `pip install llama-index-llms-vllm` ```python from llama_index.llms.vllm import Vllm # specific functions to format for mistral instruct def messages_to_prompt(messages): prompt = "\n".join([str(x) for x in messages]) return f"<s>[INST] {prompt} [/INST] </s>\n" def completion_to_prompt(completion): return f"<s>[INST] {completion} [/INST] </s>\n" llm = Vllm( model="mistralai/Mistral-7B-Instruct-v0.1", tensor_parallel_size=4, max_new_tokens=256, vllm_kwargs={"swap_space": 1, "gpu_memory_utilization": 0.5}, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, ) llm.complete( "What is a black hole?" ) ``` """model:Optional[str]=Field(description="The HuggingFace Model to use.")temperature:float=Field(description="The temperature to use for sampling.")tensor_parallel_size:Optional[int]=Field(default=1,description="The number of GPUs to use for distributed execution with tensor parallelism.",)trust_remote_code:Optional[bool]=Field(default=True,description="Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.",)n:int=Field(default=1,description="Number of output sequences to return for the given prompt.",)best_of:Optional[int]=Field(default=None,description="Number of output sequences that are generated from the prompt.",)presence_penalty:float=Field(default=0.0,description="Float that penalizes new tokens based on whether they appear in the generated text so far.",)frequency_penalty:float=Field(default=0.0,description="Float that penalizes new tokens based on their frequency in the generated text so far.",)top_p:float=Field(default=1.0,description="Float that controls the cumulative probability of the top tokens to consider.",)top_k:int=Field(default=-1,description="Integer that controls the number of top tokens to consider.",)use_beam_search:bool=Field(default=False,description="Whether to use beam search instead of sampling.")stop:Optional[List[str]]=Field(default=None,description="List of strings that stop the generation when they are generated.",)ignore_eos:bool=Field(default=False,description="Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.",)max_new_tokens:int=Field(default=512,description="Maximum number of tokens to generate per output sequence.",)logprobs:Optional[int]=Field(default=None,description="Number of log probabilities to return per output token.",)dtype:str=Field(default="auto",description="The data type for the model weights and activations.",)download_dir:Optional[str]=Field(default=None,description="Directory to download and load the weights. (Default to the default cache dir of huggingface)",)vllm_kwargs:Dict[str,Any]=Field(default_factory=dict,description="Holds any model parameters valid for `vllm.LLM` call not explicitly specified.",)api_url:str=Field(description="The api url for vllm server")_client:Any=PrivateAttr()def__init__(self,model:str="facebook/opt-125m",temperature:float=1.0,tensor_parallel_size:int=1,trust_remote_code:bool=True,n:int=1,best_of:Optional[int]=None,presence_penalty:float=0.0,frequency_penalty:float=0.0,top_p:float=1.0,top_k:int=-1,use_beam_search:bool=False,stop:Optional[List[str]]=None,ignore_eos:bool=False,max_new_tokens:int=512,logprobs:Optional[int]=None,dtype:str="auto",download_dir:Optional[str]=None,vllm_kwargs:Dict[str,Any]={},api_url:Optional[str]="",callback_manager:Optional[CallbackManager]=None,system_prompt:Optional[str]=None,messages_to_prompt:Optional[Callable[[Sequence[ChatMessage]],str]]=None,completion_to_prompt:Optional[Callable[[str],str]]=None,pydantic_program_mode:PydanticProgramMode=PydanticProgramMode.DEFAULT,output_parser:Optional[BaseOutputParser]=None,)->None:ifnotapi_url:try:fromvllmimportLLMasVLLModelexceptImportError:raiseImportError("Could not import vllm python package. ""Please install it with `pip install vllm`.")self._client=VLLModel(model=model,tensor_parallel_size=tensor_parallel_size,trust_remote_code=trust_remote_code,dtype=dtype,download_dir=download_dir,**vllm_kwargs)else:self._client=Nonecallback_manager=callback_managerorCallbackManager([])super().__init__(model=model,temperature=temperature,n=n,best_of=best_of,presence_penalty=presence_penalty,frequency_penalty=frequency_penalty,top_p=top_p,top_k=top_k,use_beam_search=use_beam_search,stop=stop,ignore_eos=ignore_eos,max_new_tokens=max_new_tokens,logprobs=logprobs,dtype=dtype,download_dir=download_dir,vllm_kwargs=vllm_kwargs,api_url=api_url,callback_manager=callback_manager,system_prompt=system_prompt,messages_to_prompt=messages_to_prompt,completion_to_prompt=completion_to_prompt,pydantic_program_mode=pydantic_program_mode,output_parser=output_parser,)@classmethoddefclass_name(cls)->str:return"Vllm"@propertydefmetadata(self)->LLMMetadata:returnLLMMetadata(model_name=self.model)@propertydef_model_kwargs(self)->Dict[str,Any]:base_kwargs={"temperature":self.temperature,"max_tokens":self.max_new_tokens,"n":self.n,"frequency_penalty":self.frequency_penalty,"presence_penalty":self.presence_penalty,"use_beam_search":self.use_beam_search,"best_of":self.best_of,"ignore_eos":self.ignore_eos,"stop":self.stop,"logprobs":self.logprobs,"top_k":self.top_k,"top_p":self.top_p,}return{**base_kwargs}def__del__(self)->None:importtorchiftorch.cuda.is_available():fromvllm.model_executor.parallel_utils.parallel_stateimport(destroy_model_parallel,)destroy_model_parallel()delself._clienttorch.cuda.synchronize()def_get_all_kwargs(self,**kwargs:Any)->Dict[str,Any]:return{**self._model_kwargs,**kwargs,}@llm_chat_callback()defchat(self,messages:Sequence[ChatMessage],**kwargs:Any)->ChatResponse:kwargs=kwargsifkwargselse{}prompt=self.messages_to_prompt(messages)completion_response=self.complete(prompt,**kwargs)returncompletion_response_to_chat_response(completion_response)@llm_completion_callback()defcomplete(self,prompt:str,formatted:bool=False,**kwargs:Any)->CompletionResponse:kwargs=kwargsifkwargselse{}params={**self._model_kwargs,**kwargs}fromvllmimportSamplingParams# build sampling parameterssampling_params=SamplingParams(**params)outputs=self._client.generate([prompt],sampling_params)returnCompletionResponse(text=outputs[0].outputs[0].text)@llm_chat_callback()defstream_chat(self,messages:Sequence[ChatMessage],**kwargs:Any)->ChatResponseGen:raise(ValueError("Not Implemented"))@llm_completion_callback()defstream_complete(self,prompt:str,formatted:bool=False,**kwargs:Any)->CompletionResponseGen:raise(ValueError("Not Implemented"))@llm_chat_callback()asyncdefachat(self,messages:Sequence[ChatMessage],**kwargs:Any)->ChatResponse:kwargs=kwargsifkwargselse{}returnself.chat(messages,**kwargs)@llm_completion_callback()asyncdefacomplete(self,prompt:str,formatted:bool=False,**kwargs:Any)->CompletionResponse:kwargs=kwargsifkwargselse{}returnself.complete(prompt,**kwargs)@llm_chat_callback()asyncdefastream_chat(self,messages:Sequence[ChatMessage],**kwargs:Any)->ChatResponseAsyncGen:raise(ValueError("Not Implemented"))@llm_completion_callback()asyncdefastream_complete(self,prompt:str,formatted:bool=False,**kwargs:Any)->CompletionResponseAsyncGen:raise(ValueError("Not Implemented"))
This class connects to a vLLM server (non-openai versions).
If using the OpenAI-API vLLM server, please see the OpenAILike LLM class.
Examples:
pip install llama-index-llms-vllm
fromllama_index.llms.vllmimportVllmServer# specific functions to format for mistral instructdefmessages_to_prompt(messages):prompt="\n".join([str(x)forxinmessages])returnf"<s>[INST] {prompt} [/INST] </s>\n"defcompletion_to_prompt(completion):returnf"<s>[INST] {completion} [/INST] </s>\n"llm=VllmServer(api_url=api_url,max_new_tokens=256,temperature=0.1,messages_to_prompt=messages_to_prompt,completion_to_prompt=completion_to_prompt,)llm.complete("What is a black hole?")
Source code in llama-index-integrations/llms/llama-index-llms-vllm/llama_index/llms/vllm/base.py
classVllmServer(Vllm):r"""Vllm LLM. This class connects to a vLLM server (non-openai versions). If using the OpenAI-API vLLM server, please see the `OpenAILike` LLM class. Examples: `pip install llama-index-llms-vllm` ```python from llama_index.llms.vllm import VllmServer # specific functions to format for mistral instruct def messages_to_prompt(messages): prompt = "\n".join([str(x) for x in messages]) return f"<s>[INST] {prompt} [/INST] </s>\n" def completion_to_prompt(completion): return f"<s>[INST] {completion} [/INST] </s>\n" llm = VllmServer( api_url=api_url, max_new_tokens=256, temperature=0.1, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, ) llm.complete( "What is a black hole?" ) ``` """def__init__(self,model:str="facebook/opt-125m",api_url:str="http://localhost:8000",temperature:float=1.0,tensor_parallel_size:Optional[int]=1,trust_remote_code:Optional[bool]=True,n:int=1,best_of:Optional[int]=None,presence_penalty:float=0.0,frequency_penalty:float=0.0,top_p:float=1.0,top_k:int=-1,use_beam_search:bool=False,stop:Optional[List[str]]=None,ignore_eos:bool=False,max_new_tokens:int=512,logprobs:Optional[int]=None,dtype:str="auto",download_dir:Optional[str]=None,messages_to_prompt:Optional[Callable]=None,completion_to_prompt:Optional[Callable]=None,vllm_kwargs:Dict[str,Any]={},callback_manager:Optional[CallbackManager]=None,output_parser:Optional[BaseOutputParser]=None,)->None:self._client=Nonemessages_to_prompt=messages_to_promptorgeneric_messages_to_promptcompletion_to_prompt=completion_to_promptor(lambdax:x)callback_manager=callback_managerorCallbackManager([])super().__init__(model=model,temperature=temperature,n=n,best_of=best_of,presence_penalty=presence_penalty,frequency_penalty=frequency_penalty,top_p=top_p,top_k=top_k,use_beam_search=use_beam_search,stop=stop,ignore_eos=ignore_eos,max_new_tokens=max_new_tokens,logprobs=logprobs,dtype=dtype,download_dir=download_dir,messages_to_prompt=messages_to_prompt,completion_to_prompt=completion_to_prompt,vllm_kwargs=vllm_kwargs,api_url=api_url,callback_manager=callback_manager,output_parser=output_parser,)@classmethoddefclass_name(cls)->str:return"VllmServer"def__del__(self)->None:...@llm_completion_callback()defcomplete(self,prompt:str,formatted:bool=False,**kwargs:Any)->CompletionResponse:kwargs=kwargsifkwargselse{}params={**self._model_kwargs,**kwargs}# build sampling parameterssampling_params=dict(**params)sampling_params["prompt"]=promptresponse=post_http_request(self.api_url,sampling_params,stream=False)output=get_response(response)returnCompletionResponse(text=output[0])@llm_completion_callback()defstream_complete(self,prompt:str,formatted:bool=False,**kwargs:Any)->CompletionResponseGen:kwargs=kwargsifkwargselse{}params={**self._model_kwargs,**kwargs}sampling_params=dict(**params)sampling_params["prompt"]=promptresponse=post_http_request(self.api_url,sampling_params,stream=True)defgen()->CompletionResponseGen:response_str=""prev_prefix_len=len(prompt)forchunkinresponse.iter_lines(chunk_size=8192,decode_unicode=False,delimiter=b"\0"):ifchunk:data=json.loads(chunk.decode("utf-8"))increasing_concat=data["text"][0]pref=prev_prefix_lenprev_prefix_len=len(increasing_concat)yieldCompletionResponse(text=increasing_concat,delta=increasing_concat[pref:])returngen()@llm_completion_callback()asyncdefacomplete(self,prompt:str,formatted:bool=False,**kwargs:Any)->CompletionResponse:kwargs=kwargsifkwargselse{}returnself.complete(prompt,**kwargs)@llm_completion_callback()asyncdefastream_complete(self,prompt:str,formatted:bool=False,**kwargs:Any)->CompletionResponseAsyncGen:kwargs=kwargsifkwargselse{}params={**self._model_kwargs,**kwargs}# build sampling parameterssampling_params=dict(**params)sampling_params["prompt"]=promptasyncdefgen()->CompletionResponseAsyncGen:formessageinself.stream_complete(prompt,**kwargs):yieldmessagereturngen()@llm_chat_callback()defstream_chat(self,messages:Sequence[ChatMessage],**kwargs:Any)->ChatResponseGen:prompt=self.messages_to_prompt(messages)completion_response=self.stream_complete(prompt,**kwargs)returnstream_completion_response_to_chat_response(completion_response)@llm_chat_callback()asyncdefastream_chat(self,messages:Sequence[ChatMessage],**kwargs:Any)->ChatResponseAsyncGen:asyncdefgen()->ChatResponseAsyncGen:formessageinself.stream_chat(messages,**kwargs):yieldmessagereturngen()