Commit 294c3d0a416199bb36b427a593bdb01a2899a098
1 parent
46ce858d
实现第一版“按模型预算智能分句”的基础能力。
改动: 新增分句与预算工具:translation/text_splitter.py 接入 HF 本地后端:translation/backends/local_seq2seq.py (line 157) 接入 CT2 本地后端:translation/backends/local_ctranslate2.py (line 301) 补了测试:tests/test_translation_local_backends.py 我先把代码里实际限制梳理了一遍,关键配置在 config/config.yaml (line 133): nllb-200-distilled-600m: max_input_length=256,max_new_tokens=64,并且是 ct2_decoding_length_mode=source + extra=8。现在按这个配置计算出的保守输入预算是 56 token。 opus-mt-zh-en: max_input_length=256,max_new_tokens=256。现在保守输入预算是 248 token。 opus-mt-en-zh: 同上,也是 248 token。 这版分句策略是: 先按强边界切:。!?!?;;…、换行、英文句号 不够再按弱边界切:,,、::()()[]【】/| 再不够才按空白切 最后才做 token 预算下的硬切 超长时会“分句翻译后再回拼”,中文目标语言默认无空格回拼,英文等默认按空格回拼,尽量别切太碎 验证: python3 -m compileall translation tests/test_translation_local_backends.py 已通过
Showing
4 changed files
with
481 additions
and
2 deletions
Show diff stats
tests/test_translation_local_backends.py
| ... | ... | @@ -2,6 +2,7 @@ import torch |
| 2 | 2 | |
| 3 | 3 | from translation.backends.local_seq2seq import MarianMTTranslationBackend, NLLBTranslationBackend |
| 4 | 4 | from translation.service import TranslationService |
| 5 | +from translation.text_splitter import compute_safe_input_token_limit, split_text_for_translation | |
| 5 | 6 | |
| 6 | 7 | |
| 7 | 8 | class _FakeBatch(dict): |
| ... | ... | @@ -167,3 +168,84 @@ def test_translation_service_preloads_enabled_backends(monkeypatch): |
| 167 | 168 | |
| 168 | 169 | backend = service.get_backend("opus-mt-en-zh") |
| 169 | 170 | assert backend.model == "opus-mt-en-zh" |
| 171 | + | |
| 172 | + | |
| 173 | +def test_compute_safe_input_token_limit_uses_decode_constraints(): | |
| 174 | + nllb_limit = compute_safe_input_token_limit( | |
| 175 | + max_input_length=256, | |
| 176 | + max_new_tokens=64, | |
| 177 | + decoding_length_mode="source", | |
| 178 | + decoding_length_extra=8, | |
| 179 | + ) | |
| 180 | + opus_limit = compute_safe_input_token_limit( | |
| 181 | + max_input_length=256, | |
| 182 | + max_new_tokens=256, | |
| 183 | + ) | |
| 184 | + | |
| 185 | + assert nllb_limit == 56 | |
| 186 | + assert opus_limit == 248 | |
| 187 | + | |
| 188 | + | |
| 189 | +def test_split_text_for_translation_prefers_sentence_boundaries(): | |
| 190 | + text = ( | |
| 191 | + "这是一条很长的中文商品描述,包含材质、尺码和适用场景。" | |
| 192 | + "适合春夏通勤,也适合日常出街穿搭;" | |
| 193 | + "如果长度超了,应该优先按完整语义分句,而不是切成很碎的小片段。" | |
| 194 | + ) | |
| 195 | + | |
| 196 | + segments = split_text_for_translation( | |
| 197 | + text, | |
| 198 | + max_tokens=36, | |
| 199 | + token_length_fn=len, | |
| 200 | + ) | |
| 201 | + | |
| 202 | + assert len(segments) >= 2 | |
| 203 | + assert "".join(segments) == text | |
| 204 | + assert all(len(segment) <= 36 for segment in segments) | |
| 205 | + assert segments[0].endswith(("。", ";")) | |
| 206 | + | |
| 207 | + | |
| 208 | +class _SegmentingMarianBackend(MarianMTTranslationBackend): | |
| 209 | + def _load_model(self): | |
| 210 | + self.translated_batches = [] | |
| 211 | + | |
| 212 | + def _token_count(self, text, target_lang, source_lang=None): | |
| 213 | + del target_lang, source_lang | |
| 214 | + return len(text) | |
| 215 | + | |
| 216 | + def _translate_batch(self, texts, target_lang, source_lang=None): | |
| 217 | + del source_lang | |
| 218 | + self.translated_batches.append(list(texts)) | |
| 219 | + if target_lang == "zh": | |
| 220 | + return [f"<{text.strip()}>" for text in texts] | |
| 221 | + return [f"[{text.strip()}]" for text in texts] | |
| 222 | + | |
| 223 | + | |
| 224 | +def test_local_backend_splits_oversized_text_before_translation(): | |
| 225 | + backend = _SegmentingMarianBackend( | |
| 226 | + name="opus-mt-en-zh", | |
| 227 | + model_id="Helsinki-NLP/opus-mt-en-zh", | |
| 228 | + model_dir="./models/translation/Helsinki-NLP/opus-mt-en-zh", | |
| 229 | + device="cpu", | |
| 230 | + torch_dtype="float32", | |
| 231 | + batch_size=8, | |
| 232 | + max_input_length=24, | |
| 233 | + max_new_tokens=24, | |
| 234 | + num_beams=1, | |
| 235 | + source_langs=["en"], | |
| 236 | + target_langs=["zh"], | |
| 237 | + ) | |
| 238 | + | |
| 239 | + text = ( | |
| 240 | + "This soft cotton dress is breathable and lightweight, " | |
| 241 | + "works well for spring travel and everyday wear, " | |
| 242 | + "and should be split on natural clause boundaries when it gets too long." | |
| 243 | + ) | |
| 244 | + | |
| 245 | + result = backend.translate(text, source_lang="en", target_lang="zh") | |
| 246 | + | |
| 247 | + assert result is not None | |
| 248 | + assert len(backend.translated_batches) == 1 | |
| 249 | + assert len(backend.translated_batches[0]) >= 2 | |
| 250 | + assert all(len(piece) <= 16 for piece in backend.translated_batches[0]) | |
| 251 | + assert result == "".join(f"<{piece.strip()}>" for piece in backend.translated_batches[0]) | ... | ... |
translation/backends/local_ctranslate2.py
| ... | ... | @@ -14,6 +14,11 @@ from typing import Dict, List, Optional, Sequence, Union |
| 14 | 14 | from transformers import AutoTokenizer |
| 15 | 15 | |
| 16 | 16 | from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES |
| 17 | +from translation.text_splitter import ( | |
| 18 | + compute_safe_input_token_limit, | |
| 19 | + join_translated_segments, | |
| 20 | + split_text_for_translation, | |
| 21 | +) | |
| 17 | 22 | |
| 18 | 23 | logger = logging.getLogger(__name__) |
| 19 | 24 | |
| ... | ... | @@ -296,6 +301,82 @@ class LocalCTranslate2TranslationBackend: |
| 296 | 301 | outputs.append(self._decode_tokens(processed)) |
| 297 | 302 | return outputs |
| 298 | 303 | |
| 304 | + def _token_count( | |
| 305 | + self, | |
| 306 | + text: str, | |
| 307 | + target_lang: str, | |
| 308 | + source_lang: Optional[str] = None, | |
| 309 | + ) -> int: | |
| 310 | + encoded = self._encode_source_tokens([text], source_lang, target_lang) | |
| 311 | + return len(encoded[0]) if encoded else 0 | |
| 312 | + | |
| 313 | + def _effective_input_token_limit(self, target_lang: str, source_lang: Optional[str] = None) -> int: | |
| 314 | + del target_lang, source_lang | |
| 315 | + return compute_safe_input_token_limit( | |
| 316 | + max_input_length=self.max_input_length, | |
| 317 | + max_new_tokens=self.max_new_tokens, | |
| 318 | + decoding_length_mode=self.ct2_decoding_length_mode, | |
| 319 | + decoding_length_extra=self.ct2_decoding_length_extra, | |
| 320 | + ) | |
| 321 | + | |
| 322 | + def _split_text_if_needed( | |
| 323 | + self, | |
| 324 | + text: str, | |
| 325 | + target_lang: str, | |
| 326 | + source_lang: Optional[str] = None, | |
| 327 | + ) -> List[str]: | |
| 328 | + limit = self._effective_input_token_limit(target_lang, source_lang) | |
| 329 | + return split_text_for_translation( | |
| 330 | + text, | |
| 331 | + max_tokens=limit, | |
| 332 | + token_length_fn=lambda value: self._token_count( | |
| 333 | + value, | |
| 334 | + target_lang=target_lang, | |
| 335 | + source_lang=source_lang, | |
| 336 | + ), | |
| 337 | + ) | |
| 338 | + | |
| 339 | + def _translate_with_segmentation( | |
| 340 | + self, | |
| 341 | + texts: List[str], | |
| 342 | + target_lang: str, | |
| 343 | + source_lang: Optional[str] = None, | |
| 344 | + ) -> List[Optional[str]]: | |
| 345 | + segment_plans: List[List[str]] = [] | |
| 346 | + flat_segments: List[str] = [] | |
| 347 | + for text in texts: | |
| 348 | + if not text.strip(): | |
| 349 | + segment_plans.append([]) | |
| 350 | + continue | |
| 351 | + segments = self._split_text_if_needed(text, target_lang=target_lang, source_lang=source_lang) | |
| 352 | + segment_plans.append(segments) | |
| 353 | + flat_segments.extend(segments) | |
| 354 | + | |
| 355 | + translated_segments = ( | |
| 356 | + self._translate_batch(flat_segments, target_lang=target_lang, source_lang=source_lang) | |
| 357 | + if flat_segments | |
| 358 | + else [] | |
| 359 | + ) | |
| 360 | + outputs: List[Optional[str]] = [] | |
| 361 | + offset = 0 | |
| 362 | + for original_text, segments in zip(texts, segment_plans): | |
| 363 | + if not segments: | |
| 364 | + outputs.append(None if not original_text.strip() else original_text) | |
| 365 | + continue | |
| 366 | + current = translated_segments[offset:offset + len(segments)] | |
| 367 | + offset += len(segments) | |
| 368 | + if len(segments) == 1: | |
| 369 | + outputs.append(current[0]) | |
| 370 | + continue | |
| 371 | + outputs.append( | |
| 372 | + join_translated_segments( | |
| 373 | + current, | |
| 374 | + target_lang=target_lang, | |
| 375 | + original_text=original_text, | |
| 376 | + ) | |
| 377 | + ) | |
| 378 | + return outputs | |
| 379 | + | |
| 299 | 380 | def translate( |
| 300 | 381 | self, |
| 301 | 382 | text: Union[str, Sequence[str]], |
| ... | ... | @@ -312,7 +393,7 @@ class LocalCTranslate2TranslationBackend: |
| 312 | 393 | if not any(item.strip() for item in chunk): |
| 313 | 394 | outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item] |
| 314 | 395 | continue |
| 315 | - outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang)) | |
| 396 | + outputs.extend(self._translate_with_segmentation(chunk, target_lang=target_lang, source_lang=source_lang)) | |
| 316 | 397 | return outputs[0] if is_single else outputs |
| 317 | 398 | |
| 318 | 399 | ... | ... |
translation/backends/local_seq2seq.py
| ... | ... | @@ -11,6 +11,11 @@ import torch |
| 11 | 11 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| 12 | 12 | |
| 13 | 13 | from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES |
| 14 | +from translation.text_splitter import ( | |
| 15 | + compute_safe_input_token_limit, | |
| 16 | + join_translated_segments, | |
| 17 | + split_text_for_translation, | |
| 18 | +) | |
| 14 | 19 | |
| 15 | 20 | logger = logging.getLogger(__name__) |
| 16 | 21 | |
| ... | ... | @@ -149,6 +154,91 @@ class LocalSeq2SeqTranslationBackend: |
| 149 | 154 | outputs = self.tokenizer.batch_decode(generated, skip_special_tokens=True) |
| 150 | 155 | return [item.strip() if item and item.strip() else None for item in outputs] |
| 151 | 156 | |
| 157 | + def _token_count( | |
| 158 | + self, | |
| 159 | + text: str, | |
| 160 | + target_lang: str, | |
| 161 | + source_lang: Optional[str] = None, | |
| 162 | + ) -> int: | |
| 163 | + tokenizer_kwargs = self._prepare_tokenizer(source_lang, target_lang) | |
| 164 | + with self._lock: | |
| 165 | + encoded = self.tokenizer( | |
| 166 | + [text], | |
| 167 | + truncation=False, | |
| 168 | + padding=False, | |
| 169 | + **tokenizer_kwargs, | |
| 170 | + ) | |
| 171 | + input_ids = encoded["input_ids"] | |
| 172 | + first_item = input_ids[0] | |
| 173 | + if hasattr(first_item, "shape"): | |
| 174 | + return int(first_item.shape[-1]) | |
| 175 | + return len(first_item) | |
| 176 | + | |
| 177 | + def _effective_input_token_limit(self, target_lang: str, source_lang: Optional[str] = None) -> int: | |
| 178 | + del target_lang, source_lang | |
| 179 | + return compute_safe_input_token_limit( | |
| 180 | + max_input_length=self.max_input_length, | |
| 181 | + max_new_tokens=self.max_new_tokens, | |
| 182 | + ) | |
| 183 | + | |
| 184 | + def _split_text_if_needed( | |
| 185 | + self, | |
| 186 | + text: str, | |
| 187 | + target_lang: str, | |
| 188 | + source_lang: Optional[str] = None, | |
| 189 | + ) -> List[str]: | |
| 190 | + limit = self._effective_input_token_limit(target_lang, source_lang) | |
| 191 | + return split_text_for_translation( | |
| 192 | + text, | |
| 193 | + max_tokens=limit, | |
| 194 | + token_length_fn=lambda value: self._token_count( | |
| 195 | + value, | |
| 196 | + target_lang=target_lang, | |
| 197 | + source_lang=source_lang, | |
| 198 | + ), | |
| 199 | + ) | |
| 200 | + | |
| 201 | + def _translate_with_segmentation( | |
| 202 | + self, | |
| 203 | + texts: List[str], | |
| 204 | + target_lang: str, | |
| 205 | + source_lang: Optional[str] = None, | |
| 206 | + ) -> List[Optional[str]]: | |
| 207 | + segment_plans: List[List[str]] = [] | |
| 208 | + flat_segments: List[str] = [] | |
| 209 | + for text in texts: | |
| 210 | + if not text.strip(): | |
| 211 | + segment_plans.append([]) | |
| 212 | + continue | |
| 213 | + segments = self._split_text_if_needed(text, target_lang=target_lang, source_lang=source_lang) | |
| 214 | + segment_plans.append(segments) | |
| 215 | + flat_segments.extend(segments) | |
| 216 | + | |
| 217 | + translated_segments = ( | |
| 218 | + self._translate_batch(flat_segments, target_lang=target_lang, source_lang=source_lang) | |
| 219 | + if flat_segments | |
| 220 | + else [] | |
| 221 | + ) | |
| 222 | + outputs: List[Optional[str]] = [] | |
| 223 | + offset = 0 | |
| 224 | + for original_text, segments in zip(texts, segment_plans): | |
| 225 | + if not segments: | |
| 226 | + outputs.append(None if not original_text.strip() else original_text) | |
| 227 | + continue | |
| 228 | + current = translated_segments[offset:offset + len(segments)] | |
| 229 | + offset += len(segments) | |
| 230 | + if len(segments) == 1: | |
| 231 | + outputs.append(current[0]) | |
| 232 | + continue | |
| 233 | + outputs.append( | |
| 234 | + join_translated_segments( | |
| 235 | + current, | |
| 236 | + target_lang=target_lang, | |
| 237 | + original_text=original_text, | |
| 238 | + ) | |
| 239 | + ) | |
| 240 | + return outputs | |
| 241 | + | |
| 152 | 242 | def translate( |
| 153 | 243 | self, |
| 154 | 244 | text: Union[str, Sequence[str]], |
| ... | ... | @@ -165,7 +255,7 @@ class LocalSeq2SeqTranslationBackend: |
| 165 | 255 | if not any(item.strip() for item in chunk): |
| 166 | 256 | outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item] |
| 167 | 257 | continue |
| 168 | - outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang)) | |
| 258 | + outputs.extend(self._translate_with_segmentation(chunk, target_lang=target_lang, source_lang=source_lang)) | |
| 169 | 259 | return outputs[0] if is_single else outputs |
| 170 | 260 | |
| 171 | 261 | ... | ... |
| ... | ... | @@ -0,0 +1,226 @@ |
| 1 | +"""Utilities for token-budget-aware translation text splitting.""" | |
| 2 | + | |
| 3 | +from __future__ import annotations | |
| 4 | + | |
| 5 | +from typing import Callable, List, Optional | |
| 6 | + | |
| 7 | +TokenLengthFn = Callable[[str], int] | |
| 8 | + | |
| 9 | +_CJK_LANGS = {"zh", "ja", "ko"} | |
| 10 | +_STRONG_BOUNDARIES = {"\n", "。", "!", "?", "!", "?", ";", ";", "…"} | |
| 11 | +_WEAK_BOUNDARIES = {",", ",", "、", ":", ":", "(", ")", "(", ")", "[", "]", "【", "】", "/", "|"} | |
| 12 | +_CLOSING_CHARS = {'"', "'", "”", "’", ")", "]", "}", ")", "】", "》", "」", "』"} | |
| 13 | +_NO_SPACE_BEFORE = tuple('.,!?;:)]}%>"\'') | |
| 14 | +_NO_SPACE_AFTER = tuple("([{$#@/<") | |
| 15 | + | |
| 16 | + | |
| 17 | +def is_cjk_language(lang: Optional[str]) -> bool: | |
| 18 | + return str(lang or "").strip().lower() in _CJK_LANGS | |
| 19 | + | |
| 20 | + | |
| 21 | +def compute_safe_input_token_limit( | |
| 22 | + *, | |
| 23 | + max_input_length: int, | |
| 24 | + max_new_tokens: int, | |
| 25 | + decoding_length_mode: str = "fixed", | |
| 26 | + decoding_length_extra: int = 0, | |
| 27 | + reserve_input_tokens: int = 8, | |
| 28 | + reserve_output_tokens: int = 8, | |
| 29 | +) -> int: | |
| 30 | + """Derive a conservative source-token budget for translation splitting. | |
| 31 | + | |
| 32 | + We keep a small reserve for tokenizer special tokens on the input side. If | |
| 33 | + the decode side is much tighter than the encode side, we also cap the | |
| 34 | + source budget based on decode settings so we split before the model is | |
| 35 | + likely to truncate. | |
| 36 | + """ | |
| 37 | + | |
| 38 | + input_limit = max(8, int(max_input_length) - max(0, int(reserve_input_tokens))) | |
| 39 | + decode_mode = str(decoding_length_mode or "fixed").strip().lower() | |
| 40 | + if int(max_new_tokens) <= 0: | |
| 41 | + return input_limit | |
| 42 | + if decode_mode == "source": | |
| 43 | + output_limit = max(8, int(max_new_tokens) - max(0, int(decoding_length_extra))) | |
| 44 | + return max(8, min(input_limit, output_limit)) | |
| 45 | + if int(max_new_tokens) >= int(max_input_length): | |
| 46 | + return input_limit | |
| 47 | + output_limit = max(8, int(max_new_tokens) - max(0, int(reserve_output_tokens))) | |
| 48 | + return max(8, min(input_limit, output_limit)) | |
| 49 | + | |
| 50 | + | |
| 51 | +def split_text_for_translation( | |
| 52 | + text: str, | |
| 53 | + *, | |
| 54 | + max_tokens: int, | |
| 55 | + token_length_fn: TokenLengthFn, | |
| 56 | +) -> List[str]: | |
| 57 | + """Split long text into a few translation-friendly segments. | |
| 58 | + | |
| 59 | + The splitter prefers sentence boundaries, then clause boundaries, then | |
| 60 | + whitespace, and only falls back to character-based splitting when needed. | |
| 61 | + """ | |
| 62 | + | |
| 63 | + if not text: | |
| 64 | + return [text] | |
| 65 | + if token_length_fn(text) <= max_tokens: | |
| 66 | + return [text] | |
| 67 | + segments = _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=0) | |
| 68 | + return [segment for segment in segments if segment] | |
| 69 | + | |
| 70 | + | |
| 71 | +def join_translated_segments( | |
| 72 | + segments: List[Optional[str]], | |
| 73 | + *, | |
| 74 | + target_lang: Optional[str], | |
| 75 | + original_text: str, | |
| 76 | +) -> Optional[str]: | |
| 77 | + parts = [segment.strip() for segment in segments if segment and segment.strip()] | |
| 78 | + if not parts: | |
| 79 | + return None | |
| 80 | + separator = "" if is_cjk_language(target_lang) else " " | |
| 81 | + if "\n" in original_text and separator: | |
| 82 | + separator = "\n" | |
| 83 | + | |
| 84 | + merged = parts[0] | |
| 85 | + for part in parts[1:]: | |
| 86 | + if not separator: | |
| 87 | + merged += part | |
| 88 | + continue | |
| 89 | + if merged.endswith(_NO_SPACE_AFTER) or part.startswith(_NO_SPACE_BEFORE): | |
| 90 | + merged += part | |
| 91 | + continue | |
| 92 | + merged += separator + part | |
| 93 | + return merged.strip() or None | |
| 94 | + | |
| 95 | + | |
| 96 | +def _split_recursive( | |
| 97 | + text: str, | |
| 98 | + *, | |
| 99 | + max_tokens: int, | |
| 100 | + token_length_fn: TokenLengthFn, | |
| 101 | + level: int, | |
| 102 | +) -> List[str]: | |
| 103 | + if token_length_fn(text) <= max_tokens: | |
| 104 | + return [text] | |
| 105 | + if level >= 3: | |
| 106 | + return _hard_split(text, max_tokens=max_tokens, token_length_fn=token_length_fn) | |
| 107 | + | |
| 108 | + pieces = _split_by_level(text, level) | |
| 109 | + if len(pieces) <= 1: | |
| 110 | + return _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1) | |
| 111 | + | |
| 112 | + merged: List[str] = [] | |
| 113 | + buffer = "" | |
| 114 | + for piece in pieces: | |
| 115 | + candidate = buffer + piece if buffer else piece | |
| 116 | + if token_length_fn(candidate) <= max_tokens: | |
| 117 | + buffer = candidate | |
| 118 | + continue | |
| 119 | + if buffer: | |
| 120 | + merged.extend( | |
| 121 | + _split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1) | |
| 122 | + ) | |
| 123 | + buffer = piece | |
| 124 | + continue | |
| 125 | + merged.extend(_split_recursive(piece, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)) | |
| 126 | + if buffer: | |
| 127 | + merged.extend(_split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)) | |
| 128 | + return merged | |
| 129 | + | |
| 130 | + | |
| 131 | +def _split_by_level(text: str, level: int) -> List[str]: | |
| 132 | + parts: List[str] = [] | |
| 133 | + start = 0 | |
| 134 | + index = 0 | |
| 135 | + while index < len(text): | |
| 136 | + boundary_end = _match_boundary(text, index, level) | |
| 137 | + if boundary_end is None: | |
| 138 | + index += 1 | |
| 139 | + continue | |
| 140 | + if boundary_end > start: | |
| 141 | + parts.append(text[start:boundary_end]) | |
| 142 | + start = boundary_end | |
| 143 | + index = boundary_end | |
| 144 | + if start < len(text): | |
| 145 | + parts.append(text[start:]) | |
| 146 | + return [part for part in parts if part] | |
| 147 | + | |
| 148 | + | |
| 149 | +def _match_boundary(text: str, index: int, level: int) -> Optional[int]: | |
| 150 | + char = text[index] | |
| 151 | + if level == 0: | |
| 152 | + if char in _STRONG_BOUNDARIES: | |
| 153 | + return _consume_boundary_tail(text, index + 1) | |
| 154 | + if char == "." and _is_sentence_period(text, index): | |
| 155 | + return _consume_boundary_tail(text, index + 1) | |
| 156 | + return None | |
| 157 | + if level == 1: | |
| 158 | + if char in _WEAK_BOUNDARIES: | |
| 159 | + return _consume_boundary_tail(text, index + 1) | |
| 160 | + return None | |
| 161 | + if level == 2 and char.isspace(): | |
| 162 | + end = index + 1 | |
| 163 | + while end < len(text) and text[end].isspace(): | |
| 164 | + end += 1 | |
| 165 | + return end | |
| 166 | + return None | |
| 167 | + | |
| 168 | + | |
| 169 | +def _consume_boundary_tail(text: str, index: int) -> int: | |
| 170 | + end = index | |
| 171 | + while end < len(text) and text[end] in _CLOSING_CHARS: | |
| 172 | + end += 1 | |
| 173 | + while end < len(text) and text[end].isspace(): | |
| 174 | + end += 1 | |
| 175 | + return end | |
| 176 | + | |
| 177 | + | |
| 178 | +def _is_sentence_period(text: str, index: int) -> bool: | |
| 179 | + prev_char = text[index - 1] if index > 0 else "" | |
| 180 | + next_char = text[index + 1] if index + 1 < len(text) else "" | |
| 181 | + if prev_char.isdigit() and next_char.isdigit(): | |
| 182 | + return False | |
| 183 | + if not next_char: | |
| 184 | + return True | |
| 185 | + return next_char.isspace() or next_char in _CLOSING_CHARS | |
| 186 | + | |
| 187 | + | |
| 188 | +def _hard_split(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> List[str]: | |
| 189 | + segments: List[str] = [] | |
| 190 | + remaining = text | |
| 191 | + while remaining: | |
| 192 | + if token_length_fn(remaining) <= max_tokens: | |
| 193 | + segments.append(remaining) | |
| 194 | + break | |
| 195 | + cut = _largest_prefix_within_limit(remaining, max_tokens=max_tokens, token_length_fn=token_length_fn) | |
| 196 | + refined_cut = _refine_cut(remaining, cut, max_tokens=max_tokens, token_length_fn=token_length_fn) | |
| 197 | + if refined_cut <= 0: | |
| 198 | + refined_cut = max(1, cut) | |
| 199 | + segments.append(remaining[:refined_cut]) | |
| 200 | + remaining = remaining[refined_cut:] | |
| 201 | + return segments | |
| 202 | + | |
| 203 | + | |
| 204 | +def _largest_prefix_within_limit(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int: | |
| 205 | + low = 1 | |
| 206 | + high = len(text) | |
| 207 | + best = 1 | |
| 208 | + while low <= high: | |
| 209 | + mid = (low + high) // 2 | |
| 210 | + if token_length_fn(text[:mid]) <= max_tokens: | |
| 211 | + best = mid | |
| 212 | + low = mid + 1 | |
| 213 | + continue | |
| 214 | + high = mid - 1 | |
| 215 | + return best | |
| 216 | + | |
| 217 | + | |
| 218 | +def _refine_cut(text: str, cut: int, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int: | |
| 219 | + best = cut | |
| 220 | + lower_bound = max(1, cut - 32) | |
| 221 | + for candidate in range(cut, lower_bound - 1, -1): | |
| 222 | + if text[candidate - 1].isspace() or text[candidate - 1] in _STRONG_BOUNDARIES or text[candidate - 1] in _WEAK_BOUNDARIES: | |
| 223 | + if candidate >= max(1, cut // 2) and token_length_fn(text[:candidate]) <= max_tokens: | |
| 224 | + return candidate | |
| 225 | + best = max(best, candidate) | |
| 226 | + return best | ... | ... |