Skip to content

Longllmlingua

LongLLMLinguaPostprocessor #

Bases: BaseNodePostprocessor

Optimization of nodes.

Compress using LongLLMLingua paper.

Source code in llama-index-integrations/postprocessor/llama-index-postprocessor-longllmlingua/llama_index/postprocessor/longllmlingua/base.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class LongLLMLinguaPostprocessor(BaseNodePostprocessor):
    """Optimization of nodes.

    Compress using LongLLMLingua paper.

    """

    metadata_mode: MetadataMode = Field(
        default=MetadataMode.ALL, description="Metadata mode."
    )
    instruction_str: str = Field(
        default=DEFAULT_INSTRUCTION_STR, description="Instruction string."
    )
    target_token: int = Field(
        default=300, description="Target number of compressed tokens."
    )
    rank_method: str = Field(default="longllmlingua", description="Ranking method.")
    additional_compress_kwargs: Dict[str, Any] = Field(
        default_factory=dict, description="Additional compress kwargs."
    )

    _llm_lingua: Any = PrivateAttr()

    def __init__(
        self,
        model_name: str = "NousResearch/Llama-2-7b-hf",
        device_map: str = "cuda",
        model_config: Optional[dict] = {},
        open_api_config: Optional[dict] = {},
        metadata_mode: MetadataMode = MetadataMode.ALL,
        instruction_str: str = DEFAULT_INSTRUCTION_STR,
        target_token: int = 300,
        rank_method: str = "longllmlingua",
        additional_compress_kwargs: Optional[Dict[str, Any]] = None,
    ):
        """LongLLMLingua Compressor for Node Context."""
        from llmlingua import PromptCompressor

        super().__init__(
            metadata_mode=metadata_mode,
            instruction_str=instruction_str,
            target_token=target_token,
            rank_method=rank_method,
            additional_compress_kwargs=additional_compress_kwargs,
        )

        open_api_config = open_api_config or {}
        additional_compress_kwargs = additional_compress_kwargs or {}

        self._llm_lingua = PromptCompressor(
            model_name=model_name,
            device_map=device_map,
            model_config=model_config,
            open_api_config=open_api_config,
        )

    @classmethod
    def class_name(cls) -> str:
        return "LongLLMLinguaPostprocessor"

    def _postprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:
        """Optimize a node text given the query by shortening the node text."""
        if query_bundle is None:
            raise ValueError("Query bundle is required.")
        context_texts = [n.get_content(metadata_mode=self.metadata_mode) for n in nodes]
        # split by "\n\n" (recommended by LongLLMLingua authors)
        new_context_texts = [
            c for context in context_texts for c in context.split("\n\n")
        ]

        # You can use it this way, although the question-aware fine-grained compression hasn't been enabled.
        compressed_prompt = self._llm_lingua.compress_prompt(
            new_context_texts,  # ! Replace the previous context_list
            instruction=self.instruction_str,
            question=query_bundle.query_str,
            # target_token=2000,
            target_token=self.target_token,
            rank_method=self.rank_method,
            **self.additional_compress_kwargs,
        )

        compressed_prompt_txt = compressed_prompt["compressed_prompt"]

        # separate out the question and instruction (appended to top and bottom)
        compressed_prompt_txt_list = compressed_prompt_txt.split("\n\n")
        compressed_prompt_txt_list = compressed_prompt_txt_list[1:-1]

        # return nodes for each list
        return [
            NodeWithScore(node=TextNode(text=t)) for t in compressed_prompt_txt_list
        ]