55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158 | class TitleExtractor(BaseExtractor):
"""Title extractor. Useful for long documents. Extracts `document_title`
metadata field.
Args:
llm (Optional[LLM]): LLM
nodes (int): number of nodes from front to use for title extraction
node_template (str): template for node-level title clues extraction
combine_template (str): template for combining node-level clues into
a document-level title
"""
is_text_node_only: bool = False # can work for mixture of text and non-text nodes
llm: SerializeAsAny[LLM] = Field(description="The LLM to use for generation.")
nodes: int = Field(
default=5,
description="The number of nodes to extract titles from.",
gt=0,
)
node_template: str = Field(
default=DEFAULT_TITLE_NODE_TEMPLATE,
description="The prompt template to extract titles with.",
)
combine_template: str = Field(
default=DEFAULT_TITLE_COMBINE_TEMPLATE,
description="The prompt template to merge titles with.",
)
def __init__(
self,
llm: Optional[LLM] = None,
# TODO: llm_predictor arg is deprecated
llm_predictor: Optional[LLM] = None,
nodes: int = 5,
node_template: str = DEFAULT_TITLE_NODE_TEMPLATE,
combine_template: str = DEFAULT_TITLE_COMBINE_TEMPLATE,
num_workers: int = DEFAULT_NUM_WORKERS,
**kwargs: Any,
) -> None:
"""Init params."""
if nodes < 1:
raise ValueError("num_nodes must be >= 1")
super().__init__(
llm=llm or llm_predictor or Settings.llm,
nodes=nodes,
node_template=node_template,
combine_template=combine_template,
num_workers=num_workers,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "TitleExtractor"
async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
nodes_by_doc_id = self.separate_nodes_by_ref_id(nodes)
titles_by_doc_id = await self.extract_titles(nodes_by_doc_id)
return [{"document_title": titles_by_doc_id[node.ref_doc_id]} for node in nodes]
def filter_nodes(self, nodes: Sequence[BaseNode]) -> List[BaseNode]:
filtered_nodes: List[BaseNode] = []
for node in nodes:
if self.is_text_node_only and not isinstance(node, TextNode):
continue
filtered_nodes.append(node)
return filtered_nodes
def separate_nodes_by_ref_id(self, nodes: Sequence[BaseNode]) -> Dict:
separated_items: Dict[Optional[str], List[BaseNode]] = {}
for node in nodes:
key = node.ref_doc_id
if key not in separated_items:
separated_items[key] = []
if len(separated_items[key]) < self.nodes:
separated_items[key].append(node)
return separated_items
async def extract_titles(self, nodes_by_doc_id: Dict) -> Dict:
titles_by_doc_id = {}
for key, nodes in nodes_by_doc_id.items():
title_candidates = await self.get_title_candidates(nodes)
combined_titles = ", ".join(title_candidates)
titles_by_doc_id[key] = await self.llm.apredict(
PromptTemplate(template=self.combine_template),
context_str=combined_titles,
)
return titles_by_doc_id
async def get_title_candidates(self, nodes: List[BaseNode]) -> List[str]:
title_jobs = [
self.llm.apredict(
PromptTemplate(template=self.node_template),
context_str=cast(TextNode, node).text,
)
for node in nodes
]
return await run_jobs(
title_jobs, show_progress=self.show_progress, workers=self.num_workers
)
|