8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113 | class NetmindEmbedding(BaseEmbedding):
api_base: str = Field(
default="https://api.netmind.ai/inference-api/openai/v1",
description="The base URL for the Netmind API.",
)
api_key: str = Field(
default="",
description="The API key for the Netmind API. If not set, will attempt to use the NETMIND_API_KEY environment variable.",
)
timeout: float = Field(
default=120, description="The timeout for the API request in seconds.", ge=0
)
max_retries: int = Field(
default=3,
description="The maximum number of retries for the API request.",
ge=0,
)
def __init__(
self,
model_name: str,
timeout: Optional[float] = 120,
max_retries: Optional[int] = 3,
api_key: Optional[str] = None,
api_base: str = "https://api.netmind.ai/inference-api/openai/v1",
**kwargs: Any,
) -> None:
api_key = api_key or os.environ.get("NETMIND_API_KEY", None)
super().__init__(
model_name=model_name,
api_key=api_key,
api_base=api_base,
**kwargs,
)
self._client = OpenAI(
api_key=api_key,
base_url=self.api_base,
timeout=timeout,
max_retries=max_retries,
)
self._aclient = AsyncOpenAI(
api_key=api_key,
base_url=self.api_base,
timeout=timeout,
max_retries=max_retries,
)
def _get_text_embedding(self, text: str) -> Embedding:
"""Get text embedding."""
return (
self._client.embeddings.create(
input=[text],
model=self.model_name,
)
.data[0]
.embedding
)
def _get_query_embedding(self, query: str) -> Embedding:
"""Get query embedding."""
return (
self._client.embeddings.create(
input=[query],
model=self.model_name,
)
.data[0]
.embedding
)
def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
"""Get text embeddings."""
data = self._client.embeddings.create(
input=texts,
model=self.model_name,
).data
return [d.embedding for d in data]
async def _aget_text_embedding(self, text: str) -> Embedding:
"""Async get text embedding."""
return (
(await self._aclient.embeddings.create(input=[text], model=self.model_name))
.data[0]
.embedding
)
async def _aget_query_embedding(self, query: str) -> Embedding:
"""Async get query embedding."""
return (
(
await self._aclient.embeddings.create(
input=[query], model=self.model_name
)
)
.data[0]
.embedding
)
async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
"""Async get text embeddings."""
data = (
await self._aclient.embeddings.create(
input=texts,
model=self.model_name,
)
).data
return [d.embedding for d in data]
|