22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327 | class PredibaseLLM(CustomLLM):
"""Predibase LLM.
To use, you should have the ``predibase`` python package installed,
and have your Predibase API key.
The `model_name` parameter is the Predibase "serverless" base_model ID
(see https://docs.predibase.com/user-guide/inference/models for the catalog).
An optional `adapter_id` parameter is the Predibase ID or the HuggingFace ID
of a fine-tuned LLM adapter, whose base model is the `model` parameter; the
fine-tuned adapter must be compatible with its base model; otherwise, an
error is raised. If the fine-tuned adapter is hosted at Predibase,
`adapter_version` must be specified.
Examples:
`pip install llama-index-llms-predibase`
```python
import os
os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
from llama_index.llms.predibase import PredibaseLLM
llm = PredibaseLLM(
model_name="mistral-7b",
predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)
adapter_id="my-adapter-id", # optional parameter
adapter_version=3, # optional parameter (applies to Predibase only)
api_token, # optional parameter for accessing services hosting adapters (e.g., HuggingFace)
temperature=0.3,
max_new_tokens=512,
)
response = llm.complete("Hello World!")
print(str(response))
```
"""
model_name: str = Field(description="The Predibase base model to use.")
predibase_api_key: str = Field(description="The Predibase API key to use.")
predibase_sdk_version: str = Field(
default=None,
description="The optional version (string) of the Predibase SDK (defaults to the latest if not specified).",
)
adapter_id: str = Field(
default=None,
description="The optional Predibase ID or HuggingFace ID of a fine-tuned adapter to use.",
)
adapter_version: str = Field(
default=None,
description="The optional version number of fine-tuned adapter use (applies to Predibase only).",
)
api_token: str = Field(
default=None,
description="The adapter hosting service API key to use.",
)
max_new_tokens: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="The number of tokens to generate.",
gt=0,
)
temperature: float = Field(
default=DEFAULT_TEMPERATURE,
description="The temperature to use for sampling.",
gte=0.0,
lte=1.0,
)
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description="The number of context tokens available to the LLM.",
gt=0,
)
_client: Any = PrivateAttr()
def __init__(
self,
model_name: str,
predibase_api_key: Optional[str] = None,
predibase_sdk_version: Optional[str] = None,
adapter_id: Optional[str] = None,
adapter_version: Optional[int] = None,
api_token: Optional[str] = None,
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
temperature: float = DEFAULT_TEMPERATURE,
context_window: int = DEFAULT_CONTEXT_WINDOW,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
) -> None:
predibase_api_key = (
predibase_api_key
if predibase_api_key
else os.environ.get("PREDIBASE_API_TOKEN")
)
if not predibase_api_key:
raise ValueError(
'Your "PREDIBASE_API_TOKEN" is empty. Please generate a valid "PREDIBASE_API_TOKEN" in your Predibase account.'
)
super().__init__(
model_name=model_name,
predibase_api_key=predibase_api_key,
predibase_sdk_version=predibase_sdk_version,
adapter_id=adapter_id,
adapter_version=adapter_version,
api_token=api_token,
max_new_tokens=max_new_tokens,
temperature=temperature,
context_window=context_window,
callback_manager=callback_manager,
system_prompt=system_prompt,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
pydantic_program_mode=pydantic_program_mode,
output_parser=output_parser,
)
self._client: Union["PredibaseClient", "Predibase"] = self.initialize_client()
def initialize_client(
self,
) -> Union["PredibaseClient", "Predibase"]:
try:
if self._is_deprecated_sdk_version():
from predibase import PredibaseClient
from predibase.pql import get_session
from predibase.pql.api import Session
session: Session = get_session(
token=self.predibase_api_key,
gateway="https://api.app.predibase.com/v1",
serving_endpoint="serving.app.predibase.com",
)
return PredibaseClient(session=session)
from predibase import Predibase
os.environ["PREDIBASE_GATEWAY"] = "https://api.app.predibase.com"
return Predibase(api_token=self.predibase_api_key)
except ValueError as e:
raise ValueError(
'Your "PREDIBASE_API_TOKEN" is not correct. Please try again.'
) from e
@classmethod
def class_name(cls) -> str:
return "PredibaseLLM"
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self.context_window,
num_output=self.max_new_tokens,
model_name=self.model_name,
)
@llm_completion_callback()
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> "CompletionResponse":
options: Dict[str, Union[str, float]] = {
**{
"api_token": self.api_token,
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
},
**(kwargs or {}),
}
response_text: str
if self._is_deprecated_sdk_version():
from predibase.pql.api import ServerResponseError
from predibase.resource.llm.interface import (
HuggingFaceLLM,
LLMDeployment,
)
from predibase.resource.llm.response import GeneratedResponse
from predibase.resource.model import Model
base_llm_deployment: LLMDeployment = self._client.LLM(
uri=f"pb://deployments/{self.model_name}"
)
result: GeneratedResponse
if self.adapter_id:
"""
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
If absent, then load the fine-tuned adapter from a HuggingFace repository.
"""
adapter_model: Union[Model, HuggingFaceLLM]
try:
adapter_model = self._client.get_model(
name=self.adapter_id,
version=self.adapter_version,
model_id=None,
)
except ServerResponseError:
# Predibase does not recognize the adapter ID (query HuggingFace).
adapter_model = self._client.LLM(uri=f"hf://{self.adapter_id}")
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
prompt=prompt,
options=options,
)
else:
result = base_llm_deployment.generate(
prompt=prompt,
options=options,
)
response_text = result.response
else:
import requests
from lorax.client import Client as LoraxClient
from lorax.errors import GenerationError
from lorax.types import Response
lorax_client: LoraxClient = self._client.deployments.client(
deployment_ref=self.model_name
)
response: Response
if self.adapter_id:
"""
Attempt to retrieve the fine-tuned adapter from a Predibase repository.
If absent, then load the fine-tuned adapter from a HuggingFace repository.
"""
if self.adapter_version:
# Since the adapter version is provided, query the Predibase repository.
pb_adapter_id: str = f"{self.adapter_id}/{self.adapter_version}"
options.pop("api_token", None)
try:
response = lorax_client.generate(
prompt=prompt,
adapter_id=pb_adapter_id,
**options,
)
except GenerationError as ge:
raise ValueError(
f'An adapter with the ID "{pb_adapter_id}" cannot be found in the Predibase repository of fine-tuned adapters.'
) from ge
else:
# The adapter version is omitted, hence look for the adapter ID in the HuggingFace repository.
try:
response = lorax_client.generate(
prompt=prompt,
adapter_id=self.adapter_id,
adapter_source="hub",
**options,
)
except GenerationError as ge:
raise ValueError(
f"""Either an adapter with the ID "{self.adapter_id}" cannot be found in a HuggingFace repository, \
or it is incompatible with the base model (please make sure that the adapter configuration is consistent).
"""
) from ge
else:
try:
response = lorax_client.generate(
prompt=prompt,
**options,
)
except requests.JSONDecodeError as jde:
raise ValueError(
f"""An LLM with the deployment ID "{self.model_name}" cannot be found at Predibase \
(please refer to "https://docs.predibase.com/user-guide/inference/models" for the list of supported models).
"""
) from jde
response_text = response.generated_text
return CompletionResponse(text=response_text)
@llm_completion_callback()
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> "CompletionResponseGen":
raise NotImplementedError
def _is_deprecated_sdk_version(self) -> bool:
try:
import semantic_version
from semantic_version.base import Version
from predibase.version import __version__ as current_version
sdk_semver_deprecated: Version = semantic_version.Version(
version_string="2024.4.8"
)
actual_current_version: str = self.predibase_sdk_version or current_version
sdk_semver_current: Version = semantic_version.Version(
version_string=actual_current_version
)
return not (
(sdk_semver_current > sdk_semver_deprecated)
or ("+dev" in actual_current_version)
)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install semantic_version predibase`."
) from e
|