From 294c3d0a416199bb36b427a593bdb01a2899a098 Mon Sep 17 00:00:00 2001 From: tangwang Date: Thu, 19 Mar 2026 09:51:06 +0800 Subject: [PATCH] 实现第一版“按模型预算智能分句”的基础能力。 --- tests/test_translation_local_backends.py | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ translation/backends/local_ctranslate2.py | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- translation/backends/local_seq2seq.py | 92 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- translation/text_splitter.py | 226 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 481 insertions(+), 2 deletions(-) create mode 100644 translation/text_splitter.py diff --git a/tests/test_translation_local_backends.py b/tests/test_translation_local_backends.py index 6dbfcb6..c9c6c58 100644 --- a/tests/test_translation_local_backends.py +++ b/tests/test_translation_local_backends.py @@ -2,6 +2,7 @@ import torch from translation.backends.local_seq2seq import MarianMTTranslationBackend, NLLBTranslationBackend from translation.service import TranslationService +from translation.text_splitter import compute_safe_input_token_limit, split_text_for_translation class _FakeBatch(dict): @@ -167,3 +168,84 @@ def test_translation_service_preloads_enabled_backends(monkeypatch): backend = service.get_backend("opus-mt-en-zh") assert backend.model == "opus-mt-en-zh" + + +def test_compute_safe_input_token_limit_uses_decode_constraints(): + nllb_limit = compute_safe_input_token_limit( + max_input_length=256, + max_new_tokens=64, + decoding_length_mode="source", + decoding_length_extra=8, + ) + opus_limit = compute_safe_input_token_limit( + max_input_length=256, + max_new_tokens=256, + ) + + assert nllb_limit == 56 + assert opus_limit == 248 + + +def test_split_text_for_translation_prefers_sentence_boundaries(): + text = ( + "这是一条很长的中文商品描述,包含材质、尺码和适用场景。" + "适合春夏通勤,也适合日常出街穿搭;" + "如果长度超了,应该优先按完整语义分句,而不是切成很碎的小片段。" + ) + + segments = split_text_for_translation( + text, + max_tokens=36, + token_length_fn=len, + ) + + assert len(segments) >= 2 + assert "".join(segments) == text + assert all(len(segment) <= 36 for segment in segments) + assert segments[0].endswith(("。", ";")) + + +class _SegmentingMarianBackend(MarianMTTranslationBackend): + def _load_model(self): + self.translated_batches = [] + + def _token_count(self, text, target_lang, source_lang=None): + del target_lang, source_lang + return len(text) + + def _translate_batch(self, texts, target_lang, source_lang=None): + del source_lang + self.translated_batches.append(list(texts)) + if target_lang == "zh": + return [f"<{text.strip()}>" for text in texts] + return [f"[{text.strip()}]" for text in texts] + + +def test_local_backend_splits_oversized_text_before_translation(): + backend = _SegmentingMarianBackend( + name="opus-mt-en-zh", + model_id="Helsinki-NLP/opus-mt-en-zh", + model_dir="./models/translation/Helsinki-NLP/opus-mt-en-zh", + device="cpu", + torch_dtype="float32", + batch_size=8, + max_input_length=24, + max_new_tokens=24, + num_beams=1, + source_langs=["en"], + target_langs=["zh"], + ) + + text = ( + "This soft cotton dress is breathable and lightweight, " + "works well for spring travel and everyday wear, " + "and should be split on natural clause boundaries when it gets too long." + ) + + result = backend.translate(text, source_lang="en", target_lang="zh") + + assert result is not None + assert len(backend.translated_batches) == 1 + assert len(backend.translated_batches[0]) >= 2 + assert all(len(piece) <= 16 for piece in backend.translated_batches[0]) + assert result == "".join(f"<{piece.strip()}>" for piece in backend.translated_batches[0]) diff --git a/translation/backends/local_ctranslate2.py b/translation/backends/local_ctranslate2.py index 638a6d7..210bdbf 100644 --- a/translation/backends/local_ctranslate2.py +++ b/translation/backends/local_ctranslate2.py @@ -14,6 +14,11 @@ from typing import Dict, List, Optional, Sequence, Union from transformers import AutoTokenizer from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES +from translation.text_splitter import ( + compute_safe_input_token_limit, + join_translated_segments, + split_text_for_translation, +) logger = logging.getLogger(__name__) @@ -296,6 +301,82 @@ class LocalCTranslate2TranslationBackend: outputs.append(self._decode_tokens(processed)) return outputs + def _token_count( + self, + text: str, + target_lang: str, + source_lang: Optional[str] = None, + ) -> int: + encoded = self._encode_source_tokens([text], source_lang, target_lang) + return len(encoded[0]) if encoded else 0 + + def _effective_input_token_limit(self, target_lang: str, source_lang: Optional[str] = None) -> int: + del target_lang, source_lang + return compute_safe_input_token_limit( + max_input_length=self.max_input_length, + max_new_tokens=self.max_new_tokens, + decoding_length_mode=self.ct2_decoding_length_mode, + decoding_length_extra=self.ct2_decoding_length_extra, + ) + + def _split_text_if_needed( + self, + text: str, + target_lang: str, + source_lang: Optional[str] = None, + ) -> List[str]: + limit = self._effective_input_token_limit(target_lang, source_lang) + return split_text_for_translation( + text, + max_tokens=limit, + token_length_fn=lambda value: self._token_count( + value, + target_lang=target_lang, + source_lang=source_lang, + ), + ) + + def _translate_with_segmentation( + self, + texts: List[str], + target_lang: str, + source_lang: Optional[str] = None, + ) -> List[Optional[str]]: + segment_plans: List[List[str]] = [] + flat_segments: List[str] = [] + for text in texts: + if not text.strip(): + segment_plans.append([]) + continue + segments = self._split_text_if_needed(text, target_lang=target_lang, source_lang=source_lang) + segment_plans.append(segments) + flat_segments.extend(segments) + + translated_segments = ( + self._translate_batch(flat_segments, target_lang=target_lang, source_lang=source_lang) + if flat_segments + else [] + ) + outputs: List[Optional[str]] = [] + offset = 0 + for original_text, segments in zip(texts, segment_plans): + if not segments: + outputs.append(None if not original_text.strip() else original_text) + continue + current = translated_segments[offset:offset + len(segments)] + offset += len(segments) + if len(segments) == 1: + outputs.append(current[0]) + continue + outputs.append( + join_translated_segments( + current, + target_lang=target_lang, + original_text=original_text, + ) + ) + return outputs + def translate( self, text: Union[str, Sequence[str]], @@ -312,7 +393,7 @@ class LocalCTranslate2TranslationBackend: if not any(item.strip() for item in chunk): outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item] continue - outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang)) + outputs.extend(self._translate_with_segmentation(chunk, target_lang=target_lang, source_lang=source_lang)) return outputs[0] if is_single else outputs diff --git a/translation/backends/local_seq2seq.py b/translation/backends/local_seq2seq.py index 187a395..d15acbe 100644 --- a/translation/backends/local_seq2seq.py +++ b/translation/backends/local_seq2seq.py @@ -11,6 +11,11 @@ import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES +from translation.text_splitter import ( + compute_safe_input_token_limit, + join_translated_segments, + split_text_for_translation, +) logger = logging.getLogger(__name__) @@ -149,6 +154,91 @@ class LocalSeq2SeqTranslationBackend: outputs = self.tokenizer.batch_decode(generated, skip_special_tokens=True) return [item.strip() if item and item.strip() else None for item in outputs] + def _token_count( + self, + text: str, + target_lang: str, + source_lang: Optional[str] = None, + ) -> int: + tokenizer_kwargs = self._prepare_tokenizer(source_lang, target_lang) + with self._lock: + encoded = self.tokenizer( + [text], + truncation=False, + padding=False, + **tokenizer_kwargs, + ) + input_ids = encoded["input_ids"] + first_item = input_ids[0] + if hasattr(first_item, "shape"): + return int(first_item.shape[-1]) + return len(first_item) + + def _effective_input_token_limit(self, target_lang: str, source_lang: Optional[str] = None) -> int: + del target_lang, source_lang + return compute_safe_input_token_limit( + max_input_length=self.max_input_length, + max_new_tokens=self.max_new_tokens, + ) + + def _split_text_if_needed( + self, + text: str, + target_lang: str, + source_lang: Optional[str] = None, + ) -> List[str]: + limit = self._effective_input_token_limit(target_lang, source_lang) + return split_text_for_translation( + text, + max_tokens=limit, + token_length_fn=lambda value: self._token_count( + value, + target_lang=target_lang, + source_lang=source_lang, + ), + ) + + def _translate_with_segmentation( + self, + texts: List[str], + target_lang: str, + source_lang: Optional[str] = None, + ) -> List[Optional[str]]: + segment_plans: List[List[str]] = [] + flat_segments: List[str] = [] + for text in texts: + if not text.strip(): + segment_plans.append([]) + continue + segments = self._split_text_if_needed(text, target_lang=target_lang, source_lang=source_lang) + segment_plans.append(segments) + flat_segments.extend(segments) + + translated_segments = ( + self._translate_batch(flat_segments, target_lang=target_lang, source_lang=source_lang) + if flat_segments + else [] + ) + outputs: List[Optional[str]] = [] + offset = 0 + for original_text, segments in zip(texts, segment_plans): + if not segments: + outputs.append(None if not original_text.strip() else original_text) + continue + current = translated_segments[offset:offset + len(segments)] + offset += len(segments) + if len(segments) == 1: + outputs.append(current[0]) + continue + outputs.append( + join_translated_segments( + current, + target_lang=target_lang, + original_text=original_text, + ) + ) + return outputs + def translate( self, text: Union[str, Sequence[str]], @@ -165,7 +255,7 @@ class LocalSeq2SeqTranslationBackend: if not any(item.strip() for item in chunk): outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item] continue - outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang)) + outputs.extend(self._translate_with_segmentation(chunk, target_lang=target_lang, source_lang=source_lang)) return outputs[0] if is_single else outputs diff --git a/translation/text_splitter.py b/translation/text_splitter.py new file mode 100644 index 0000000..630d53f --- /dev/null +++ b/translation/text_splitter.py @@ -0,0 +1,226 @@ +"""Utilities for token-budget-aware translation text splitting.""" + +from __future__ import annotations + +from typing import Callable, List, Optional + +TokenLengthFn = Callable[[str], int] + +_CJK_LANGS = {"zh", "ja", "ko"} +_STRONG_BOUNDARIES = {"\n", "。", "!", "?", "!", "?", ";", ";", "…"} +_WEAK_BOUNDARIES = {",", ",", "、", ":", ":", "(", ")", "(", ")", "[", "]", "【", "】", "/", "|"} +_CLOSING_CHARS = {'"', "'", "”", "’", ")", "]", "}", ")", "】", "》", "」", "』"} +_NO_SPACE_BEFORE = tuple('.,!?;:)]}%>"\'') +_NO_SPACE_AFTER = tuple("([{$#@/<") + + +def is_cjk_language(lang: Optional[str]) -> bool: + return str(lang or "").strip().lower() in _CJK_LANGS + + +def compute_safe_input_token_limit( + *, + max_input_length: int, + max_new_tokens: int, + decoding_length_mode: str = "fixed", + decoding_length_extra: int = 0, + reserve_input_tokens: int = 8, + reserve_output_tokens: int = 8, +) -> int: + """Derive a conservative source-token budget for translation splitting. + + We keep a small reserve for tokenizer special tokens on the input side. If + the decode side is much tighter than the encode side, we also cap the + source budget based on decode settings so we split before the model is + likely to truncate. + """ + + input_limit = max(8, int(max_input_length) - max(0, int(reserve_input_tokens))) + decode_mode = str(decoding_length_mode or "fixed").strip().lower() + if int(max_new_tokens) <= 0: + return input_limit + if decode_mode == "source": + output_limit = max(8, int(max_new_tokens) - max(0, int(decoding_length_extra))) + return max(8, min(input_limit, output_limit)) + if int(max_new_tokens) >= int(max_input_length): + return input_limit + output_limit = max(8, int(max_new_tokens) - max(0, int(reserve_output_tokens))) + return max(8, min(input_limit, output_limit)) + + +def split_text_for_translation( + text: str, + *, + max_tokens: int, + token_length_fn: TokenLengthFn, +) -> List[str]: + """Split long text into a few translation-friendly segments. + + The splitter prefers sentence boundaries, then clause boundaries, then + whitespace, and only falls back to character-based splitting when needed. + """ + + if not text: + return [text] + if token_length_fn(text) <= max_tokens: + return [text] + segments = _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=0) + return [segment for segment in segments if segment] + + +def join_translated_segments( + segments: List[Optional[str]], + *, + target_lang: Optional[str], + original_text: str, +) -> Optional[str]: + parts = [segment.strip() for segment in segments if segment and segment.strip()] + if not parts: + return None + separator = "" if is_cjk_language(target_lang) else " " + if "\n" in original_text and separator: + separator = "\n" + + merged = parts[0] + for part in parts[1:]: + if not separator: + merged += part + continue + if merged.endswith(_NO_SPACE_AFTER) or part.startswith(_NO_SPACE_BEFORE): + merged += part + continue + merged += separator + part + return merged.strip() or None + + +def _split_recursive( + text: str, + *, + max_tokens: int, + token_length_fn: TokenLengthFn, + level: int, +) -> List[str]: + if token_length_fn(text) <= max_tokens: + return [text] + if level >= 3: + return _hard_split(text, max_tokens=max_tokens, token_length_fn=token_length_fn) + + pieces = _split_by_level(text, level) + if len(pieces) <= 1: + return _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1) + + merged: List[str] = [] + buffer = "" + for piece in pieces: + candidate = buffer + piece if buffer else piece + if token_length_fn(candidate) <= max_tokens: + buffer = candidate + continue + if buffer: + merged.extend( + _split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1) + ) + buffer = piece + continue + merged.extend(_split_recursive(piece, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)) + if buffer: + merged.extend(_split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)) + return merged + + +def _split_by_level(text: str, level: int) -> List[str]: + parts: List[str] = [] + start = 0 + index = 0 + while index < len(text): + boundary_end = _match_boundary(text, index, level) + if boundary_end is None: + index += 1 + continue + if boundary_end > start: + parts.append(text[start:boundary_end]) + start = boundary_end + index = boundary_end + if start < len(text): + parts.append(text[start:]) + return [part for part in parts if part] + + +def _match_boundary(text: str, index: int, level: int) -> Optional[int]: + char = text[index] + if level == 0: + if char in _STRONG_BOUNDARIES: + return _consume_boundary_tail(text, index + 1) + if char == "." and _is_sentence_period(text, index): + return _consume_boundary_tail(text, index + 1) + return None + if level == 1: + if char in _WEAK_BOUNDARIES: + return _consume_boundary_tail(text, index + 1) + return None + if level == 2 and char.isspace(): + end = index + 1 + while end < len(text) and text[end].isspace(): + end += 1 + return end + return None + + +def _consume_boundary_tail(text: str, index: int) -> int: + end = index + while end < len(text) and text[end] in _CLOSING_CHARS: + end += 1 + while end < len(text) and text[end].isspace(): + end += 1 + return end + + +def _is_sentence_period(text: str, index: int) -> bool: + prev_char = text[index - 1] if index > 0 else "" + next_char = text[index + 1] if index + 1 < len(text) else "" + if prev_char.isdigit() and next_char.isdigit(): + return False + if not next_char: + return True + return next_char.isspace() or next_char in _CLOSING_CHARS + + +def _hard_split(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> List[str]: + segments: List[str] = [] + remaining = text + while remaining: + if token_length_fn(remaining) <= max_tokens: + segments.append(remaining) + break + cut = _largest_prefix_within_limit(remaining, max_tokens=max_tokens, token_length_fn=token_length_fn) + refined_cut = _refine_cut(remaining, cut, max_tokens=max_tokens, token_length_fn=token_length_fn) + if refined_cut <= 0: + refined_cut = max(1, cut) + segments.append(remaining[:refined_cut]) + remaining = remaining[refined_cut:] + return segments + + +def _largest_prefix_within_limit(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int: + low = 1 + high = len(text) + best = 1 + while low <= high: + mid = (low + high) // 2 + if token_length_fn(text[:mid]) <= max_tokens: + best = mid + low = mid + 1 + continue + high = mid - 1 + return best + + +def _refine_cut(text: str, cut: int, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int: + best = cut + lower_bound = max(1, cut - 32) + for candidate in range(cut, lower_bound - 1, -1): + if text[candidate - 1].isspace() or text[candidate - 1] in _STRONG_BOUNDARIES or text[candidate - 1] in _WEAK_BOUNDARIES: + if candidate >= max(1, cut // 2) and token_length_fn(text[:candidate]) <= max_tokens: + return candidate + best = max(best, candidate) + return best -- libgit2 0.21.2