20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129 | class RetrieverTool(AsyncBaseTool):
"""Retriever tool.
A tool making use of a retriever.
Args:
retriever (BaseRetriever): A retriever.
metadata (ToolMetadata): The associated metadata of the query engine.
node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of
node postprocessors.
"""
def __init__(
self,
retriever: BaseRetriever,
metadata: ToolMetadata,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
) -> None:
self._retriever = retriever
self._metadata = metadata
self._node_postprocessors = node_postprocessors or []
@classmethod
def from_defaults(
cls,
retriever: BaseRetriever,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
) -> "RetrieverTool":
name = name or DEFAULT_NAME
description = description or DEFAULT_DESCRIPTION
metadata = ToolMetadata(name=name, description=description)
return cls(
retriever=retriever,
metadata=metadata,
node_postprocessors=node_postprocessors,
)
@property
def retriever(self) -> BaseRetriever:
return self._retriever
@property
def metadata(self) -> ToolMetadata:
return self._metadata
def call(self, *args: Any, **kwargs: Any) -> ToolOutput:
query_str = ""
if args is not None:
query_str += ", ".join([str(arg) for arg in args]) + "\n"
if kwargs is not None:
query_str += (
", ".join([f"{k!s} is {v!s}" for k, v in kwargs.items()]) + "\n"
)
if query_str == "":
raise ValueError("Cannot call query engine without inputs")
docs = self._retriever.retrieve(query_str)
docs = self._apply_node_postprocessors(docs, QueryBundle(query_str))
content = ""
for doc in docs:
node_copy = doc.node.model_copy()
node_copy.text_template = "{metadata_str}\n{content}"
node_copy.metadata_template = "{key} = {value}"
content += node_copy.get_content(MetadataMode.LLM) + "\n\n"
return ToolOutput(
content=content,
tool_name=self.metadata.name,
raw_input={"input": input},
raw_output=docs,
)
async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput:
query_str = ""
if args is not None:
query_str += ", ".join([str(arg) for arg in args]) + "\n"
if kwargs is not None:
query_str += (
", ".join([f"{k!s} is {v!s}" for k, v in kwargs.items()]) + "\n"
)
if query_str == "":
raise ValueError("Cannot call query engine without inputs")
docs = await self._retriever.aretrieve(query_str)
content = ""
docs = self._apply_node_postprocessors(docs, QueryBundle(query_str))
for doc in docs:
node_copy = doc.node.model_copy()
node_copy.text_template = "{metadata_str}\n{content}"
node_copy.metadata_template = "{key} = {value}"
content += node_copy.get_content(MetadataMode.LLM) + "\n\n"
return ToolOutput(
content=content,
tool_name=self.metadata.name,
raw_input={"input": input},
raw_output=docs,
)
def as_langchain_tool(self) -> "LlamaIndexTool":
raise NotImplementedError("`as_langchain_tool` not implemented here.")
def _apply_node_postprocessors(
self, nodes: List[NodeWithScore], query_bundle: QueryBundle
) -> List[NodeWithScore]:
for node_postprocessor in self._node_postprocessors:
nodes = node_postprocessor.postprocess_nodes(
nodes, query_bundle=query_bundle
)
return nodes
|