Skip to content

Cleanlab

CleanlabTLM #

Bases: CustomLLM

Source code in llama-index-integrations/llms/llama-index-llms-cleanlab/llama_index/llms/cleanlab/base.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class CleanlabTLM(CustomLLM):
    context_window: int = Field(
        default=DEFAULT_CONTEXT_WINDOW,
        description="The maximum number of context tokens for the model.",
    )
    max_tokens: int = Field(
        default=DEFAULT_MAX_TOKENS,
        description="The maximum number of tokens to generate in TLM response.",
    )
    model: str = Field(
        default="gpt-3.5-turbo-16k", description="The base model to use."
    )
    quality_preset: str = Field(
        default="medium", description="Pre-defined configuration to use for TLM."
    )

    _client: Any = PrivateAttr()

    def __init__(
        self,
        api_key: Optional[str] = None,
        quality_preset: Optional[str] = "medium",
        options: Optional[Dict] = None,
        callback_manager: Optional[CallbackManager] = None,
        additional_kwargs: Optional[Dict[str, Any]] = None,
    ) -> None:
        super().__init__(
            additional_kwargs=additional_kwargs or {},
            callback_manager=callback_manager,
        )

        self.quality_preset = quality_preset
        use_options = options is not None
        # Check for user overrides in options dict
        if use_options:
            if options.get("model") is not None:
                self.model = options.get("model")
                if self.model == "gpt-4":
                    self.context_window = 8192
                elif self.model == "gpt-3.5-turbo-16k":
                    self.context_window = 16385
                else:
                    # ValueError is raised by Studio object for non-supported models
                    # Set context_window to dummy (default) value
                    self.context_window = DEFAULT_CONTEXT_WINDOW
            else:
                self.context_window = DEFAULT_CONTEXT_WINDOW

            if options.get("max_tokens") is not None:
                self.max_tokens = options.get("max_tokens")
            else:
                self.max_tokens = DEFAULT_MAX_TOKENS
        else:
            self.model = DEFAULT_MODEL
            self.context_window = DEFAULT_CONTEXT_WINDOW
            self.max_tokens = DEFAULT_MAX_TOKENS

        api_key = get_from_param_or_env("api_key", api_key, "CLEANLAB_API_KEY")

        studio = Studio(api_key=api_key)
        self._client = studio.TLM(
            quality_preset=self.quality_preset, options=options if use_options else None
        )

    @classmethod
    def class_name(cls) -> str:
        return "CleanlabTLM"

    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.max_tokens,
            model_name=self.model,
        )

    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        # Prompt TLM for a response and trustworthiness score
        response: Dict[str, str] = self._client.prompt(prompt)

        return CompletionResponse(
            text=response["response"],
            additional_kwargs={
                "trustworthiness_score": response["trustworthiness_score"]
            },
        )

    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        # Prompt TLM for a response and trustworthiness score
        response = self._client.prompt(prompt)
        output = json.dumps(response)

        # TODO: figure how to stream additional_kwargs. workaround: dump `trustworthiness_score` as str
        # Stream the output
        output_str = ""
        for token in output:
            output_str += token
            yield CompletionResponse(text=output_str, delta=token)

metadata property #

metadata: LLMMetadata

Get LLM metadata.