Skip to content

Replicate

ReplicateMultiModal #

Bases: MultiModalLLM

Source code in llama-index-integrations/multi_modal_llms/llama-index-multi-modal-llms-replicate/llama_index/multi_modal_llms/replicate/base.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
class ReplicateMultiModal(MultiModalLLM):
    model: str = Field(description="The Multi-Modal model to use from Replicate.")
    temperature: float = Field(
        description="The temperature to use for sampling. Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic."
    )
    max_new_tokens: int = Field(
        description=" The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt"
    )
    context_window: int = Field(
        description="The maximum number of context tokens for the model."
    )
    prompt_key: str = Field(description="The key to use for the prompt in API calls.")
    image_key: str = Field(description="The key to use for the image in API calls.")
    top_p: float = Field(
        description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens."
    )
    num_beams: int = Field(description="Number of beams for beam search decoding.")
    repetition_penalty: float = Field(
        description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it."
    )
    additional_kwargs: Dict[str, Any] = Field(
        default_factory=dict, description="Additional kwargs for the Replicate API."
    )

    _messages_to_prompt: Callable = PrivateAttr()
    _completion_to_prompt: Callable = PrivateAttr()

    def __init__(
        self,
        model: str = REPLICATE_MULTI_MODAL_LLM_MODELS["fuyu-8b"],
        temperature: float = 0.75,
        max_new_tokens: int = 512,
        num_input_files: int = 1,
        additional_kwargs: Optional[Dict[str, Any]] = None,
        context_window: int = DEFAULT_CONTEXT_WINDOW,
        prompt_key: str = "prompt",
        image_key: str = "image",
        repetition_penalty: Optional[float] = 1.0,
        num_beams: Optional[int] = 1,
        top_p: Optional[float] = 0.9,
        messages_to_prompt: Optional[Callable] = None,
        completion_to_prompt: Optional[Callable] = None,
        callback_manager: Optional[CallbackManager] = None,
    ) -> None:
        super().__init__(
            model=model,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            num_input_files=num_input_files,
            repetition_penalty=repetition_penalty,
            num_beams=num_beams,
            top_p=top_p,
            additional_kwargs=additional_kwargs or {},
            context_window=context_window,
            prompt_key=prompt_key,
            image_key=image_key,
            callback_manager=callback_manager,
        )
        self._messages_to_rompt = messages_to_prompt or generic_messages_to_prompt
        self._completion_to_prompt = completion_to_prompt or (lambda x: x)

    @classmethod
    def class_name(cls) -> str:
        return "replicate_multi_modal_llm"

    @property
    def metadata(self) -> MultiModalLLMMetadata:
        """Multi Modal LLM metadata."""
        return MultiModalLLMMetadata(
            context_window=self.context_window,
            num_output=DEFAULT_NUM_OUTPUTS,
            model_name=self.model,
        )

    @property
    def _model_kwargs(self) -> Dict[str, Any]:
        base_kwargs: Dict[str, Any] = {
            "temperature": self.temperature,
            "max_length": self.context_window,
            "max_new_tokens": self.max_new_tokens,
            "num_beams": self.num_beams,
            "repetition_penalty": self.repetition_penalty,
            "top_p": self.top_p,
        }
        return {
            **base_kwargs,
            **self.additional_kwargs,
        }

    def _get_multi_modal_chat_messages(
        self, prompt: str, image_document: ImageNode, **kwargs: Any
    ) -> Dict[str, Any]:
        if image_document.image_path:
            # load local image file and pass file handler to replicate
            try:
                return {
                    self.prompt_key: prompt,
                    self.image_key: open(image_document.image_path, "rb"),
                    **self._model_kwargs,
                    **kwargs,
                }
            except FileNotFoundError:
                raise FileNotFoundError(
                    "Could not load local image file. Please check whether the file exists"
                )
        elif image_document.image_url:
            # load remote image url and pass file url to replicate
            return {
                self.prompt_key: prompt,
                self.image_key: image_document.image_url,
                **self._model_kwargs,
                **kwargs,
            }
        else:
            raise FileNotFoundError(
                "Could not load image file. Please check whether the file exists"
            )

    def complete(
        self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
    ) -> CompletionResponse:
        response_gen = self.stream_complete(prompt, image_documents, **kwargs)
        response_list = list(response_gen)
        final_response = response_list[-1]
        final_response.delta = None
        return final_response

    def stream_complete(
        self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
    ) -> CompletionResponseGen:
        try:
            import replicate
        except ImportError:
            raise ImportError(
                "Could not import replicate library."
                "Please install replicate with `pip install replicate`"
            )

        # TODO: at the current moment, only support uploading one image document
        if len(image_documents) > 1:
            _logger.warning(
                "ReplicateMultiModal currently only supports uploading one image document"
                "we are using the first image document for completion."
            )

        prompt = self._completion_to_prompt(prompt)
        input_dict = self._get_multi_modal_chat_messages(
            # using the first image for single image completion
            prompt,
            image_documents[0],
            **kwargs,
        )
        if self.model not in REPLICATE_MULTI_MODAL_LLM_MODELS.values():
            raise ValueError(
                f"Unknown model {self.model!r}. Please provide a valid Replicate Multi-Modal model name in:"
                f" {', '.join(REPLICATE_MULTI_MODAL_LLM_MODELS.values())}"
            )

        response_iter = replicate.run(self.model, input=input_dict)

        def gen() -> CompletionResponseGen:
            text = ""
            for delta in response_iter:
                text += delta
                yield CompletionResponse(
                    delta=delta,
                    text=text,
                )

        return gen()

    def chat(
        self,
        messages: Sequence[ChatMessage],
        **kwargs: Any,
    ) -> ChatResponse:
        raise NotImplementedError

    def stream_chat(
        self,
        messages: Sequence[ChatMessage],
        **kwargs: Any,
    ) -> ChatResponseGen:
        raise NotImplementedError

    # ===== Async Endpoints =====

    async def acomplete(
        self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
    ) -> CompletionResponse:
        response_gen = self.stream_complete(prompt, image_documents, **kwargs)
        response_list = list(response_gen)
        final_response = response_list[-1]
        final_response.delta = None
        return final_response

    async def astream_complete(
        self, prompt: str, image_documents: Sequence[ImageNode], **kwargs: Any
    ) -> CompletionResponseAsyncGen:
        try:
            import replicate
        except ImportError:
            raise ImportError(
                "Could not import replicate library."
                "Please install replicate with `pip install replicate`"
            )

        # TODO: at the current moment, only support uploading one image document
        if len(image_documents) > 1:
            _logger.warning(
                "ReplicateMultiModal currently only supports uploading one image document"
                "we are using the first image document for completion."
            )

        prompt = self._completion_to_prompt(prompt)
        input_dict = self._get_multi_modal_chat_messages(
            # using the first image for single image completion
            prompt,
            image_documents[0],
            **kwargs,
        )
        if self.model not in REPLICATE_MULTI_MODAL_LLM_MODELS.values():
            raise ValueError(
                f"Unknown model {self.model!r}. Please provide a valid Replicate Multi-Modal model name in:"
                f" {', '.join(REPLICATE_MULTI_MODAL_LLM_MODELS.values())}"
            )

        response_iter = replicate.run(self.model, input=input_dict)

        async def gen() -> CompletionResponseAsyncGen:
            text = ""
            for delta in response_iter:
                text += delta
                yield CompletionResponse(
                    delta=delta,
                    text=text,
                )

        return gen()

    async def achat(
        self,
        messages: Sequence[ChatMessage],
        **kwargs: Any,
    ) -> ChatResponse:
        raise NotImplementedError

    async def astream_chat(
        self,
        messages: Sequence[ChatMessage],
        **kwargs: Any,
    ) -> ChatResponseAsyncGen:
        raise NotImplementedError

metadata property #

Multi Modal LLM metadata.