Bases: BaseNodePostprocessor
Source code in llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 | class CohereRerank(BaseNodePostprocessor):
model: str = Field(description="Cohere model name.")
top_n: int = Field(description="Top N nodes to return.")
_client: Any = PrivateAttr()
def __init__(
self,
top_n: int = 2,
model: str = "rerank-english-v2.0",
api_key: Optional[str] = None,
):
try:
api_key = api_key or os.environ["COHERE_API_KEY"]
except IndexError:
raise ValueError(
"Must pass in cohere api key or "
"specify via COHERE_API_KEY environment variable "
)
try:
from cohere import Client
except ImportError:
raise ImportError(
"Cannot import cohere package, please `pip install cohere`."
)
self._client = Client(api_key=api_key)
super().__init__(top_n=top_n, model=model)
@classmethod
def class_name(cls) -> str:
return "CohereRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
texts = [node.node.get_content() for node in nodes]
results = self._client.rerank(
model=self.model,
top_n=self.top_n,
query=query_bundle.query_str,
documents=texts,
)
new_nodes = []
for result in results:
new_node_with_score = NodeWithScore(
node=nodes[result.index].node, score=result.relevance_score
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})
return new_nodes
|