25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151 | class RouterRetriever(BaseRetriever):
"""Router retriever.
Selects one (or multiple) out of several candidate retrievers to execute a query.
Args:
selector (BaseSelector): A selector that chooses one out of many options based
on each candidate's metadata and query.
retriever_tools (Sequence[RetrieverTool]): A sequence of candidate
retrievers. They must be wrapped as tools to expose metadata to
the selector.
"""
def __init__(
self,
selector: BaseSelector,
retriever_tools: Sequence[RetrieverTool],
llm: Optional[LLM] = None,
service_context: Optional[ServiceContext] = None,
objects: Optional[List[IndexNode]] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
) -> None:
self._llm = llm or llm_from_settings_or_context(Settings, service_context)
self._selector = selector
self._retrievers: List[BaseRetriever] = [x.retriever for x in retriever_tools]
self._metadatas = [x.metadata for x in retriever_tools]
super().__init__(
callback_manager=callback_manager_from_settings_or_context(
Settings, service_context
),
object_map=object_map,
objects=objects,
verbose=verbose,
)
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
# NOTE: don't include tools for now
return {"selector": self._selector}
@classmethod
def from_defaults(
cls,
retriever_tools: Sequence[RetrieverTool],
llm: Optional[LLM] = None,
service_context: Optional[ServiceContext] = None,
selector: Optional[BaseSelector] = None,
select_multi: bool = False,
) -> "RouterRetriever":
llm = llm or llm_from_settings_or_context(Settings, service_context)
selector = selector or get_selector_from_llm(llm, is_multi=select_multi)
return cls(
selector,
retriever_tools,
llm=llm,
service_context=service_context,
)
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as query_event:
result = self._selector.select(self._metadatas, query_bundle)
if len(result.inds) > 1:
retrieved_results = {}
for i, engine_ind in enumerate(result.inds):
logger.info(
f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
)
selected_retriever = self._retrievers[engine_ind]
cur_results = selected_retriever.retrieve(query_bundle)
retrieved_results.update({n.node.node_id: n for n in cur_results})
else:
try:
selected_retriever = self._retrievers[result.ind]
logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
except ValueError as e:
raise ValueError("Failed to select retriever") from e
cur_results = selected_retriever.retrieve(query_bundle)
retrieved_results = {n.node.node_id: n for n in cur_results}
query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
return list(retrieved_results.values())
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as query_event:
result = await self._selector.aselect(self._metadatas, query_bundle)
if len(result.inds) > 1:
retrieved_results = {}
tasks = []
for i, engine_ind in enumerate(result.inds):
logger.info(
f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
)
selected_retriever = self._retrievers[engine_ind]
tasks.append(selected_retriever.aretrieve(query_bundle))
results_of_results = await asyncio.gather(*tasks)
cur_results = [
item for sublist in results_of_results for item in sublist
]
retrieved_results.update({n.node.node_id: n for n in cur_results})
else:
try:
selected_retriever = self._retrievers[result.ind]
logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
except ValueError as e:
raise ValueError("Failed to select retriever") from e
cur_results = await selected_retriever.aretrieve(query_bundle)
retrieved_results = {n.node.node_id: n for n in cur_results}
query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
return list(retrieved_results.values())
|