class CustomLLM(LLM):
"""Simple abstract base class for custom LLMs.
Subclasses must implement the `__init__`, `_complete`,
`_stream_complete`, and `metadata` methods.
"""
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
assert self.messages_to_prompt is not None
prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)
@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
assert self.messages_to_prompt is not None
prompt = self.messages_to_prompt(messages)
completion_response_gen = self.stream_complete(prompt, formatted=True, **kwargs)
return stream_completion_response_to_chat_response(completion_response_gen)
@llm_chat_callback()
async def achat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
return self.chat(messages, **kwargs)
@llm_chat_callback()
async def astream_chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponseAsyncGen:
async def gen() -> ChatResponseAsyncGen:
for message in self.stream_chat(messages, **kwargs):
yield message
# NOTE: convert generator to async generator
return gen()
@llm_completion_callback()
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
return self.complete(prompt, formatted=formatted, **kwargs)
@llm_completion_callback()
async def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseAsyncGen:
async def gen() -> CompletionResponseAsyncGen:
for message in self.stream_complete(prompt, formatted=formatted, **kwargs):
yield message
# NOTE: convert generator to async generator
return gen()
@classmethod
def class_name(cls) -> str:
return "custom_llm"