19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117 | class JinaEmbedding(BaseEmbedding):
"""JinaAI class for embeddings.
Args:
model (str): Model for embedding.
Defaults to `jina-embeddings-v2-base-en`
"""
api_key: str = Field(default=None, description="The JinaAI API key.")
model: str = Field(
default="jina-embeddings-v2-base-en",
description="The model to use when calling Jina AI API",
)
_session: Any = PrivateAttr()
def __init__(
self,
model: str = "jina-embeddings-v2-base-en",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
api_key: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
model=model,
api_key=api_key,
**kwargs,
)
self.api_key = get_from_param_or_env("api_key", api_key, "JINAAI_API_KEY", "")
self.model = model
self._session = requests.Session()
self._session.headers.update(
{"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}
)
@classmethod
def class_name(cls) -> str:
return "JinaAIEmbedding"
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._get_text_embedding(query)
async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return await self._aget_text_embedding(query)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings([text])[0]
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
result = await self._aget_text_embeddings([text])
return result[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
# Call Jina AI Embedding API
resp = self._session.post( # type: ignore
API_URL, json={"input": texts, "model": self.model}
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
import aiohttp
async with aiohttp.ClientSession(trust_env=True) as session:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept-Encoding": "identity",
}
async with session.post(
f"{API_URL}",
json={"input": texts, "model": self.model},
headers=headers,
) as response:
resp = await response.json()
response.raise_for_status()
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
|