Skip to content

Flashrank rerank

FlashRankRerank #

Bases: BaseNodePostprocessor

Source code in llama-index-integrations/postprocessor/llama-index-postprocessor-flashrank-rerank/llama_index/postprocessor/flashrank_rerank/base.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class FlashRankRerank(BaseNodePostprocessor):
    model: str = Field(
        description="FlashRank model name.", default="ms-marco-TinyBERT-L-2-v2"
    )
    top_n: int = Field(
        description="Number of nodes to return sorted by score.", default=20
    )
    max_length: int = Field(
        description="Maximum length of passage text passed to the reranker.",
        default=512,
    )

    _reranker: Ranker = PrivateAttr()

    @override
    def model_post_init(self, context: Any, /) -> None:  # pyright: ignore[reportAny]
        self._reranker = Ranker(model_name=self.model, max_length=self.max_length)

    @classmethod
    @override
    def class_name(cls) -> str:
        return "FlashRankRerank"

    @dispatcher.span
    @override
    def _postprocess_nodes(
        self,
        nodes: list[NodeWithScore],
        query_bundle: QueryBundle | None = None,
    ) -> list[NodeWithScore]:
        if query_bundle is None:
            raise ValueError("Missing query bundle in extra info.")
        if len(nodes) == 0:
            return []

        query_and_nodes: RerankRequest = RerankRequest(
            query=query_bundle.query_str,
            passages=[
                {
                    "id": node.node.id_,
                    "text": node.node.get_content(metadata_mode=MetadataMode.EMBED),
                }
                for node in nodes
            ],
        )
        ## you would need to define a custom event subclassing BaseEvent from llama_index_instrumentation
        dispatcher.event(
            FlashRerankingQueryEvent(
                nodes=nodes,
                model_name=self.model,
                query_str=query_bundle.query_str,
                top_k=self.top_n,
            )
        )
        scores = self._reranker.rerank(query_and_nodes)
        scores_by_id = {score["id"]: score["score"] for score in scores}

        if len(scores) != len(nodes):
            msg = "Number of scores and nodes do not match."
            raise ValueError(msg)

        for node in nodes:
            node.score = scores_by_id[node.node.id_]

        new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
            : self.top_n
        ]
        dispatcher.event(FlashRerankEndEvent(nodes=new_nodes))

        return new_nodes