text_splitter.py
7.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""Utilities for token-budget-aware translation text splitting."""
from __future__ import annotations
from typing import Callable, List, Optional
TokenLengthFn = Callable[[str], int]
_CJK_LANGS = {"zh", "ja", "ko"}
_STRONG_BOUNDARIES = {"\n", "。", "!", "?", "!", "?", ";", ";", "…"}
_WEAK_BOUNDARIES = {",", ",", "、", ":", ":", "(", ")", "(", ")", "[", "]", "【", "】", "/", "|"}
_CLOSING_CHARS = {'"', "'", "”", "’", ")", "]", "}", ")", "】", "》", "」", "』"}
_NO_SPACE_BEFORE = tuple('.,!?;:)]}%>"\'')
_NO_SPACE_AFTER = tuple("([{$#@/<")
def is_cjk_language(lang: Optional[str]) -> bool:
return str(lang or "").strip().lower() in _CJK_LANGS
def compute_safe_input_token_limit(
*,
max_input_length: int,
max_new_tokens: int,
decoding_length_mode: str = "fixed",
decoding_length_extra: int = 0,
reserve_input_tokens: int = 8,
reserve_output_tokens: int = 8,
) -> int:
"""Derive a conservative source-token budget for translation splitting.
We keep a small reserve for tokenizer special tokens on the input side. If
the decode side is much tighter than the encode side, we also cap the
source budget based on decode settings so we split before the model is
likely to truncate.
"""
input_limit = max(8, int(max_input_length) - max(0, int(reserve_input_tokens)))
decode_mode = str(decoding_length_mode or "fixed").strip().lower()
if int(max_new_tokens) <= 0:
return input_limit
if decode_mode == "source":
output_limit = max(8, int(max_new_tokens) - max(0, int(decoding_length_extra)))
return max(8, min(input_limit, output_limit))
if int(max_new_tokens) >= int(max_input_length):
return input_limit
output_limit = max(8, int(max_new_tokens) - max(0, int(reserve_output_tokens)))
return max(8, min(input_limit, output_limit))
def split_text_for_translation(
text: str,
*,
max_tokens: int,
token_length_fn: TokenLengthFn,
) -> List[str]:
"""Split long text into a few translation-friendly segments.
The splitter prefers sentence boundaries, then clause boundaries, then
whitespace, and only falls back to character-based splitting when needed.
"""
if not text:
return [text]
if token_length_fn(text) <= max_tokens:
return [text]
segments = _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=0)
return [segment for segment in segments if segment]
def join_translated_segments(
segments: List[Optional[str]],
*,
target_lang: Optional[str],
original_text: str,
) -> Optional[str]:
parts = [segment.strip() for segment in segments if segment and segment.strip()]
if not parts:
return None
separator = "" if is_cjk_language(target_lang) else " "
if "\n" in original_text and separator:
separator = "\n"
merged = parts[0]
for part in parts[1:]:
if not separator:
merged += part
continue
if merged.endswith(_NO_SPACE_AFTER) or part.startswith(_NO_SPACE_BEFORE):
merged += part
continue
merged += separator + part
return merged.strip() or None
def _split_recursive(
text: str,
*,
max_tokens: int,
token_length_fn: TokenLengthFn,
level: int,
) -> List[str]:
if token_length_fn(text) <= max_tokens:
return [text]
if level >= 3:
return _hard_split(text, max_tokens=max_tokens, token_length_fn=token_length_fn)
pieces = _split_by_level(text, level)
if len(pieces) <= 1:
return _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)
merged: List[str] = []
buffer = ""
for piece in pieces:
candidate = buffer + piece if buffer else piece
if token_length_fn(candidate) <= max_tokens:
buffer = candidate
continue
if buffer:
merged.extend(
_split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)
)
buffer = piece
continue
merged.extend(_split_recursive(piece, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1))
if buffer:
merged.extend(_split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1))
return merged
def _split_by_level(text: str, level: int) -> List[str]:
parts: List[str] = []
start = 0
index = 0
while index < len(text):
boundary_end = _match_boundary(text, index, level)
if boundary_end is None:
index += 1
continue
if boundary_end > start:
parts.append(text[start:boundary_end])
start = boundary_end
index = boundary_end
if start < len(text):
parts.append(text[start:])
return [part for part in parts if part]
def _match_boundary(text: str, index: int, level: int) -> Optional[int]:
char = text[index]
if level == 0:
if char in _STRONG_BOUNDARIES:
return _consume_boundary_tail(text, index + 1)
if char == "." and _is_sentence_period(text, index):
return _consume_boundary_tail(text, index + 1)
return None
if level == 1:
if char in _WEAK_BOUNDARIES:
return _consume_boundary_tail(text, index + 1)
return None
if level == 2 and char.isspace():
end = index + 1
while end < len(text) and text[end].isspace():
end += 1
return end
return None
def _consume_boundary_tail(text: str, index: int) -> int:
end = index
while end < len(text) and text[end] in _CLOSING_CHARS:
end += 1
while end < len(text) and text[end].isspace():
end += 1
return end
def _is_sentence_period(text: str, index: int) -> bool:
prev_char = text[index - 1] if index > 0 else ""
next_char = text[index + 1] if index + 1 < len(text) else ""
if prev_char.isdigit() and next_char.isdigit():
return False
if not next_char:
return True
return next_char.isspace() or next_char in _CLOSING_CHARS
def _hard_split(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> List[str]:
segments: List[str] = []
remaining = text
while remaining:
if token_length_fn(remaining) <= max_tokens:
segments.append(remaining)
break
cut = _largest_prefix_within_limit(remaining, max_tokens=max_tokens, token_length_fn=token_length_fn)
refined_cut = _refine_cut(remaining, cut, max_tokens=max_tokens, token_length_fn=token_length_fn)
if refined_cut <= 0:
refined_cut = max(1, cut)
segments.append(remaining[:refined_cut])
remaining = remaining[refined_cut:]
return segments
def _largest_prefix_within_limit(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int:
low = 1
high = len(text)
best = 1
while low <= high:
mid = (low + high) // 2
if token_length_fn(text[:mid]) <= max_tokens:
best = mid
low = mid + 1
continue
high = mid - 1
return best
def _refine_cut(text: str, cut: int, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int:
best = cut
lower_bound = max(1, cut - 32)
for candidate in range(cut, lower_bound - 1, -1):
if text[candidate - 1].isspace() or text[candidate - 1] in _STRONG_BOUNDARIES or text[candidate - 1] in _WEAK_BOUNDARIES:
if candidate >= max(1, cut // 2) and token_length_fn(text[:candidate]) <= max_tokens:
return candidate
best = max(best, candidate)
return best