Bases: BaseNodePostprocessor
Source code in llama-index-integrations/postprocessor/llama-index-postprocessor-flashrank-rerank/llama_index/postprocessor/flashrank_rerank/base.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99 | class FlashRankRerank(BaseNodePostprocessor):
model: str = Field(
description="FlashRank model name.", default="ms-marco-TinyBERT-L-2-v2"
)
top_n: int = Field(
description="Number of nodes to return sorted by score.", default=20
)
max_length: int = Field(
description="Maximum length of passage text passed to the reranker.",
default=512,
)
_reranker: Ranker = PrivateAttr()
@override
def model_post_init(self, context: Any, /) -> None: # pyright: ignore[reportAny]
self._reranker = Ranker(model_name=self.model, max_length=self.max_length)
@classmethod
@override
def class_name(cls) -> str:
return "FlashRankRerank"
@dispatcher.span
@override
def _postprocess_nodes(
self,
nodes: list[NodeWithScore],
query_bundle: QueryBundle | None = None,
) -> list[NodeWithScore]:
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
query_and_nodes: RerankRequest = RerankRequest(
query=query_bundle.query_str,
passages=[
{
"id": node.node.id_,
"text": node.node.get_content(metadata_mode=MetadataMode.EMBED),
}
for node in nodes
],
)
## you would need to define a custom event subclassing BaseEvent from llama_index_instrumentation
dispatcher.event(
FlashRerankingQueryEvent(
nodes=nodes,
model_name=self.model,
query_str=query_bundle.query_str,
top_k=self.top_n,
)
)
scores = self._reranker.rerank(query_and_nodes)
scores_by_id = {score["id"]: score["score"] for score in scores}
if len(scores) != len(nodes):
msg = "Number of scores and nodes do not match."
raise ValueError(msg)
for node in nodes:
node.score = scores_by_id[node.node.id_]
new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
: self.top_n
]
dispatcher.event(FlashRerankEndEvent(nodes=new_nodes))
return new_nodes
|