"""Utilities for token-budget-aware translation text splitting.""" from __future__ import annotations from typing import Callable, List, Optional TokenLengthFn = Callable[[str], int] _CJK_LANGS = {"zh", "ja", "ko"} _STRONG_BOUNDARIES = {"\n", "。", "!", "?", "!", "?", ";", ";", "…"} _WEAK_BOUNDARIES = {",", ",", "、", ":", ":", "(", ")", "(", ")", "[", "]", "【", "】", "/", "|"} _CLOSING_CHARS = {'"', "'", "”", "’", ")", "]", "}", ")", "】", "》", "」", "』"} _NO_SPACE_BEFORE = tuple('.,!?;:)]}%>"\'') _NO_SPACE_AFTER = tuple("([{$#@/<") def is_cjk_language(lang: Optional[str]) -> bool: return str(lang or "").strip().lower() in _CJK_LANGS def compute_safe_input_token_limit( *, max_input_length: int, max_new_tokens: int, decoding_length_mode: str = "fixed", decoding_length_extra: int = 0, reserve_input_tokens: int = 8, reserve_output_tokens: int = 8, ) -> int: """Derive a conservative source-token budget for translation splitting. We keep a small reserve for tokenizer special tokens on the input side. If the decode side is much tighter than the encode side, we also cap the source budget based on decode settings so we split before the model is likely to truncate. """ input_limit = max(8, int(max_input_length) - max(0, int(reserve_input_tokens))) decode_mode = str(decoding_length_mode or "fixed").strip().lower() if int(max_new_tokens) <= 0: return input_limit if decode_mode == "source": output_limit = max(8, int(max_new_tokens) - max(0, int(decoding_length_extra))) return max(8, min(input_limit, output_limit)) if int(max_new_tokens) >= int(max_input_length): return input_limit output_limit = max(8, int(max_new_tokens) - max(0, int(reserve_output_tokens))) return max(8, min(input_limit, output_limit)) def split_text_for_translation( text: str, *, max_tokens: int, token_length_fn: TokenLengthFn, ) -> List[str]: """Split long text into a few translation-friendly segments. The splitter prefers sentence boundaries, then clause boundaries, then whitespace, and only falls back to character-based splitting when needed. """ if not text: return [text] if token_length_fn(text) <= max_tokens: return [text] segments = _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=0) return [segment for segment in segments if segment] def join_translated_segments( segments: List[Optional[str]], *, target_lang: Optional[str], original_text: str, ) -> Optional[str]: parts = [segment.strip() for segment in segments if segment and segment.strip()] if not parts: return None separator = "" if is_cjk_language(target_lang) else " " if "\n" in original_text and separator: separator = "\n" merged = parts[0] for part in parts[1:]: if not separator: merged += part continue if merged.endswith(_NO_SPACE_AFTER) or part.startswith(_NO_SPACE_BEFORE): merged += part continue merged += separator + part return merged.strip() or None def _split_recursive( text: str, *, max_tokens: int, token_length_fn: TokenLengthFn, level: int, ) -> List[str]: if token_length_fn(text) <= max_tokens: return [text] if level >= 3: return _hard_split(text, max_tokens=max_tokens, token_length_fn=token_length_fn) pieces = _split_by_level(text, level) if len(pieces) <= 1: return _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1) merged: List[str] = [] buffer = "" for piece in pieces: candidate = buffer + piece if buffer else piece if token_length_fn(candidate) <= max_tokens: buffer = candidate continue if buffer: merged.extend( _split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1) ) buffer = piece continue merged.extend(_split_recursive(piece, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)) if buffer: merged.extend(_split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)) return merged def _split_by_level(text: str, level: int) -> List[str]: parts: List[str] = [] start = 0 index = 0 while index < len(text): boundary_end = _match_boundary(text, index, level) if boundary_end is None: index += 1 continue if boundary_end > start: parts.append(text[start:boundary_end]) start = boundary_end index = boundary_end if start < len(text): parts.append(text[start:]) return [part for part in parts if part] def _match_boundary(text: str, index: int, level: int) -> Optional[int]: char = text[index] if level == 0: if char in _STRONG_BOUNDARIES: return _consume_boundary_tail(text, index + 1) if char == "." and _is_sentence_period(text, index): return _consume_boundary_tail(text, index + 1) return None if level == 1: if char in _WEAK_BOUNDARIES: return _consume_boundary_tail(text, index + 1) return None if level == 2 and char.isspace(): end = index + 1 while end < len(text) and text[end].isspace(): end += 1 return end return None def _consume_boundary_tail(text: str, index: int) -> int: end = index while end < len(text) and text[end] in _CLOSING_CHARS: end += 1 while end < len(text) and text[end].isspace(): end += 1 return end def _is_sentence_period(text: str, index: int) -> bool: prev_char = text[index - 1] if index > 0 else "" next_char = text[index + 1] if index + 1 < len(text) else "" if prev_char.isdigit() and next_char.isdigit(): return False if not next_char: return True return next_char.isspace() or next_char in _CLOSING_CHARS def _hard_split(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> List[str]: segments: List[str] = [] remaining = text while remaining: if token_length_fn(remaining) <= max_tokens: segments.append(remaining) break cut = _largest_prefix_within_limit(remaining, max_tokens=max_tokens, token_length_fn=token_length_fn) refined_cut = _refine_cut(remaining, cut, max_tokens=max_tokens, token_length_fn=token_length_fn) if refined_cut <= 0: refined_cut = max(1, cut) segments.append(remaining[:refined_cut]) remaining = remaining[refined_cut:] return segments def _largest_prefix_within_limit(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int: low = 1 high = len(text) best = 1 while low <= high: mid = (low + high) // 2 if token_length_fn(text[:mid]) <= max_tokens: best = mid low = mid + 1 continue high = mid - 1 return best def _refine_cut(text: str, cut: int, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int: best = cut lower_bound = max(1, cut - 32) for candidate in range(cut, lower_bound - 1, -1): if text[candidate - 1].isspace() or text[candidate - 1] in _STRONG_BOUNDARIES or text[candidate - 1] in _WEAK_BOUNDARIES: if candidate >= max(1, cut // 2) and token_length_fn(text[:candidate]) <= max_tokens: return candidate best = max(best, candidate) return best