323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413 | class SummaryExtractor(BaseExtractor):
"""
Summary extractor. Node-level extractor with adjacent sharing.
Extracts `section_summary`, `prev_section_summary`, `next_section_summary`
metadata fields.
Args:
llm (Optional[LLM]): LLM
summaries (List[str]): list of summaries to extract: 'self', 'prev', 'next'
prompt_template (str): template for summary extraction
"""
llm: LLMPredictorType = Field(description="The LLM to use for generation.")
summaries: List[str] = Field(
description="List of summaries to extract: 'self', 'prev', 'next'"
)
prompt_template: str = Field(
default=DEFAULT_SUMMARY_EXTRACT_TEMPLATE,
description="Template to use when generating summaries.",
)
_self_summary: bool = PrivateAttr()
_prev_summary: bool = PrivateAttr()
_next_summary: bool = PrivateAttr()
def __init__(
self,
llm: Optional[LLM] = None,
# TODO: llm_predictor arg is deprecated
llm_predictor: Optional[LLMPredictorType] = None,
summaries: List[str] = ["self"],
prompt_template: str = DEFAULT_SUMMARY_EXTRACT_TEMPLATE,
num_workers: int = DEFAULT_NUM_WORKERS,
**kwargs: Any,
):
# validation
if not all(s in ["self", "prev", "next"] for s in summaries):
raise ValueError("summaries must be one of ['self', 'prev', 'next']")
self._self_summary = "self" in summaries
self._prev_summary = "prev" in summaries
self._next_summary = "next" in summaries
super().__init__(
llm=llm or llm_predictor or Settings.llm,
summaries=summaries,
prompt_template=prompt_template,
num_workers=num_workers,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "SummaryExtractor"
async def _agenerate_node_summary(self, node: BaseNode) -> str:
"""Generate a summary for a node."""
if self.is_text_node_only and not isinstance(node, TextNode):
return ""
context_str = node.get_content(metadata_mode=self.metadata_mode)
summary = await self.llm.apredict(
PromptTemplate(template=self.prompt_template), context_str=context_str
)
return summary.strip()
async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
if not all(isinstance(node, TextNode) for node in nodes):
raise ValueError("Only `TextNode` is allowed for `Summary` extractor")
node_summaries_jobs = []
for node in nodes:
node_summaries_jobs.append(self._agenerate_node_summary(node))
node_summaries = await run_jobs(
node_summaries_jobs,
show_progress=self.show_progress,
workers=self.num_workers,
)
# Extract node-level summary metadata
metadata_list: List[Dict] = [{} for _ in nodes]
for i, metadata in enumerate(metadata_list):
if i > 0 and self._prev_summary and node_summaries[i - 1]:
metadata["prev_section_summary"] = node_summaries[i - 1]
if i < len(nodes) - 1 and self._next_summary and node_summaries[i + 1]:
metadata["next_section_summary"] = node_summaries[i + 1]
if self._self_summary and node_summaries[i]:
metadata["section_summary"] = node_summaries[i]
return metadata_list
|