18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149 | class Accumulate(BaseSynthesizer):
"""Accumulate responses from multiple text chunks."""
def __init__(
self,
llm: Optional[LLM] = None,
callback_manager: Optional[CallbackManager] = None,
prompt_helper: Optional[PromptHelper] = None,
text_qa_template: Optional[BasePromptTemplate] = None,
output_cls: Optional[BaseModel] = None,
streaming: bool = False,
use_async: bool = False,
) -> None:
super().__init__(
llm=llm,
callback_manager=callback_manager,
prompt_helper=prompt_helper,
streaming=streaming,
)
self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
self._use_async = use_async
self._output_cls = output_cls # type: ignore
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {"text_qa_template": self._text_qa_template}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "text_qa_template" in prompts:
self._text_qa_template = prompts["text_qa_template"]
def flatten_list(self, md_array: List[List[Any]]) -> List[Any]:
return [item for sublist in md_array for item in sublist]
def _format_response(self, outputs: List[Any], separator: str) -> str:
responses: List[str] = []
for response in outputs:
responses.append(response or "Empty Response")
return separator.join(
[f"Response {index + 1}: {item}" for index, item in enumerate(responses)]
)
async def aget_response(
self,
query_str: str,
text_chunks: Sequence[str],
separator: str = "\n---------------------\n",
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
"""Apply the same prompt to text chunks and return async responses."""
if self._streaming:
raise ValueError("Unable to stream in Accumulate response mode")
tasks = [
self._give_responses(
query_str, text_chunk, use_async=True, **response_kwargs
)
for text_chunk in text_chunks
]
flattened_tasks = self.flatten_list(tasks)
outputs = await asyncio.gather(*flattened_tasks)
return self._format_response(outputs, separator)
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
separator: str = "\n---------------------\n",
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
"""Apply the same prompt to text chunks and return responses."""
if self._streaming:
raise ValueError("Unable to stream in Accumulate response mode")
tasks = [
self._give_responses(
query_str, text_chunk, use_async=self._use_async, **response_kwargs
)
for text_chunk in text_chunks
]
outputs = self.flatten_list(tasks)
if self._use_async:
outputs = run_async_tasks(outputs)
return self._format_response(outputs, separator)
def _give_responses(
self,
query_str: str,
text_chunk: str,
use_async: bool = False,
**response_kwargs: Any,
) -> List[Any]:
"""Give responses given a query and a corresponding text chunk."""
text_qa_template = self._text_qa_template.partial_format(query_str=query_str)
text_chunks = self._prompt_helper.repack(text_qa_template, [text_chunk])
predictor: Callable
if self._output_cls is None:
predictor = self._llm.apredict if use_async else self._llm.predict
return [
predictor(
text_qa_template,
context_str=cur_text_chunk,
**response_kwargs,
)
for cur_text_chunk in text_chunks
]
else:
predictor = (
self._llm.astructured_predict # type: ignore
if use_async
else self._llm.structured_predict # type: ignore
)
return [
predictor(
self._output_cls,
text_qa_template,
context_str=cur_text_chunk,
**response_kwargs,
)
for cur_text_chunk in text_chunks
]
|