Skip to content

Predibase

PredibaseLLM #

Bases: CustomLLM

Predibase LLM.

To use, you should have the predibase python package installed, and have your Predibase API key.

The model_name parameter is the Predibase "serverless" base_model ID (see https://docs.predibase.com/user-guide/inference/models for the catalog).

An optional adapter_id parameter is the Predibase ID or the HuggingFace ID of a fine-tuned LLM adapter, whose base model is the model parameter; the fine-tuned adapter must be compatible with its base model; otherwise, an error is raised. If the fine-tuned adapter is hosted at Predibase, adapter_version can be specified (omitting it gives the latest version).

Examples:

pip install llama-index-llms-predibase

import os

os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"

from llama_index.llms.predibase import PredibaseLLM

llm = PredibaseLLM(
    model_name="mistral-7b",
    adapter_id="my-adapter-id",  # optional parameter
    adapter_version=3,  # optional parameter (applies to Predibase only)
    temperature=0.3,
    max_new_tokens=512,
)
response = llm.complete("Hello World!")
print(str(response))
Source code in llama-index-integrations/llms/llama-index-llms-predibase/llama_index/llms/predibase/base.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class PredibaseLLM(CustomLLM):
    """Predibase LLM.

    To use, you should have the ``predibase`` python package installed,
    and have your Predibase API key.

    The `model_name` parameter is the Predibase "serverless" base_model ID
    (see https://docs.predibase.com/user-guide/inference/models for the catalog).

    An optional `adapter_id` parameter is the Predibase ID or the HuggingFace ID
    of a fine-tuned LLM adapter, whose base model is the `model` parameter; the
    fine-tuned adapter must be compatible with its base model; otherwise, an
    error is raised.  If the fine-tuned adapter is hosted at Predibase,
    `adapter_version` can be specified (omitting it gives the latest version).

    Examples:
        `pip install llama-index-llms-predibase`

        ```python
        import os

        os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"

        from llama_index.llms.predibase import PredibaseLLM

        llm = PredibaseLLM(
            model_name="mistral-7b",
            adapter_id="my-adapter-id",  # optional parameter
            adapter_version=3,  # optional parameter (applies to Predibase only)
            temperature=0.3,
            max_new_tokens=512,
        )
        response = llm.complete("Hello World!")
        print(str(response))
        ```
    """

    model_name: str = Field(description="The Predibase base model to use.")
    predibase_api_key: str = Field(description="The Predibase API key to use.")
    adapter_id: str = Field(
        default=None,
        description="The optional Predibase ID or HuggingFace ID of a fine-tuned adapter to use.",
    )
    adapter_version: str = Field(
        default=None,
        description="The optional version number of fine-tuned adapter use (applies to Predibase only).",
    )
    max_new_tokens: int = Field(
        default=DEFAULT_NUM_OUTPUTS,
        description="The number of tokens to generate.",
        gt=0,
    )
    temperature: float = Field(
        default=DEFAULT_TEMPERATURE,
        description="The temperature to use for sampling.",
        gte=0.0,
        lte=1.0,
    )
    context_window: int = Field(
        default=DEFAULT_CONTEXT_WINDOW,
        description="The number of context tokens available to the LLM.",
        gt=0,
    )

    _client: Any = PrivateAttr()

    def __init__(
        self,
        model_name: str,
        predibase_api_key: Optional[str] = None,
        adapter_id: Optional[str] = None,
        adapter_version: Optional[int] = None,
        max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
        temperature: float = DEFAULT_TEMPERATURE,
        context_window: int = DEFAULT_CONTEXT_WINDOW,
        callback_manager: Optional[CallbackManager] = None,
        system_prompt: Optional[str] = None,
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
        completion_to_prompt: Optional[Callable[[str], str]] = None,
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
        output_parser: Optional[BaseOutputParser] = None,
    ) -> None:
        predibase_api_key = (
            predibase_api_key
            if predibase_api_key
            else os.environ.get("PREDIBASE_API_TOKEN")
        )
        assert predibase_api_key is not None

        super().__init__(
            model_name=model_name,
            adapter_id=adapter_id,
            adapter_version=adapter_version,
            predibase_api_key=predibase_api_key,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            context_window=context_window,
            callback_manager=callback_manager,
            system_prompt=system_prompt,
            messages_to_prompt=messages_to_prompt,
            completion_to_prompt=completion_to_prompt,
            pydantic_program_mode=pydantic_program_mode,
            output_parser=output_parser,
        )

        self._client = self.initialize_client(predibase_api_key)

    @staticmethod
    def initialize_client(predibase_api_key: str) -> Any:
        try:
            from predibase import PredibaseClient
            from predibase.pql import get_session
            from predibase.pql.api import Session

            session: Session = get_session(
                token=predibase_api_key,
                gateway="https://api.app.predibase.com/v1",
                serving_endpoint="serving.app.predibase.com",
            )
            return PredibaseClient(session=session)
        except ImportError as e:
            raise ImportError(
                "Could not import Predibase Python package. "
                "Please install it with `pip install predibase`."
            ) from e
        except ValueError as e:
            raise ValueError("Your API key is not correct. Please try again") from e

    @classmethod
    def class_name(cls) -> str:
        return "PredibaseLLM"

    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.max_new_tokens,
            model_name=self.model_name,
        )

    @llm_completion_callback()
    def complete(
        self, prompt: str, formatted: bool = False, **kwargs: Any
    ) -> "CompletionResponse":
        from predibase.pql.api import ServerResponseError
        from predibase.resource.llm.interface import (
            HuggingFaceLLM,
            LLMDeployment,
        )
        from predibase.resource.llm.response import GeneratedResponse
        from predibase.resource.model import Model

        base_llm_deployment: LLMDeployment = self._client.LLM(
            uri=f"pb://deployments/{self.model_name}"
        )

        options: Dict[str, Union[str, float]] = copy.deepcopy(kwargs)
        options.update(
            {
                "max_new_tokens": self.max_new_tokens,
                "temperature": self.temperature,
            }
        )

        result: GeneratedResponse
        if self.adapter_id:
            """
            Attempt to retrieve the fine-tuned adapter from a Predibase repository.
            If absent, then load the fine-tuned adapter from a HuggingFace repository.
            """
            adapter_model: Union[Model, HuggingFaceLLM]
            try:
                adapter_model = self._client.get_model(
                    name=self.adapter_id,
                    version=self.adapter_version,
                    model_id=None,
                )
            except ServerResponseError:
                # Predibase does not recognize the adapter ID (query HuggingFace).
                adapter_model = self._client.LLM(uri=f"hf://{self.adapter_id}")
            result = base_llm_deployment.with_adapter(model=adapter_model).generate(
                prompt=prompt,
                options=options,
            )
        else:
            result = base_llm_deployment.generate(
                prompt=prompt,
                options=options,
            )

        return CompletionResponse(text=result.response)

    @llm_completion_callback()
    def stream_complete(
        self, prompt: str, formatted: bool = False, **kwargs: Any
    ) -> "CompletionResponseGen":
        raise NotImplementedError

metadata property #

metadata: LLMMetadata

Get LLM metadata.