Commit 294c3d0a416199bb36b427a593bdb01a2899a098

Authored by tangwang
1 parent 46ce858d

实现第一版“按模型预算智能分句”的基础能力。

改动:

新增分句与预算工具:translation/text_splitter.py
接入 HF 本地后端:translation/backends/local_seq2seq.py (line 157)
接入 CT2 本地后端:translation/backends/local_ctranslate2.py (line 301)
补了测试:tests/test_translation_local_backends.py
我先把代码里实际限制梳理了一遍,关键配置在 config/config.yaml (line
133):

nllb-200-distilled-600m: max_input_length=256,max_new_tokens=64,并且是
ct2_decoding_length_mode=source +
extra=8。现在按这个配置计算出的保守输入预算是 56 token。
opus-mt-zh-en:
max_input_length=256,max_new_tokens=256。现在保守输入预算是 248 token。
opus-mt-en-zh: 同上,也是 248 token。
这版分句策略是:

先按强边界切:。!?!?;;…、换行、英文句号
不够再按弱边界切:,,、::()()[]【】/|
再不够才按空白切
最后才做 token 预算下的硬切
超长时会“分句翻译后再回拼”,中文目标语言默认无空格回拼,英文等默认按空格回拼,尽量别切太碎
验证:

python3 -m compileall translation
tests/test_translation_local_backends.py 已通过
tests/test_translation_local_backends.py
... ... @@ -2,6 +2,7 @@ import torch
2 2  
3 3 from translation.backends.local_seq2seq import MarianMTTranslationBackend, NLLBTranslationBackend
4 4 from translation.service import TranslationService
  5 +from translation.text_splitter import compute_safe_input_token_limit, split_text_for_translation
5 6  
6 7  
7 8 class _FakeBatch(dict):
... ... @@ -167,3 +168,84 @@ def test_translation_service_preloads_enabled_backends(monkeypatch):
167 168  
168 169 backend = service.get_backend("opus-mt-en-zh")
169 170 assert backend.model == "opus-mt-en-zh"
  171 +
  172 +
  173 +def test_compute_safe_input_token_limit_uses_decode_constraints():
  174 + nllb_limit = compute_safe_input_token_limit(
  175 + max_input_length=256,
  176 + max_new_tokens=64,
  177 + decoding_length_mode="source",
  178 + decoding_length_extra=8,
  179 + )
  180 + opus_limit = compute_safe_input_token_limit(
  181 + max_input_length=256,
  182 + max_new_tokens=256,
  183 + )
  184 +
  185 + assert nllb_limit == 56
  186 + assert opus_limit == 248
  187 +
  188 +
  189 +def test_split_text_for_translation_prefers_sentence_boundaries():
  190 + text = (
  191 + "这是一条很长的中文商品描述,包含材质、尺码和适用场景。"
  192 + "适合春夏通勤,也适合日常出街穿搭;"
  193 + "如果长度超了,应该优先按完整语义分句,而不是切成很碎的小片段。"
  194 + )
  195 +
  196 + segments = split_text_for_translation(
  197 + text,
  198 + max_tokens=36,
  199 + token_length_fn=len,
  200 + )
  201 +
  202 + assert len(segments) >= 2
  203 + assert "".join(segments) == text
  204 + assert all(len(segment) <= 36 for segment in segments)
  205 + assert segments[0].endswith(("。", ";"))
  206 +
  207 +
  208 +class _SegmentingMarianBackend(MarianMTTranslationBackend):
  209 + def _load_model(self):
  210 + self.translated_batches = []
  211 +
  212 + def _token_count(self, text, target_lang, source_lang=None):
  213 + del target_lang, source_lang
  214 + return len(text)
  215 +
  216 + def _translate_batch(self, texts, target_lang, source_lang=None):
  217 + del source_lang
  218 + self.translated_batches.append(list(texts))
  219 + if target_lang == "zh":
  220 + return [f"<{text.strip()}>" for text in texts]
  221 + return [f"[{text.strip()}]" for text in texts]
  222 +
  223 +
  224 +def test_local_backend_splits_oversized_text_before_translation():
  225 + backend = _SegmentingMarianBackend(
  226 + name="opus-mt-en-zh",
  227 + model_id="Helsinki-NLP/opus-mt-en-zh",
  228 + model_dir="./models/translation/Helsinki-NLP/opus-mt-en-zh",
  229 + device="cpu",
  230 + torch_dtype="float32",
  231 + batch_size=8,
  232 + max_input_length=24,
  233 + max_new_tokens=24,
  234 + num_beams=1,
  235 + source_langs=["en"],
  236 + target_langs=["zh"],
  237 + )
  238 +
  239 + text = (
  240 + "This soft cotton dress is breathable and lightweight, "
  241 + "works well for spring travel and everyday wear, "
  242 + "and should be split on natural clause boundaries when it gets too long."
  243 + )
  244 +
  245 + result = backend.translate(text, source_lang="en", target_lang="zh")
  246 +
  247 + assert result is not None
  248 + assert len(backend.translated_batches) == 1
  249 + assert len(backend.translated_batches[0]) >= 2
  250 + assert all(len(piece) <= 16 for piece in backend.translated_batches[0])
  251 + assert result == "".join(f"<{piece.strip()}>" for piece in backend.translated_batches[0])
... ...
translation/backends/local_ctranslate2.py
... ... @@ -14,6 +14,11 @@ from typing import Dict, List, Optional, Sequence, Union
14 14 from transformers import AutoTokenizer
15 15  
16 16 from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES
  17 +from translation.text_splitter import (
  18 + compute_safe_input_token_limit,
  19 + join_translated_segments,
  20 + split_text_for_translation,
  21 +)
17 22  
18 23 logger = logging.getLogger(__name__)
19 24  
... ... @@ -296,6 +301,82 @@ class LocalCTranslate2TranslationBackend:
296 301 outputs.append(self._decode_tokens(processed))
297 302 return outputs
298 303  
  304 + def _token_count(
  305 + self,
  306 + text: str,
  307 + target_lang: str,
  308 + source_lang: Optional[str] = None,
  309 + ) -> int:
  310 + encoded = self._encode_source_tokens([text], source_lang, target_lang)
  311 + return len(encoded[0]) if encoded else 0
  312 +
  313 + def _effective_input_token_limit(self, target_lang: str, source_lang: Optional[str] = None) -> int:
  314 + del target_lang, source_lang
  315 + return compute_safe_input_token_limit(
  316 + max_input_length=self.max_input_length,
  317 + max_new_tokens=self.max_new_tokens,
  318 + decoding_length_mode=self.ct2_decoding_length_mode,
  319 + decoding_length_extra=self.ct2_decoding_length_extra,
  320 + )
  321 +
  322 + def _split_text_if_needed(
  323 + self,
  324 + text: str,
  325 + target_lang: str,
  326 + source_lang: Optional[str] = None,
  327 + ) -> List[str]:
  328 + limit = self._effective_input_token_limit(target_lang, source_lang)
  329 + return split_text_for_translation(
  330 + text,
  331 + max_tokens=limit,
  332 + token_length_fn=lambda value: self._token_count(
  333 + value,
  334 + target_lang=target_lang,
  335 + source_lang=source_lang,
  336 + ),
  337 + )
  338 +
  339 + def _translate_with_segmentation(
  340 + self,
  341 + texts: List[str],
  342 + target_lang: str,
  343 + source_lang: Optional[str] = None,
  344 + ) -> List[Optional[str]]:
  345 + segment_plans: List[List[str]] = []
  346 + flat_segments: List[str] = []
  347 + for text in texts:
  348 + if not text.strip():
  349 + segment_plans.append([])
  350 + continue
  351 + segments = self._split_text_if_needed(text, target_lang=target_lang, source_lang=source_lang)
  352 + segment_plans.append(segments)
  353 + flat_segments.extend(segments)
  354 +
  355 + translated_segments = (
  356 + self._translate_batch(flat_segments, target_lang=target_lang, source_lang=source_lang)
  357 + if flat_segments
  358 + else []
  359 + )
  360 + outputs: List[Optional[str]] = []
  361 + offset = 0
  362 + for original_text, segments in zip(texts, segment_plans):
  363 + if not segments:
  364 + outputs.append(None if not original_text.strip() else original_text)
  365 + continue
  366 + current = translated_segments[offset:offset + len(segments)]
  367 + offset += len(segments)
  368 + if len(segments) == 1:
  369 + outputs.append(current[0])
  370 + continue
  371 + outputs.append(
  372 + join_translated_segments(
  373 + current,
  374 + target_lang=target_lang,
  375 + original_text=original_text,
  376 + )
  377 + )
  378 + return outputs
  379 +
299 380 def translate(
300 381 self,
301 382 text: Union[str, Sequence[str]],
... ... @@ -312,7 +393,7 @@ class LocalCTranslate2TranslationBackend:
312 393 if not any(item.strip() for item in chunk):
313 394 outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item]
314 395 continue
315   - outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang))
  396 + outputs.extend(self._translate_with_segmentation(chunk, target_lang=target_lang, source_lang=source_lang))
316 397 return outputs[0] if is_single else outputs
317 398  
318 399  
... ...
translation/backends/local_seq2seq.py
... ... @@ -11,6 +11,11 @@ import torch
11 11 from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12 12  
13 13 from translation.languages import MARIAN_LANGUAGE_DIRECTIONS, NLLB_LANGUAGE_CODES
  14 +from translation.text_splitter import (
  15 + compute_safe_input_token_limit,
  16 + join_translated_segments,
  17 + split_text_for_translation,
  18 +)
14 19  
15 20 logger = logging.getLogger(__name__)
16 21  
... ... @@ -149,6 +154,91 @@ class LocalSeq2SeqTranslationBackend:
149 154 outputs = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
150 155 return [item.strip() if item and item.strip() else None for item in outputs]
151 156  
  157 + def _token_count(
  158 + self,
  159 + text: str,
  160 + target_lang: str,
  161 + source_lang: Optional[str] = None,
  162 + ) -> int:
  163 + tokenizer_kwargs = self._prepare_tokenizer(source_lang, target_lang)
  164 + with self._lock:
  165 + encoded = self.tokenizer(
  166 + [text],
  167 + truncation=False,
  168 + padding=False,
  169 + **tokenizer_kwargs,
  170 + )
  171 + input_ids = encoded["input_ids"]
  172 + first_item = input_ids[0]
  173 + if hasattr(first_item, "shape"):
  174 + return int(first_item.shape[-1])
  175 + return len(first_item)
  176 +
  177 + def _effective_input_token_limit(self, target_lang: str, source_lang: Optional[str] = None) -> int:
  178 + del target_lang, source_lang
  179 + return compute_safe_input_token_limit(
  180 + max_input_length=self.max_input_length,
  181 + max_new_tokens=self.max_new_tokens,
  182 + )
  183 +
  184 + def _split_text_if_needed(
  185 + self,
  186 + text: str,
  187 + target_lang: str,
  188 + source_lang: Optional[str] = None,
  189 + ) -> List[str]:
  190 + limit = self._effective_input_token_limit(target_lang, source_lang)
  191 + return split_text_for_translation(
  192 + text,
  193 + max_tokens=limit,
  194 + token_length_fn=lambda value: self._token_count(
  195 + value,
  196 + target_lang=target_lang,
  197 + source_lang=source_lang,
  198 + ),
  199 + )
  200 +
  201 + def _translate_with_segmentation(
  202 + self,
  203 + texts: List[str],
  204 + target_lang: str,
  205 + source_lang: Optional[str] = None,
  206 + ) -> List[Optional[str]]:
  207 + segment_plans: List[List[str]] = []
  208 + flat_segments: List[str] = []
  209 + for text in texts:
  210 + if not text.strip():
  211 + segment_plans.append([])
  212 + continue
  213 + segments = self._split_text_if_needed(text, target_lang=target_lang, source_lang=source_lang)
  214 + segment_plans.append(segments)
  215 + flat_segments.extend(segments)
  216 +
  217 + translated_segments = (
  218 + self._translate_batch(flat_segments, target_lang=target_lang, source_lang=source_lang)
  219 + if flat_segments
  220 + else []
  221 + )
  222 + outputs: List[Optional[str]] = []
  223 + offset = 0
  224 + for original_text, segments in zip(texts, segment_plans):
  225 + if not segments:
  226 + outputs.append(None if not original_text.strip() else original_text)
  227 + continue
  228 + current = translated_segments[offset:offset + len(segments)]
  229 + offset += len(segments)
  230 + if len(segments) == 1:
  231 + outputs.append(current[0])
  232 + continue
  233 + outputs.append(
  234 + join_translated_segments(
  235 + current,
  236 + target_lang=target_lang,
  237 + original_text=original_text,
  238 + )
  239 + )
  240 + return outputs
  241 +
152 242 def translate(
153 243 self,
154 244 text: Union[str, Sequence[str]],
... ... @@ -165,7 +255,7 @@ class LocalSeq2SeqTranslationBackend:
165 255 if not any(item.strip() for item in chunk):
166 256 outputs.extend([None if not item.strip() else item for item in chunk]) # type: ignore[list-item]
167 257 continue
168   - outputs.extend(self._translate_batch(chunk, target_lang=target_lang, source_lang=source_lang))
  258 + outputs.extend(self._translate_with_segmentation(chunk, target_lang=target_lang, source_lang=source_lang))
169 259 return outputs[0] if is_single else outputs
170 260  
171 261  
... ...
translation/text_splitter.py 0 → 100644
... ... @@ -0,0 +1,226 @@
  1 +"""Utilities for token-budget-aware translation text splitting."""
  2 +
  3 +from __future__ import annotations
  4 +
  5 +from typing import Callable, List, Optional
  6 +
  7 +TokenLengthFn = Callable[[str], int]
  8 +
  9 +_CJK_LANGS = {"zh", "ja", "ko"}
  10 +_STRONG_BOUNDARIES = {"\n", "。", "!", "?", "!", "?", ";", ";", "…"}
  11 +_WEAK_BOUNDARIES = {",", ",", "、", ":", ":", "(", ")", "(", ")", "[", "]", "【", "】", "/", "|"}
  12 +_CLOSING_CHARS = {'"', "'", "”", "’", ")", "]", "}", ")", "】", "》", "」", "』"}
  13 +_NO_SPACE_BEFORE = tuple('.,!?;:)]}%>"\'')
  14 +_NO_SPACE_AFTER = tuple("([{$#@/<")
  15 +
  16 +
  17 +def is_cjk_language(lang: Optional[str]) -> bool:
  18 + return str(lang or "").strip().lower() in _CJK_LANGS
  19 +
  20 +
  21 +def compute_safe_input_token_limit(
  22 + *,
  23 + max_input_length: int,
  24 + max_new_tokens: int,
  25 + decoding_length_mode: str = "fixed",
  26 + decoding_length_extra: int = 0,
  27 + reserve_input_tokens: int = 8,
  28 + reserve_output_tokens: int = 8,
  29 +) -> int:
  30 + """Derive a conservative source-token budget for translation splitting.
  31 +
  32 + We keep a small reserve for tokenizer special tokens on the input side. If
  33 + the decode side is much tighter than the encode side, we also cap the
  34 + source budget based on decode settings so we split before the model is
  35 + likely to truncate.
  36 + """
  37 +
  38 + input_limit = max(8, int(max_input_length) - max(0, int(reserve_input_tokens)))
  39 + decode_mode = str(decoding_length_mode or "fixed").strip().lower()
  40 + if int(max_new_tokens) <= 0:
  41 + return input_limit
  42 + if decode_mode == "source":
  43 + output_limit = max(8, int(max_new_tokens) - max(0, int(decoding_length_extra)))
  44 + return max(8, min(input_limit, output_limit))
  45 + if int(max_new_tokens) >= int(max_input_length):
  46 + return input_limit
  47 + output_limit = max(8, int(max_new_tokens) - max(0, int(reserve_output_tokens)))
  48 + return max(8, min(input_limit, output_limit))
  49 +
  50 +
  51 +def split_text_for_translation(
  52 + text: str,
  53 + *,
  54 + max_tokens: int,
  55 + token_length_fn: TokenLengthFn,
  56 +) -> List[str]:
  57 + """Split long text into a few translation-friendly segments.
  58 +
  59 + The splitter prefers sentence boundaries, then clause boundaries, then
  60 + whitespace, and only falls back to character-based splitting when needed.
  61 + """
  62 +
  63 + if not text:
  64 + return [text]
  65 + if token_length_fn(text) <= max_tokens:
  66 + return [text]
  67 + segments = _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=0)
  68 + return [segment for segment in segments if segment]
  69 +
  70 +
  71 +def join_translated_segments(
  72 + segments: List[Optional[str]],
  73 + *,
  74 + target_lang: Optional[str],
  75 + original_text: str,
  76 +) -> Optional[str]:
  77 + parts = [segment.strip() for segment in segments if segment and segment.strip()]
  78 + if not parts:
  79 + return None
  80 + separator = "" if is_cjk_language(target_lang) else " "
  81 + if "\n" in original_text and separator:
  82 + separator = "\n"
  83 +
  84 + merged = parts[0]
  85 + for part in parts[1:]:
  86 + if not separator:
  87 + merged += part
  88 + continue
  89 + if merged.endswith(_NO_SPACE_AFTER) or part.startswith(_NO_SPACE_BEFORE):
  90 + merged += part
  91 + continue
  92 + merged += separator + part
  93 + return merged.strip() or None
  94 +
  95 +
  96 +def _split_recursive(
  97 + text: str,
  98 + *,
  99 + max_tokens: int,
  100 + token_length_fn: TokenLengthFn,
  101 + level: int,
  102 +) -> List[str]:
  103 + if token_length_fn(text) <= max_tokens:
  104 + return [text]
  105 + if level >= 3:
  106 + return _hard_split(text, max_tokens=max_tokens, token_length_fn=token_length_fn)
  107 +
  108 + pieces = _split_by_level(text, level)
  109 + if len(pieces) <= 1:
  110 + return _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)
  111 +
  112 + merged: List[str] = []
  113 + buffer = ""
  114 + for piece in pieces:
  115 + candidate = buffer + piece if buffer else piece
  116 + if token_length_fn(candidate) <= max_tokens:
  117 + buffer = candidate
  118 + continue
  119 + if buffer:
  120 + merged.extend(
  121 + _split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)
  122 + )
  123 + buffer = piece
  124 + continue
  125 + merged.extend(_split_recursive(piece, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1))
  126 + if buffer:
  127 + merged.extend(_split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1))
  128 + return merged
  129 +
  130 +
  131 +def _split_by_level(text: str, level: int) -> List[str]:
  132 + parts: List[str] = []
  133 + start = 0
  134 + index = 0
  135 + while index < len(text):
  136 + boundary_end = _match_boundary(text, index, level)
  137 + if boundary_end is None:
  138 + index += 1
  139 + continue
  140 + if boundary_end > start:
  141 + parts.append(text[start:boundary_end])
  142 + start = boundary_end
  143 + index = boundary_end
  144 + if start < len(text):
  145 + parts.append(text[start:])
  146 + return [part for part in parts if part]
  147 +
  148 +
  149 +def _match_boundary(text: str, index: int, level: int) -> Optional[int]:
  150 + char = text[index]
  151 + if level == 0:
  152 + if char in _STRONG_BOUNDARIES:
  153 + return _consume_boundary_tail(text, index + 1)
  154 + if char == "." and _is_sentence_period(text, index):
  155 + return _consume_boundary_tail(text, index + 1)
  156 + return None
  157 + if level == 1:
  158 + if char in _WEAK_BOUNDARIES:
  159 + return _consume_boundary_tail(text, index + 1)
  160 + return None
  161 + if level == 2 and char.isspace():
  162 + end = index + 1
  163 + while end < len(text) and text[end].isspace():
  164 + end += 1
  165 + return end
  166 + return None
  167 +
  168 +
  169 +def _consume_boundary_tail(text: str, index: int) -> int:
  170 + end = index
  171 + while end < len(text) and text[end] in _CLOSING_CHARS:
  172 + end += 1
  173 + while end < len(text) and text[end].isspace():
  174 + end += 1
  175 + return end
  176 +
  177 +
  178 +def _is_sentence_period(text: str, index: int) -> bool:
  179 + prev_char = text[index - 1] if index > 0 else ""
  180 + next_char = text[index + 1] if index + 1 < len(text) else ""
  181 + if prev_char.isdigit() and next_char.isdigit():
  182 + return False
  183 + if not next_char:
  184 + return True
  185 + return next_char.isspace() or next_char in _CLOSING_CHARS
  186 +
  187 +
  188 +def _hard_split(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> List[str]:
  189 + segments: List[str] = []
  190 + remaining = text
  191 + while remaining:
  192 + if token_length_fn(remaining) <= max_tokens:
  193 + segments.append(remaining)
  194 + break
  195 + cut = _largest_prefix_within_limit(remaining, max_tokens=max_tokens, token_length_fn=token_length_fn)
  196 + refined_cut = _refine_cut(remaining, cut, max_tokens=max_tokens, token_length_fn=token_length_fn)
  197 + if refined_cut <= 0:
  198 + refined_cut = max(1, cut)
  199 + segments.append(remaining[:refined_cut])
  200 + remaining = remaining[refined_cut:]
  201 + return segments
  202 +
  203 +
  204 +def _largest_prefix_within_limit(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int:
  205 + low = 1
  206 + high = len(text)
  207 + best = 1
  208 + while low <= high:
  209 + mid = (low + high) // 2
  210 + if token_length_fn(text[:mid]) <= max_tokens:
  211 + best = mid
  212 + low = mid + 1
  213 + continue
  214 + high = mid - 1
  215 + return best
  216 +
  217 +
  218 +def _refine_cut(text: str, cut: int, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int:
  219 + best = cut
  220 + lower_bound = max(1, cut - 32)
  221 + for candidate in range(cut, lower_bound - 1, -1):
  222 + if text[candidate - 1].isspace() or text[candidate - 1] in _STRONG_BOUNDARIES or text[candidate - 1] in _WEAK_BOUNDARIES:
  223 + if candidate >= max(1, cut // 2) and token_length_fn(text[:candidate]) <= max_tokens:
  224 + return candidate
  225 + best = max(best, candidate)
  226 + return best
... ...