24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145 | class CleanlabTLM(CustomLLM):
"""
Cleanlab TLM.
Examples:
`pip install llama-index-llms-cleanlab`
```python
from llama_index.llms.cleanlab import CleanlabTLM
llm = CleanlabTLM(quality_preset="best", api_key=api_key)
resp = llm.complete("Who is Paul Graham?")
print(resp)
```
"""
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description="The maximum number of context tokens for the model.",
)
max_tokens: int = Field(
default=DEFAULT_MAX_TOKENS,
description="The maximum number of tokens to generate in TLM response.",
)
model: str = Field(default=DEFAULT_MODEL, description="The base model to use.")
quality_preset: str = Field(
default="medium", description="Pre-defined configuration to use for TLM."
)
_client: Any = PrivateAttr()
def __init__(
self,
api_key: Optional[str] = None,
quality_preset: Optional[str] = "medium",
options: Optional[Dict] = None,
callback_manager: Optional[CallbackManager] = None,
additional_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
additional_kwargs=additional_kwargs or {},
callback_manager=callback_manager,
)
self.quality_preset = quality_preset
use_options = options is not None
# Check for user overrides in options dict
if use_options:
if options.get("model") is not None:
self.model = options.get("model")
if self.model == "gpt-4":
self.context_window = 8192
elif self.model == "gpt-3.5-turbo-16k":
self.context_window = 16385
elif self.model in ["gpt-4o-mini", "gpt-4o"]:
self.context_window = 131072
elif self.model in [
"claude-3-haiku",
"claude-3-sonnet",
"claude-3.5-sonnet",
]:
self.context_window = 204800
else:
# ValueError is raised by Studio object for non-supported models
# Set context_window to dummy (default) value
self.context_window = DEFAULT_CONTEXT_WINDOW
else:
self.context_window = DEFAULT_CONTEXT_WINDOW
if options.get("max_tokens") is not None:
self.max_tokens = options.get("max_tokens")
else:
self.max_tokens = DEFAULT_MAX_TOKENS
else:
self.model = DEFAULT_MODEL
self.context_window = DEFAULT_CONTEXT_WINDOW
self.max_tokens = DEFAULT_MAX_TOKENS
api_key = get_from_param_or_env("api_key", api_key, "CLEANLAB_API_KEY")
studio = Studio(api_key=api_key)
self._client = studio.TLM(
quality_preset=self.quality_preset, options=options if use_options else None
)
@classmethod
def class_name(cls) -> str:
return "CleanlabTLM"
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self.context_window,
num_output=self.max_tokens,
model_name=self.model,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
# Prompt TLM for a response and trustworthiness score
response: Dict[str, str] = self._client.prompt(prompt)
return CompletionResponse(
text=response["response"],
additional_kwargs={
"trustworthiness_score": response["trustworthiness_score"]
},
)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
# Prompt TLM for a response and trustworthiness score
response = self._client.prompt(prompt)
output = json.dumps(response)
# TODO: figure how to stream additional_kwargs. workaround: dump `trustworthiness_score` as str
# Stream the output
output_str = ""
for token in output:
output_str += token
yield CompletionResponse(text=output_str, delta=token)
|