Skip to content

Sentence optimizer

Node PostProcessor module.

SentenceEmbeddingOptimizer #

Bases: BaseNodePostprocessor

Optimization of a text chunk given the query by shortening the input text.

Source code in llama-index-core/llama_index/core/postprocessor/optimizer.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class SentenceEmbeddingOptimizer(BaseNodePostprocessor):
    """Optimization of a text chunk given the query by shortening the input text."""

    percentile_cutoff: Optional[float] = Field(
        description="Percentile cutoff for the top k sentences to use."
    )
    threshold_cutoff: Optional[float] = Field(
        description="Threshold cutoff for similarity for each sentence to use."
    )

    _embed_model: BaseEmbedding = PrivateAttr()
    _tokenizer_fn: Callable[[str], List[str]] = PrivateAttr()

    context_before: Optional[int] = Field(
        description="Number of sentences before retrieved sentence for further context"
    )

    context_after: Optional[int] = Field(
        description="Number of sentences after retrieved sentence for further context"
    )

    def __init__(
        self,
        embed_model: Optional[BaseEmbedding] = None,
        percentile_cutoff: Optional[float] = None,
        threshold_cutoff: Optional[float] = None,
        tokenizer_fn: Optional[Callable[[str], List[str]]] = None,
        context_before: Optional[int] = None,
        context_after: Optional[int] = None,
    ):
        """Optimizer class that is passed into BaseGPTIndexQuery.

        Should be set like this:

        .. code-block:: python
        from llama_index.core.optimization.optimizer import Optimizer
        optimizer = SentenceEmbeddingOptimizer(
                        percentile_cutoff=0.5
                        this means that the top 50% of sentences will be used.
                        Alternatively, you can set the cutoff using a threshold
                        on the similarity score. In this case only sentences with a
                        similarity score higher than the threshold will be used.
                        threshold_cutoff=0.7
                        these cutoffs can also be used together.
                    )

        query_engine = index.as_query_engine(
            optimizer=optimizer
        )
        response = query_engine.query("<query_str>")
        """
        super().__init__(
            percentile_cutoff=percentile_cutoff,
            threshold_cutoff=threshold_cutoff,
            context_after=context_after,
            context_before=context_before,
        )
        self._embed_model = embed_model or Settings.embed_model
        if self._embed_model is None:
            try:
                from llama_index.embeddings.openai import (
                    OpenAIEmbedding,
                )  # pants: no-infer-dep

                self._embed_model = OpenAIEmbedding()
            except ImportError:
                raise ImportError(
                    "`llama-index-embeddings-openai` package not found, "
                    "please run `pip install llama-index-embeddings-openai`"
                )

        if tokenizer_fn is None:
            import nltk

            tokenizer = nltk.tokenize.PunktSentenceTokenizer()
            tokenizer_fn = tokenizer.tokenize
        self._tokenizer_fn = tokenizer_fn

    @classmethod
    def class_name(cls) -> str:
        return "SentenceEmbeddingOptimizer"

    def _postprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:
        """Optimize a node text given the query by shortening the node text."""
        if query_bundle is None:
            return nodes

        for node_idx in range(len(nodes)):
            text = nodes[node_idx].node.get_content(metadata_mode=MetadataMode.LLM)

            split_text = self._tokenizer_fn(text)

            if query_bundle.embedding is None:
                query_bundle.embedding = (
                    self._embed_model.get_agg_embedding_from_queries(
                        query_bundle.embedding_strs
                    )
                )

            text_embeddings = self._embed_model._get_text_embeddings(split_text)

            num_top_k = None
            threshold = None
            if self.percentile_cutoff is not None:
                num_top_k = int(len(split_text) * self.percentile_cutoff)
            if self.threshold_cutoff is not None:
                threshold = self.threshold_cutoff

            top_similarities, top_idxs = get_top_k_embeddings(
                query_embedding=query_bundle.embedding,
                embeddings=text_embeddings,
                similarity_fn=self._embed_model.similarity,
                similarity_top_k=num_top_k,
                embedding_ids=list(range(len(text_embeddings))),
                similarity_cutoff=threshold,
            )

            if len(top_idxs) == 0:
                raise ValueError("Optimizer returned zero sentences.")

            rangeMin, rangeMax = 0, len(split_text)

            if self.context_before is None:
                self.context_before = 1
            if self.context_after is None:
                self.context_after = 1

            top_sentences = [
                " ".join(
                    split_text[
                        max(idx - self.context_before, rangeMin) : min(
                            idx + self.context_after + 1, rangeMax
                        )
                    ]
                )
                for idx in top_idxs
            ]

            logger.debug(f"> Top {len(top_idxs)} sentences with scores:\n")
            if logger.isEnabledFor(logging.DEBUG):
                for idx in range(len(top_idxs)):
                    logger.debug(
                        f"{idx}. {top_sentences[idx]} ({top_similarities[idx]})"
                    )

            nodes[node_idx].node.set_content(" ".join(top_sentences))

        return nodes