Skip to content

Multi step

MultiStepQueryEngine #

Bases: BaseQueryEngine

Multi-step query engine.

This query engine can operate over an existing base query engine, along with the multi-step query transform.

Parameters:

Name Type Description Default
query_engine BaseQueryEngine

A BaseQueryEngine object.

required
query_transform StepDecomposeQueryTransform

A StepDecomposeQueryTransform object.

required
response_synthesizer Optional[BaseSynthesizer]

A BaseSynthesizer object.

None
num_steps Optional[int]

Number of steps to run the multi-step query.

3
early_stopping bool

Whether to stop early if the stop function returns True.

True
index_summary str

A string summary of the index.

'None'
stop_fn Optional[Callable[[Dict], bool]]

A stop function that takes in a dictionary of information and returns a boolean.

None
Source code in llama-index-core/llama_index/core/query_engine/multistep_query_engine.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class MultiStepQueryEngine(BaseQueryEngine):
    """Multi-step query engine.

    This query engine can operate over an existing base query engine,
    along with the multi-step query transform.

    Args:
        query_engine (BaseQueryEngine): A BaseQueryEngine object.
        query_transform (StepDecomposeQueryTransform): A StepDecomposeQueryTransform
            object.
        response_synthesizer (Optional[BaseSynthesizer]): A BaseSynthesizer
            object.
        num_steps (Optional[int]): Number of steps to run the multi-step query.
        early_stopping (bool): Whether to stop early if the stop function returns True.
        index_summary (str): A string summary of the index.
        stop_fn (Optional[Callable[[Dict], bool]]): A stop function that takes in a
            dictionary of information and returns a boolean.

    """

    def __init__(
        self,
        query_engine: BaseQueryEngine,
        query_transform: StepDecomposeQueryTransform,
        response_synthesizer: Optional[BaseSynthesizer] = None,
        num_steps: Optional[int] = 3,
        early_stopping: bool = True,
        index_summary: str = "None",
        stop_fn: Optional[Callable[[Dict], bool]] = None,
    ) -> None:
        self._query_engine = query_engine
        self._query_transform = query_transform
        self._response_synthesizer = response_synthesizer or get_response_synthesizer(
            callback_manager=self._query_engine.callback_manager
        )

        self._index_summary = index_summary
        self._num_steps = num_steps
        self._early_stopping = early_stopping
        # TODO: make interface to stop function better
        self._stop_fn = stop_fn or default_stop_fn
        # num_steps must be provided if early_stopping is False
        if not self._early_stopping and self._num_steps is None:
            raise ValueError("Must specify num_steps if early_stopping is False.")

        callback_manager = self._query_engine.callback_manager
        super().__init__(callback_manager)

    def _get_prompt_modules(self) -> PromptMixinType:
        """Get prompt sub-modules."""
        return {
            "response_synthesizer": self._response_synthesizer,
            "query_transform": self._query_transform,
        }

    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        with self.callback_manager.event(
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
        ) as query_event:
            nodes, source_nodes, metadata = self._query_multistep(query_bundle)

            final_response = self._response_synthesizer.synthesize(
                query=query_bundle,
                nodes=nodes,
                additional_source_nodes=source_nodes,
            )
            final_response.metadata = metadata

            query_event.on_end(payload={EventPayload.RESPONSE: final_response})

        return final_response

    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        with self.callback_manager.event(
            CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
        ) as query_event:
            nodes, source_nodes, metadata = self._query_multistep(query_bundle)

            final_response = await self._response_synthesizer.asynthesize(
                query=query_bundle,
                nodes=nodes,
                additional_source_nodes=source_nodes,
            )
            final_response.metadata = metadata

            query_event.on_end(payload={EventPayload.RESPONSE: final_response})

        return final_response

    def _combine_queries(
        self, query_bundle: QueryBundle, prev_reasoning: str
    ) -> QueryBundle:
        """Combine queries."""
        transform_metadata = {
            "prev_reasoning": prev_reasoning,
            "index_summary": self._index_summary,
        }
        return self._query_transform(query_bundle, metadata=transform_metadata)

    def _query_multistep(
        self, query_bundle: QueryBundle
    ) -> Tuple[List[NodeWithScore], List[NodeWithScore], Dict[str, Any]]:
        """Run query combiner."""
        prev_reasoning = ""
        cur_response = None
        should_stop = False
        cur_steps = 0

        # use response
        final_response_metadata: Dict[str, Any] = {"sub_qa": []}

        text_chunks = []
        source_nodes = []
        while not should_stop:
            if self._num_steps is not None and cur_steps >= self._num_steps:
                should_stop = True
                break
            elif should_stop:
                break

            updated_query_bundle = self._combine_queries(query_bundle, prev_reasoning)

            # TODO: make stop logic better
            stop_dict = {"query_bundle": updated_query_bundle}
            if self._stop_fn(stop_dict):
                should_stop = True
                break

            cur_response = self._query_engine.query(updated_query_bundle)

            # append to response builder
            cur_qa_text = (
                f"\nQuestion: {updated_query_bundle.query_str}\n"
                f"Answer: {cur_response!s}"
            )
            text_chunks.append(cur_qa_text)
            for source_node in cur_response.source_nodes:
                source_nodes.append(source_node)
            # update metadata
            final_response_metadata["sub_qa"].append(
                (updated_query_bundle.query_str, cur_response)
            )

            prev_reasoning += (
                f"- {updated_query_bundle.query_str}\n" f"- {cur_response!s}\n"
            )
            cur_steps += 1

        nodes = [
            NodeWithScore(node=TextNode(text=text_chunk)) for text_chunk in text_chunks
        ]
        return nodes, source_nodes, final_response_metadata