22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142 | class PredibaseLLM(CustomLLM):
"""Predibase LLM.
Examples:
`pip install llama-index-llms-predibase`
```python
import os
os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
from llama_index.llms.predibase import PredibaseLLM
llm = PredibaseLLM(
model_name="llama-2-13b", temperature=0.3, max_new_tokens=512
)
response = llm.complete("Hello World!")
print(str(response))
```
"""
model_name: str = Field(description="The Predibase model to use.")
predibase_api_key: str = Field(description="The Predibase API key to use.")
max_new_tokens: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="The number of tokens to generate.",
gt=0,
)
temperature: float = Field(
default=DEFAULT_TEMPERATURE,
description="The temperature to use for sampling.",
gte=0.0,
lte=1.0,
)
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description="The number of context tokens available to the LLM.",
gt=0,
)
_client: Any = PrivateAttr()
def __init__(
self,
model_name: str,
predibase_api_key: Optional[str] = None,
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
temperature: float = DEFAULT_TEMPERATURE,
context_window: int = DEFAULT_CONTEXT_WINDOW,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
) -> None:
predibase_api_key = (
predibase_api_key
if predibase_api_key
else os.environ.get("PREDIBASE_API_TOKEN")
)
assert predibase_api_key is not None
self._client = self.initialize_client(predibase_api_key)
super().__init__(
model_name=model_name,
predibase_api_key=predibase_api_key,
max_new_tokens=max_new_tokens,
temperature=temperature,
context_window=context_window,
callback_manager=callback_manager,
system_prompt=system_prompt,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
pydantic_program_mode=pydantic_program_mode,
output_parser=output_parser,
)
@staticmethod
def initialize_client(predibase_api_key: str) -> Any:
try:
from predibase import PredibaseClient
return PredibaseClient(token=predibase_api_key)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install predibase`."
) from e
except ValueError as e:
raise ValueError("Your API key is not correct. Please try again") from e
@classmethod
def class_name(cls) -> str:
return "PredibaseLLM"
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self.context_window,
num_output=self.max_new_tokens,
model_name=self.model_name,
)
@llm_completion_callback()
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> "CompletionResponse":
llm = self._client.LLM(f"pb://deployments/{self.model_name}")
results = llm.prompt(
prompt, max_new_tokens=self.max_new_tokens, temperature=self.temperature
)
return CompletionResponse(text=results.response)
@llm_completion_callback()
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> "CompletionResponseGen":
raise NotImplementedError
|