Commit 294c3d0a416199bb36b427a593bdb01a2899a098
1 parent
46ce858d
实现第一版“按模型预算智能分句”的基础能力。
改动: 新增分句与预算工具:translation/text_splitter.py 接入 HF 本地后端:translation/backends/local_seq2seq.py (line 157) 接入 CT2 本地后端:translation/backends/local_ctranslate2.py (line 301) 补了测试:tests/test_translation_local_backends.py 我先把代码里实际限制梳理了一遍,关键配置在 config/config.yaml (line 133): nllb-200-distilled-600m: max_input_length=256,max_new_tokens=64,并且是 ct2_decoding_length_mode=source + extra=8。现在按这个配置计算出的保守输入预算是 56 token。 opus-mt-zh-en: max_input_length=256,max_new_tokens=256。现在保守输入预算是 248 token。 opus-mt-en-zh: 同上,也是 248 token。 这版分句策略是: 先按强边界切:。!?!?;;…、换行、英文句号 不够再按弱边界切:,,、::()()[]【】/| 再不够才按空白切 最后才做 token 预算下的硬切 超长时会“分句翻译后再回拼”,中文目标语言默认无空格回拼,英文等默认按空格回拼,尽量别切太碎 验证: python3 -m compileall translation tests/test_translation_local_backends.py 已通过
Showing
4 changed files
with
481 additions
and
2 deletions
Show diff stats
tests/test_translation_local_backends.py
| @@ -2,6 +2,7 @@ import torch | @@ -2,6 +2,7 @@ import torch | ||
| 2 | 2 | ||
| 3 | from translation.backends.local_seq2seq import MarianMTTranslationBackend, NLLBTranslationBackend | 3 | from translation.backends.local_seq2seq import MarianMTTranslationBackend, NLLBTranslationBackend |
| 4 | from translation.service import TranslationService | 4 | from translation.service import TranslationService |
| 5 | +from translation.text_splitter import compute_safe_input_token_limit, split_text_for_translation | ||
| 5 | 6 | ||
| 6 | 7 | ||
| 7 | class _FakeBatch(dict): | 8 | class _FakeBatch(dict): |
| @@ -167,3 +168,84 @@ def test_translation_service_preloads_enabled_backends(monkeypatch): | @@ -167,3 +168,84 @@ def test_translation_service_preloads_enabled_backends(monkeypatch): | ||
| 167 | 168 | ||
| 168 | backend = service.get_backend("opus-mt-en-zh") | 169 | backend = service.get_backend("opus-mt-en-zh") |
| 169 | assert backend.model == "opus-mt-en-zh" | 170 | assert backend.model == "opus-mt-en-zh" |
| 171 | + | ||
| 172 | + | ||
| 173 | +def test_compute_safe_input_token_limit_uses_decode_constraints(): | ||
| 174 | + nllb_limit = compute_safe_input_token_limit( | ||
| 175 | + max_input_length=256, | ||
| 176 | + max_new_tokens=64, | ||
| 177 | + decoding_length_mode="source", | ||
| 178 | + decoding_length_extra=8, | ||
| 179 | + ) | ||
| 180 | + opus_limit = compute_safe_input_token_limit( | ||
| 181 | + max_input_length=256, | ||
| 182 | + max_new_tokens=256, | ||
| 183 | + ) | ||
| 184 | + | ||
| 185 | + assert nllb_limit == 56 | ||
| 186 | + assert opus_limit == 248 | ||
| 187 | + | ||
| 188 | + | ||
| 189 | +def test_split_text_for_translation_prefers_sentence_boundaries(): | ||
| 190 | + text = ( | ||
| 191 | + "这是一条很长的中文商品描述,包含材质、尺码和适用场景。" | ||
| 192 | + "适合春夏通勤,也适合日常出街穿搭;" | ||
| 193 | + "如果长度超了,应该优先按完整语义分句,而不是切成很碎的小片段。" | ||
| 194 | + ) | ||
| 195 | + | ||
| 196 | + segments = split_text_for_translation( | ||
| 197 | + text, | ||
| 198 | + max_tokens=36, | ||
| 199 | + token_length_fn=len, | ||
| 200 | + ) | ||
| 201 | + | ||
| 202 | + assert len(segments) >= 2 | ||
| 203 | + assert "".join(segments) == text | ||
| 204 | + assert all(len(segment) <= 36 for segment in segments) | ||
| 205 | + assert segments[0].endswith(("。", ";")) | ||
| 206 | + | ||
| 207 | + | ||
| 208 | +class _SegmentingMarianBackend(MarianMTTranslationBackend): | ||
| 209 | + def _load_model(self): | ||
| 210 | + self.translated_batches = [] | ||
| 211 | + | ||
| 212 | + def _token_count(self, text, target_lang, source_lang=None): | ||
| 213 | + del target_lang, source_lang | ||
| 214 | + return len(text) | ||
| 215 | + | ||
| 216 | + def _translate_batch(self, texts, target_lang, source_lang=None): | ||
| 217 | + del source_lang | ||
| 218 | + self.translated_batches.append(list(texts)) | ||
| 219 | + if target_lang == "zh": | ||
| 220 | + return [f"<{text.strip()}>" for text in texts] | ||
| 221 | + return [f"[{text.strip()}]" for text in texts] | ||
| 222 | + | ||
| 223 | + | ||
| 224 | +def test_local_backend_splits_oversized_text_before_translation(): | ||
| 225 | + backend = _SegmentingMarianBackend( | ||
| 226 | + name="opus-mt-en-zh", | ||
| 227 | + model_id="Helsinki-NLP/opus-mt-en-zh", | ||
| 228 | + model_dir="./models/translation/Helsinki-NLP/opus-mt-en-zh", | ||
| 229 | + device="cpu", | ||
| 230 | + torch_dtype="float32", | ||
| 231 | + batch_size=8, | ||
| 232 | + max_input_length=24, | ||
| 233 | + max_new_tokens=24, | ||
| 234 | + num_beams=1, | ||
| 235 | + source_langs=["en"], | ||
| 236 | + target_langs=["zh"], | ||
| 237 | + ) | ||
| 238 | + | ||
| 239 | + text = ( | ||
| 240 | + "This soft cotton dress is breathable and lightweight, " | ||
| 241 | + "works well for spring travel and everyday wear, " | ||
| 242 | + "and should be split on natural clause boundaries when it gets too long." | ||
| 243 | + ) | ||
| 244 | + | ||
| 245 | + result = backend.translate(text, source_lang="en", target_lang="zh") | ||
| 246 | + | ||
| 247 | + assert result is not None | ||
| 248 | + assert len(backend.translated_batches) == 1 | ||
| 249 | + assert len(backend.translated_batches[0]) >= 2 | ||
| 250 | + assert all(len(piece) <= 16 for piece in backend.translated_batches[0]) | ||
| 251 | + assert result == "".join(f"<{piece.strip()}>" for piece in backend.translated_batches[0]) |
translation/backends/local_ctranslate2.py
| @@ -14,6 +14,11 @@ from typing import Dict, List, Optional, Sequence, Union | @@ -14,6 +14,11 @@ from typing import Dict, List, Optional, Sequence, Union | ||
| 14 | from transformers import AutoTokenizer | 14 | from transformers import AutoTokenizer |
| 15 | 15 | ||
| 16 | from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES | 16 | from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES |
| 17 | +from translation.text_splitter import ( | ||
| 18 | + compute_safe_input_token_limit, | ||
| 19 | + join_translated_segments, | ||
| 20 | + split_text_for_translation, | ||
| 21 | +) | ||
| 17 | 22 | ||
| 18 | logger = logging.getLogger(__name__) | 23 | logger = logging.getLogger(__name__) |
| 19 | 24 | ||
| @@ -296,6 +301,82 @@ class LocalCTranslate2TranslationBackend: | @@ -296,6 +301,82 @@ class LocalCTranslate2TranslationBackend: | ||
| 296 | outputs.append(self._decode_tokens(processed)) | 301 | outputs.append(self._decode_tokens(processed)) |
| 297 | return outputs | 302 | return outputs |
| 298 | 303 | ||
| 304 | + def _token_count( | ||
| 305 | + self, | ||
| 306 | + text: str, | ||
| 307 | + target_lang: str, | ||
| 308 | + source_lang: Optional[str] = None, | ||
| 309 | + ) -> int: | ||
| 310 | + encoded = self._encode_source_tokens([text], source_lang, target_lang) | ||
| 311 | + return len(encoded[0]) if encoded else 0 | ||
| 312 | + | ||
| 313 | + def _effective_input_token_limit(self, target_lang: str, source_lang: Optional[str] = None) -> int: | ||
| 314 | + del target_lang, source_lang | ||
| 315 | + return compute_safe_input_token_limit( | ||
| 316 | + max_input_length=self.max_input_length, | ||
| 317 | + max_new_tokens=self.max_new_tokens, | ||
| 318 | + decoding_length_mode=self.ct2_decoding_length_mode, | ||
| 319 | + decoding_length_extra=self.ct2_decoding_length_extra, | ||
| 320 | + ) | ||
| 321 | + | ||
| 322 | + def _split_text_if_needed( | ||
| 323 | + self, | ||
| 324 | + text: str, | ||
| 325 | + target_lang: str, | ||
| 326 | + source_lang: Optional[str] = None, | ||
| 327 | + ) -> List[str]: | ||
| 328 | + limit = self._effective_input_token_limit(target_lang, source_lang) | ||
| 329 | + return split_text_for_translation( | ||
| 330 | + text, | ||
| 331 | + max_tokens=limit, | ||
| 332 | + token_length_fn=lambda value: self._token_count( | ||
| 333 | + value, | ||
| 334 | + target_lang=target_lang, | ||
| 335 | + source_lang=source_lang, | ||
| 336 | + ), | ||
| 337 | + ) | ||
| 338 | + | ||
| 339 | + def _translate_with_segmentation( | ||
| 340 | + self, | ||
| 341 | + texts: List[str], | ||
| 342 | + target_lang: str, | ||
| 343 | + source_lang: Optional[str] = None, | ||
| 344 | + ) -> List[Optional[str]]: | ||
| 345 | + segment_plans: List[List[str]] = [] | ||
| 346 | + flat_segments: List[str] = [] | ||
| 347 | + for text in texts: | ||
| 348 | + if not text.strip(): | ||
| 349 | + segment_plans.append([]) | ||
| 350 | + continue | ||
| 351 | + segments = self._split_text_if_needed(text, target_lang=target_lang, source_lang=source_lang) | ||
| 352 | + segment_plans.append(segments) | ||
| 353 | + flat_segments.extend(segments) | ||
| 354 | + | ||
| 355 | + translated_segments = ( | ||
| 356 | + self._translate_batch(flat_segments, target_lang=target_lang, source_lang=source_lang) | ||
| 357 | + if flat_segments | ||
| 358 | + else [] | ||
| 359 | + ) | ||
| 360 | + outputs: List[Optional[str]] = [] | ||
| 361 | + offset = 0 | ||
| 362 | + for original_text, segments in zip(texts, segment_plans): | ||
| 363 | + if not segments: | ||
| 364 | + outputs.append(None if not original_text.strip() else original_text) | ||
| 365 | + continue | ||
| 366 | + current = translated_segments[offset:offset + len(segments)] | ||
| 367 | + offset += len(segments) | ||
| 368 | + if len(segments) == 1: | ||
| 369 | + outputs.append(current[0]) | ||
| 370 | + continue | ||
| 371 | + outputs.append( | ||
| 372 | + join_translated_segments( | ||
| 373 | + current, | ||
| 374 | + target_lang=target_lang, | ||
| 375 | + original_text=original_text, | ||
| 376 | + ) | ||
| 377 | + ) | ||
| 378 | + return outputs | ||
| 379 | + | ||
| 299 | def translate( | 380 | def translate( |
| 300 | self, | 381 | self, |
| 301 | text: Union[str, Sequence[str]], | 382 | text: Union[str, Sequence[str]], |
| @@ -312,7 +393,7 @@ class LocalCTranslate2TranslationBackend: | @@ -312,7 +393,7 @@ class LocalCTranslate2TranslationBackend: | ||
| 312 | if not any(item.strip() for item in chunk): | 393 | if not any(item.strip() for item in chunk): |
| 313 | outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item] | 394 | outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item] |
| 314 | continue | 395 | continue |
| 315 | - outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang)) | 396 | + outputs.extend(self._translate_with_segmentation(chunk, target_lang=target_lang, source_lang=source_lang)) |
| 316 | return outputs[0] if is_single else outputs | 397 | return outputs[0] if is_single else outputs |
| 317 | 398 | ||
| 318 | 399 |
translation/backends/local_seq2seq.py
| @@ -11,6 +11,11 @@ import torch | @@ -11,6 +11,11 @@ import torch | ||
| 11 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | 11 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| 12 | 12 | ||
| 13 | from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES | 13 | from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES |
| 14 | +from translation.text_splitter import ( | ||
| 15 | + compute_safe_input_token_limit, | ||
| 16 | + join_translated_segments, | ||
| 17 | + split_text_for_translation, | ||
| 18 | +) | ||
| 14 | 19 | ||
| 15 | logger = logging.getLogger(__name__) | 20 | logger = logging.getLogger(__name__) |
| 16 | 21 | ||
| @@ -149,6 +154,91 @@ class LocalSeq2SeqTranslationBackend: | @@ -149,6 +154,91 @@ class LocalSeq2SeqTranslationBackend: | ||
| 149 | outputs = self.tokenizer.batch_decode(generated, skip_special_tokens=True) | 154 | outputs = self.tokenizer.batch_decode(generated, skip_special_tokens=True) |
| 150 | return [item.strip() if item and item.strip() else None for item in outputs] | 155 | return [item.strip() if item and item.strip() else None for item in outputs] |
| 151 | 156 | ||
| 157 | + def _token_count( | ||
| 158 | + self, | ||
| 159 | + text: str, | ||
| 160 | + target_lang: str, | ||
| 161 | + source_lang: Optional[str] = None, | ||
| 162 | + ) -> int: | ||
| 163 | + tokenizer_kwargs = self._prepare_tokenizer(source_lang, target_lang) | ||
| 164 | + with self._lock: | ||
| 165 | + encoded = self.tokenizer( | ||
| 166 | + [text], | ||
| 167 | + truncation=False, | ||
| 168 | + padding=False, | ||
| 169 | + **tokenizer_kwargs, | ||
| 170 | + ) | ||
| 171 | + input_ids = encoded["input_ids"] | ||
| 172 | + first_item = input_ids[0] | ||
| 173 | + if hasattr(first_item, "shape"): | ||
| 174 | + return int(first_item.shape[-1]) | ||
| 175 | + return len(first_item) | ||
| 176 | + | ||
| 177 | + def _effective_input_token_limit(self, target_lang: str, source_lang: Optional[str] = None) -> int: | ||
| 178 | + del target_lang, source_lang | ||
| 179 | + return compute_safe_input_token_limit( | ||
| 180 | + max_input_length=self.max_input_length, | ||
| 181 | + max_new_tokens=self.max_new_tokens, | ||
| 182 | + ) | ||
| 183 | + | ||
| 184 | + def _split_text_if_needed( | ||
| 185 | + self, | ||
| 186 | + text: str, | ||
| 187 | + target_lang: str, | ||
| 188 | + source_lang: Optional[str] = None, | ||
| 189 | + ) -> List[str]: | ||
| 190 | + limit = self._effective_input_token_limit(target_lang, source_lang) | ||
| 191 | + return split_text_for_translation( | ||
| 192 | + text, | ||
| 193 | + max_tokens=limit, | ||
| 194 | + token_length_fn=lambda value: self._token_count( | ||
| 195 | + value, | ||
| 196 | + target_lang=target_lang, | ||
| 197 | + source_lang=source_lang, | ||
| 198 | + ), | ||
| 199 | + ) | ||
| 200 | + | ||
| 201 | + def _translate_with_segmentation( | ||
| 202 | + self, | ||
| 203 | + texts: List[str], | ||
| 204 | + target_lang: str, | ||
| 205 | + source_lang: Optional[str] = None, | ||
| 206 | + ) -> List[Optional[str]]: | ||
| 207 | + segment_plans: List[List[str]] = [] | ||
| 208 | + flat_segments: List[str] = [] | ||
| 209 | + for text in texts: | ||
| 210 | + if not text.strip(): | ||
| 211 | + segment_plans.append([]) | ||
| 212 | + continue | ||
| 213 | + segments = self._split_text_if_needed(text, target_lang=target_lang, source_lang=source_lang) | ||
| 214 | + segment_plans.append(segments) | ||
| 215 | + flat_segments.extend(segments) | ||
| 216 | + | ||
| 217 | + translated_segments = ( | ||
| 218 | + self._translate_batch(flat_segments, target_lang=target_lang, source_lang=source_lang) | ||
| 219 | + if flat_segments | ||
| 220 | + else [] | ||
| 221 | + ) | ||
| 222 | + outputs: List[Optional[str]] = [] | ||
| 223 | + offset = 0 | ||
| 224 | + for original_text, segments in zip(texts, segment_plans): | ||
| 225 | + if not segments: | ||
| 226 | + outputs.append(None if not original_text.strip() else original_text) | ||
| 227 | + continue | ||
| 228 | + current = translated_segments[offset:offset + len(segments)] | ||
| 229 | + offset += len(segments) | ||
| 230 | + if len(segments) == 1: | ||
| 231 | + outputs.append(current[0]) | ||
| 232 | + continue | ||
| 233 | + outputs.append( | ||
| 234 | + join_translated_segments( | ||
| 235 | + current, | ||
| 236 | + target_lang=target_lang, | ||
| 237 | + original_text=original_text, | ||
| 238 | + ) | ||
| 239 | + ) | ||
| 240 | + return outputs | ||
| 241 | + | ||
| 152 | def translate( | 242 | def translate( |
| 153 | self, | 243 | self, |
| 154 | text: Union[str, Sequence[str]], | 244 | text: Union[str, Sequence[str]], |
| @@ -165,7 +255,7 @@ class LocalSeq2SeqTranslationBackend: | @@ -165,7 +255,7 @@ class LocalSeq2SeqTranslationBackend: | ||
| 165 | if not any(item.strip() for item in chunk): | 255 | if not any(item.strip() for item in chunk): |
| 166 | outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item] | 256 | outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item] |
| 167 | continue | 257 | continue |
| 168 | - outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang)) | 258 | + outputs.extend(self._translate_with_segmentation(chunk, target_lang=target_lang, source_lang=source_lang)) |
| 169 | return outputs[0] if is_single else outputs | 259 | return outputs[0] if is_single else outputs |
| 170 | 260 | ||
| 171 | 261 |
| @@ -0,0 +1,226 @@ | @@ -0,0 +1,226 @@ | ||
| 1 | +"""Utilities for token-budget-aware translation text splitting.""" | ||
| 2 | + | ||
| 3 | +from __future__ import annotations | ||
| 4 | + | ||
| 5 | +from typing import Callable, List, Optional | ||
| 6 | + | ||
| 7 | +TokenLengthFn = Callable[[str], int] | ||
| 8 | + | ||
| 9 | +_CJK_LANGS = {"zh", "ja", "ko"} | ||
| 10 | +_STRONG_BOUNDARIES = {"\n", "。", "!", "?", "!", "?", ";", ";", "…"} | ||
| 11 | +_WEAK_BOUNDARIES = {",", ",", "、", ":", ":", "(", ")", "(", ")", "[", "]", "【", "】", "/", "|"} | ||
| 12 | +_CLOSING_CHARS = {'"', "'", "”", "’", ")", "]", "}", ")", "】", "》", "」", "』"} | ||
| 13 | +_NO_SPACE_BEFORE = tuple('.,!?;:)]}%>"\'') | ||
| 14 | +_NO_SPACE_AFTER = tuple("([{$#@/<") | ||
| 15 | + | ||
| 16 | + | ||
| 17 | +def is_cjk_language(lang: Optional[str]) -> bool: | ||
| 18 | + return str(lang or "").strip().lower() in _CJK_LANGS | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +def compute_safe_input_token_limit( | ||
| 22 | + *, | ||
| 23 | + max_input_length: int, | ||
| 24 | + max_new_tokens: int, | ||
| 25 | + decoding_length_mode: str = "fixed", | ||
| 26 | + decoding_length_extra: int = 0, | ||
| 27 | + reserve_input_tokens: int = 8, | ||
| 28 | + reserve_output_tokens: int = 8, | ||
| 29 | +) -> int: | ||
| 30 | + """Derive a conservative source-token budget for translation splitting. | ||
| 31 | + | ||
| 32 | + We keep a small reserve for tokenizer special tokens on the input side. If | ||
| 33 | + the decode side is much tighter than the encode side, we also cap the | ||
| 34 | + source budget based on decode settings so we split before the model is | ||
| 35 | + likely to truncate. | ||
| 36 | + """ | ||
| 37 | + | ||
| 38 | + input_limit = max(8, int(max_input_length) - max(0, int(reserve_input_tokens))) | ||
| 39 | + decode_mode = str(decoding_length_mode or "fixed").strip().lower() | ||
| 40 | + if int(max_new_tokens) <= 0: | ||
| 41 | + return input_limit | ||
| 42 | + if decode_mode == "source": | ||
| 43 | + output_limit = max(8, int(max_new_tokens) - max(0, int(decoding_length_extra))) | ||
| 44 | + return max(8, min(input_limit, output_limit)) | ||
| 45 | + if int(max_new_tokens) >= int(max_input_length): | ||
| 46 | + return input_limit | ||
| 47 | + output_limit = max(8, int(max_new_tokens) - max(0, int(reserve_output_tokens))) | ||
| 48 | + return max(8, min(input_limit, output_limit)) | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +def split_text_for_translation( | ||
| 52 | + text: str, | ||
| 53 | + *, | ||
| 54 | + max_tokens: int, | ||
| 55 | + token_length_fn: TokenLengthFn, | ||
| 56 | +) -> List[str]: | ||
| 57 | + """Split long text into a few translation-friendly segments. | ||
| 58 | + | ||
| 59 | + The splitter prefers sentence boundaries, then clause boundaries, then | ||
| 60 | + whitespace, and only falls back to character-based splitting when needed. | ||
| 61 | + """ | ||
| 62 | + | ||
| 63 | + if not text: | ||
| 64 | + return [text] | ||
| 65 | + if token_length_fn(text) <= max_tokens: | ||
| 66 | + return [text] | ||
| 67 | + segments = _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=0) | ||
| 68 | + return [segment for segment in segments if segment] | ||
| 69 | + | ||
| 70 | + | ||
| 71 | +def join_translated_segments( | ||
| 72 | + segments: List[Optional[str]], | ||
| 73 | + *, | ||
| 74 | + target_lang: Optional[str], | ||
| 75 | + original_text: str, | ||
| 76 | +) -> Optional[str]: | ||
| 77 | + parts = [segment.strip() for segment in segments if segment and segment.strip()] | ||
| 78 | + if not parts: | ||
| 79 | + return None | ||
| 80 | + separator = "" if is_cjk_language(target_lang) else " " | ||
| 81 | + if "\n" in original_text and separator: | ||
| 82 | + separator = "\n" | ||
| 83 | + | ||
| 84 | + merged = parts[0] | ||
| 85 | + for part in parts[1:]: | ||
| 86 | + if not separator: | ||
| 87 | + merged += part | ||
| 88 | + continue | ||
| 89 | + if merged.endswith(_NO_SPACE_AFTER) or part.startswith(_NO_SPACE_BEFORE): | ||
| 90 | + merged += part | ||
| 91 | + continue | ||
| 92 | + merged += separator + part | ||
| 93 | + return merged.strip() or None | ||
| 94 | + | ||
| 95 | + | ||
| 96 | +def _split_recursive( | ||
| 97 | + text: str, | ||
| 98 | + *, | ||
| 99 | + max_tokens: int, | ||
| 100 | + token_length_fn: TokenLengthFn, | ||
| 101 | + level: int, | ||
| 102 | +) -> List[str]: | ||
| 103 | + if token_length_fn(text) <= max_tokens: | ||
| 104 | + return [text] | ||
| 105 | + if level >= 3: | ||
| 106 | + return _hard_split(text, max_tokens=max_tokens, token_length_fn=token_length_fn) | ||
| 107 | + | ||
| 108 | + pieces = _split_by_level(text, level) | ||
| 109 | + if len(pieces) <= 1: | ||
| 110 | + return _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1) | ||
| 111 | + | ||
| 112 | + merged: List[str] = [] | ||
| 113 | + buffer = "" | ||
| 114 | + for piece in pieces: | ||
| 115 | + candidate = buffer + piece if buffer else piece | ||
| 116 | + if token_length_fn(candidate) <= max_tokens: | ||
| 117 | + buffer = candidate | ||
| 118 | + continue | ||
| 119 | + if buffer: | ||
| 120 | + merged.extend( | ||
| 121 | + _split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1) | ||
| 122 | + ) | ||
| 123 | + buffer = piece | ||
| 124 | + continue | ||
| 125 | + merged.extend(_split_recursive(piece, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)) | ||
| 126 | + if buffer: | ||
| 127 | + merged.extend(_split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)) | ||
| 128 | + return merged | ||
| 129 | + | ||
| 130 | + | ||
| 131 | +def _split_by_level(text: str, level: int) -> List[str]: | ||
| 132 | + parts: List[str] = [] | ||
| 133 | + start = 0 | ||
| 134 | + index = 0 | ||
| 135 | + while index < len(text): | ||
| 136 | + boundary_end = _match_boundary(text, index, level) | ||
| 137 | + if boundary_end is None: | ||
| 138 | + index += 1 | ||
| 139 | + continue | ||
| 140 | + if boundary_end > start: | ||
| 141 | + parts.append(text[start:boundary_end]) | ||
| 142 | + start = boundary_end | ||
| 143 | + index = boundary_end | ||
| 144 | + if start < len(text): | ||
| 145 | + parts.append(text[start:]) | ||
| 146 | + return [part for part in parts if part] | ||
| 147 | + | ||
| 148 | + | ||
| 149 | +def _match_boundary(text: str, index: int, level: int) -> Optional[int]: | ||
| 150 | + char = text[index] | ||
| 151 | + if level == 0: | ||
| 152 | + if char in _STRONG_BOUNDARIES: | ||
| 153 | + return _consume_boundary_tail(text, index + 1) | ||
| 154 | + if char == "." and _is_sentence_period(text, index): | ||
| 155 | + return _consume_boundary_tail(text, index + 1) | ||
| 156 | + return None | ||
| 157 | + if level == 1: | ||
| 158 | + if char in _WEAK_BOUNDARIES: | ||
| 159 | + return _consume_boundary_tail(text, index + 1) | ||
| 160 | + return None | ||
| 161 | + if level == 2 and char.isspace(): | ||
| 162 | + end = index + 1 | ||
| 163 | + while end < len(text) and text[end].isspace(): | ||
| 164 | + end += 1 | ||
| 165 | + return end | ||
| 166 | + return None | ||
| 167 | + | ||
| 168 | + | ||
| 169 | +def _consume_boundary_tail(text: str, index: int) -> int: | ||
| 170 | + end = index | ||
| 171 | + while end < len(text) and text[end] in _CLOSING_CHARS: | ||
| 172 | + end += 1 | ||
| 173 | + while end < len(text) and text[end].isspace(): | ||
| 174 | + end += 1 | ||
| 175 | + return end | ||
| 176 | + | ||
| 177 | + | ||
| 178 | +def _is_sentence_period(text: str, index: int) -> bool: | ||
| 179 | + prev_char = text[index - 1] if index > 0 else "" | ||
| 180 | + next_char = text[index + 1] if index + 1 < len(text) else "" | ||
| 181 | + if prev_char.isdigit() and next_char.isdigit(): | ||
| 182 | + return False | ||
| 183 | + if not next_char: | ||
| 184 | + return True | ||
| 185 | + return next_char.isspace() or next_char in _CLOSING_CHARS | ||
| 186 | + | ||
| 187 | + | ||
| 188 | +def _hard_split(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> List[str]: | ||
| 189 | + segments: List[str] = [] | ||
| 190 | + remaining = text | ||
| 191 | + while remaining: | ||
| 192 | + if token_length_fn(remaining) <= max_tokens: | ||
| 193 | + segments.append(remaining) | ||
| 194 | + break | ||
| 195 | + cut = _largest_prefix_within_limit(remaining, max_tokens=max_tokens, token_length_fn=token_length_fn) | ||
| 196 | + refined_cut = _refine_cut(remaining, cut, max_tokens=max_tokens, token_length_fn=token_length_fn) | ||
| 197 | + if refined_cut <= 0: | ||
| 198 | + refined_cut = max(1, cut) | ||
| 199 | + segments.append(remaining[:refined_cut]) | ||
| 200 | + remaining = remaining[refined_cut:] | ||
| 201 | + return segments | ||
| 202 | + | ||
| 203 | + | ||
| 204 | +def _largest_prefix_within_limit(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int: | ||
| 205 | + low = 1 | ||
| 206 | + high = len(text) | ||
| 207 | + best = 1 | ||
| 208 | + while low <= high: | ||
| 209 | + mid = (low + high) // 2 | ||
| 210 | + if token_length_fn(text[:mid]) <= max_tokens: | ||
| 211 | + best = mid | ||
| 212 | + low = mid + 1 | ||
| 213 | + continue | ||
| 214 | + high = mid - 1 | ||
| 215 | + return best | ||
| 216 | + | ||
| 217 | + | ||
| 218 | +def _refine_cut(text: str, cut: int, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int: | ||
| 219 | + best = cut | ||
| 220 | + lower_bound = max(1, cut - 32) | ||
| 221 | + for candidate in range(cut, lower_bound - 1, -1): | ||
| 222 | + if text[candidate - 1].isspace() or text[candidate - 1] in _STRONG_BOUNDARIES or text[candidate - 1] in _WEAK_BOUNDARIES: | ||
| 223 | + if candidate >= max(1, cut // 2) and token_length_fn(text[:candidate]) <= max_tokens: | ||
| 224 | + return candidate | ||
| 225 | + best = max(best, candidate) | ||
| 226 | + return best |