20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141 | class RouterRetriever(BaseRetriever):
"""Router retriever.
Selects one (or multiple) out of several candidate retrievers to execute a query.
Args:
selector (BaseSelector): A selector that chooses one out of many options based
on each candidate's metadata and query.
retriever_tools (Sequence[RetrieverTool]): A sequence of candidate
retrievers. They must be wrapped as tools to expose metadata to
the selector.
"""
def __init__(
self,
selector: BaseSelector,
retriever_tools: Sequence[RetrieverTool],
llm: Optional[LLM] = None,
objects: Optional[List[IndexNode]] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
) -> None:
self._llm = llm or Settings.llm
self._selector = selector
self._retrievers: List[BaseRetriever] = [x.retriever for x in retriever_tools]
self._metadatas = [x.metadata for x in retriever_tools]
super().__init__(
callback_manager=Settings.callback_manager,
object_map=object_map,
objects=objects,
verbose=verbose,
)
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
# NOTE: don't include tools for now
return {"selector": self._selector}
@classmethod
def from_defaults(
cls,
retriever_tools: Sequence[RetrieverTool],
llm: Optional[LLM] = None,
selector: Optional[BaseSelector] = None,
select_multi: bool = False,
) -> "RouterRetriever":
llm = llm or Settings.llm
selector = selector or get_selector_from_llm(llm, is_multi=select_multi)
return cls(
selector,
retriever_tools,
llm=llm,
)
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as query_event:
result = self._selector.select(self._metadatas, query_bundle)
if len(result.inds) > 1:
retrieved_results = {}
for i, engine_ind in enumerate(result.inds):
logger.info(
f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
)
selected_retriever = self._retrievers[engine_ind]
cur_results = selected_retriever.retrieve(query_bundle)
retrieved_results.update({n.node.node_id: n for n in cur_results})
else:
try:
selected_retriever = self._retrievers[result.ind]
logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
except ValueError as e:
raise ValueError("Failed to select retriever") from e
cur_results = selected_retriever.retrieve(query_bundle)
retrieved_results = {n.node.node_id: n for n in cur_results}
query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
return list(retrieved_results.values())
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as query_event:
result = await self._selector.aselect(self._metadatas, query_bundle)
if len(result.inds) > 1:
retrieved_results = {}
tasks = []
for i, engine_ind in enumerate(result.inds):
logger.info(
f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
)
selected_retriever = self._retrievers[engine_ind]
tasks.append(selected_retriever.aretrieve(query_bundle))
results_of_results = await asyncio.gather(*tasks)
cur_results = [
item for sublist in results_of_results for item in sublist
]
retrieved_results.update({n.node.node_id: n for n in cur_results})
else:
try:
selected_retriever = self._retrievers[result.ind]
logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
except ValueError as e:
raise ValueError("Failed to select retriever") from e
cur_results = await selected_retriever.aretrieve(query_bundle)
retrieved_results = {n.node.node_id: n for n in cur_results}
query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
return list(retrieved_results.values())
|