21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225 | class TokenTextSplitter(MetadataAwareTextSplitter):
"""Implementation of splitting text that looks at word tokens."""
chunk_size: int = Field(
default=DEFAULT_CHUNK_SIZE,
description="The token chunk size for each chunk.",
gt=0,
)
chunk_overlap: int = Field(
default=DEFAULT_CHUNK_OVERLAP,
description="The token overlap of each chunk when splitting.",
gte=0,
)
separator: str = Field(
default=" ", description="Default separator for splitting into words"
)
backup_separators: List = Field(
default_factory=list, description="Additional separators for splitting."
)
_tokenizer: Callable = PrivateAttr()
_split_fns: List[Callable] = PrivateAttr()
def __init__(
self,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
tokenizer: Optional[Callable] = None,
callback_manager: Optional[CallbackManager] = None,
separator: str = " ",
backup_separators: Optional[List[str]] = ["\n"],
include_metadata: bool = True,
include_prev_next_rel: bool = True,
id_func: Optional[Callable[[int, Document], str]] = None,
):
"""Initialize with parameters."""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
callback_manager = callback_manager or CallbackManager([])
id_func = id_func or default_id_func
self._tokenizer = tokenizer or get_tokenizer()
all_seps = [separator] + (backup_separators or [])
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=separator,
backup_separators=backup_separators,
callback_manager=callback_manager,
include_metadata=include_metadata,
include_prev_next_rel=include_prev_next_rel,
id_func=id_func,
)
@classmethod
def from_defaults(
cls,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
separator: str = " ",
backup_separators: Optional[List[str]] = ["\n"],
callback_manager: Optional[CallbackManager] = None,
include_metadata: bool = True,
include_prev_next_rel: bool = True,
id_func: Optional[Callable[[int, Document], str]] = None,
) -> "TokenTextSplitter":
"""Initialize with default parameters."""
callback_manager = callback_manager or CallbackManager([])
return cls(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=separator,
backup_separators=backup_separators,
callback_manager=callback_manager,
include_metadata=include_metadata,
include_prev_next_rel=include_prev_next_rel,
id_func=id_func,
)
@classmethod
def class_name(cls) -> str:
return "TokenTextSplitter"
def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
"""Split text into chunks, reserving space required for metadata str."""
metadata_len = len(self._tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN
effective_chunk_size = self.chunk_size - metadata_len
if effective_chunk_size <= 0:
raise ValueError(
f"Metadata length ({metadata_len}) is longer than chunk size "
f"({self.chunk_size}). Consider increasing the chunk size or "
"decreasing the size of your metadata to avoid this."
)
elif effective_chunk_size < 50:
print(
f"Metadata length ({metadata_len}) is close to chunk size "
f"({self.chunk_size}). Resulting chunks are less than 50 tokens. "
"Consider increasing the chunk size or decreasing the size of "
"your metadata to avoid this.",
flush=True,
)
return self._split_text(text, chunk_size=effective_chunk_size)
def split_text(self, text: str) -> List[str]:
"""Split text into chunks."""
return self._split_text(text, chunk_size=self.chunk_size)
def _split_text(self, text: str, chunk_size: int) -> List[str]:
"""Split text into chunks up to chunk_size."""
if text == "":
return [text]
with self.callback_manager.event(
CBEventType.CHUNKING, payload={EventPayload.CHUNKS: [text]}
) as event:
splits = self._split(text, chunk_size)
chunks = self._merge(splits, chunk_size)
event.on_end(
payload={EventPayload.CHUNKS: chunks},
)
return chunks
def _split(self, text: str, chunk_size: int) -> List[str]:
"""Break text into splits that are smaller than chunk size.
The order of splitting is:
1. split by separator
2. split by backup separators (if any)
3. split by characters
NOTE: the splits contain the separators.
"""
if len(self._tokenizer(text)) <= chunk_size:
return [text]
for split_fn in self._split_fns:
splits = split_fn(text)
if len(splits) > 1:
break
new_splits = []
for split in splits:
split_len = len(self._tokenizer(split))
if split_len <= chunk_size:
new_splits.append(split)
else:
# recursively split
new_splits.extend(self._split(split, chunk_size=chunk_size))
return new_splits
def _merge(self, splits: List[str], chunk_size: int) -> List[str]:
"""Merge splits into chunks.
The high-level idea is to keep adding splits to a chunk until we
exceed the chunk size, then we start a new chunk with overlap.
When we start a new chunk, we pop off the first element of the previous
chunk until the total length is less than the chunk size.
"""
chunks: List[str] = []
cur_chunk: List[str] = []
cur_len = 0
for split in splits:
split_len = len(self._tokenizer(split))
if split_len > chunk_size:
_logger.warning(
f"Got a split of size {split_len}, ",
f"larger than chunk size {chunk_size}.",
)
# if we exceed the chunk size after adding the new split, then
# we need to end the current chunk and start a new one
if cur_len + split_len > chunk_size:
# end the previous chunk
chunk = "".join(cur_chunk).strip()
if chunk:
chunks.append(chunk)
# start a new chunk with overlap
# keep popping off the first element of the previous chunk until:
# 1. the current chunk length is less than chunk overlap
# 2. the total length is less than chunk size
while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size:
# pop off the first element
first_chunk = cur_chunk.pop(0)
cur_len -= len(self._tokenizer(first_chunk))
cur_chunk.append(split)
cur_len += split_len
# handle the last chunk
chunk = "".join(cur_chunk).strip()
if chunk:
chunks.append(chunk)
return chunks
|