text_splitter.py 7.6 KB
"""Utilities for token-budget-aware translation text splitting."""

from __future__ import annotations

from typing import Callable, List, Optional

TokenLengthFn = Callable[[str], int]

_CJK_LANGS = {"zh", "ja", "ko"}
_STRONG_BOUNDARIES = {"\n", "。", "!", "?", "!", "?", ";", ";", "…"}
_WEAK_BOUNDARIES = {",", ",", "、", ":", ":", "(", ")", "(", ")", "[", "]", "【", "】", "/", "|"}
_CLOSING_CHARS = {'"', "'", "”", "’", ")", "]", "}", ")", "】", "》", "」", "』"}
_NO_SPACE_BEFORE = tuple('.,!?;:)]}%>"\'')
_NO_SPACE_AFTER = tuple("([{$#@/<")


def is_cjk_language(lang: Optional[str]) -> bool:
    return str(lang or "").strip().lower() in _CJK_LANGS


def compute_safe_input_token_limit(
    *,
    max_input_length: int,
    max_new_tokens: int,
    decoding_length_mode: str = "fixed",
    decoding_length_extra: int = 0,
    reserve_input_tokens: int = 8,
    reserve_output_tokens: int = 8,
) -> int:
    """Derive a conservative source-token budget for translation splitting.

    We keep a small reserve for tokenizer special tokens on the input side. If
    the decode side is much tighter than the encode side, we also cap the
    source budget based on decode settings so we split before the model is
    likely to truncate.
    """

    input_limit = max(8, int(max_input_length) - max(0, int(reserve_input_tokens)))
    decode_mode = str(decoding_length_mode or "fixed").strip().lower()
    if int(max_new_tokens) <= 0:
        return input_limit
    if decode_mode == "source":
        output_limit = max(8, int(max_new_tokens) - max(0, int(decoding_length_extra)))
        return max(8, min(input_limit, output_limit))
    if int(max_new_tokens) >= int(max_input_length):
        return input_limit
    output_limit = max(8, int(max_new_tokens) - max(0, int(reserve_output_tokens)))
    return max(8, min(input_limit, output_limit))


def split_text_for_translation(
    text: str,
    *,
    max_tokens: int,
    token_length_fn: TokenLengthFn,
) -> List[str]:
    """Split long text into a few translation-friendly segments.

    The splitter prefers sentence boundaries, then clause boundaries, then
    whitespace, and only falls back to character-based splitting when needed.
    """

    if not text:
        return [text]
    if token_length_fn(text) <= max_tokens:
        return [text]
    segments = _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=0)
    return [segment for segment in segments if segment]


def join_translated_segments(
    segments: List[Optional[str]],
    *,
    target_lang: Optional[str],
    original_text: str,
) -> Optional[str]:
    parts = [segment.strip() for segment in segments if segment and segment.strip()]
    if not parts:
        return None
    separator = "" if is_cjk_language(target_lang) else " "
    if "\n" in original_text and separator:
        separator = "\n"

    merged = parts[0]
    for part in parts[1:]:
        if not separator:
            merged += part
            continue
        if merged.endswith(_NO_SPACE_AFTER) or part.startswith(_NO_SPACE_BEFORE):
            merged += part
            continue
        merged += separator + part
    return merged.strip() or None


def _split_recursive(
    text: str,
    *,
    max_tokens: int,
    token_length_fn: TokenLengthFn,
    level: int,
) -> List[str]:
    if token_length_fn(text) <= max_tokens:
        return [text]
    if level >= 3:
        return _hard_split(text, max_tokens=max_tokens, token_length_fn=token_length_fn)

    pieces = _split_by_level(text, level)
    if len(pieces) <= 1:
        return _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)

    merged: List[str] = []
    buffer = ""
    for piece in pieces:
        candidate = buffer + piece if buffer else piece
        if token_length_fn(candidate) <= max_tokens:
            buffer = candidate
            continue
        if buffer:
            merged.extend(
                _split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)
            )
            buffer = piece
            continue
        merged.extend(_split_recursive(piece, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1))
    if buffer:
        merged.extend(_split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1))
    return merged


def _split_by_level(text: str, level: int) -> List[str]:
    parts: List[str] = []
    start = 0
    index = 0
    while index < len(text):
        boundary_end = _match_boundary(text, index, level)
        if boundary_end is None:
            index += 1
            continue
        if boundary_end > start:
            parts.append(text[start:boundary_end])
            start = boundary_end
        index = boundary_end
    if start < len(text):
        parts.append(text[start:])
    return [part for part in parts if part]


def _match_boundary(text: str, index: int, level: int) -> Optional[int]:
    char = text[index]
    if level == 0:
        if char in _STRONG_BOUNDARIES:
            return _consume_boundary_tail(text, index + 1)
        if char == "." and _is_sentence_period(text, index):
            return _consume_boundary_tail(text, index + 1)
        return None
    if level == 1:
        if char in _WEAK_BOUNDARIES:
            return _consume_boundary_tail(text, index + 1)
        return None
    if level == 2 and char.isspace():
        end = index + 1
        while end < len(text) and text[end].isspace():
            end += 1
        return end
    return None


def _consume_boundary_tail(text: str, index: int) -> int:
    end = index
    while end < len(text) and text[end] in _CLOSING_CHARS:
        end += 1
    while end < len(text) and text[end].isspace():
        end += 1
    return end


def _is_sentence_period(text: str, index: int) -> bool:
    prev_char = text[index - 1] if index > 0 else ""
    next_char = text[index + 1] if index + 1 < len(text) else ""
    if prev_char.isdigit() and next_char.isdigit():
        return False
    if not next_char:
        return True
    return next_char.isspace() or next_char in _CLOSING_CHARS


def _hard_split(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> List[str]:
    segments: List[str] = []
    remaining = text
    while remaining:
        if token_length_fn(remaining) <= max_tokens:
            segments.append(remaining)
            break
        cut = _largest_prefix_within_limit(remaining, max_tokens=max_tokens, token_length_fn=token_length_fn)
        refined_cut = _refine_cut(remaining, cut, max_tokens=max_tokens, token_length_fn=token_length_fn)
        if refined_cut <= 0:
            refined_cut = max(1, cut)
        segments.append(remaining[:refined_cut])
        remaining = remaining[refined_cut:]
    return segments


def _largest_prefix_within_limit(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int:
    low = 1
    high = len(text)
    best = 1
    while low <= high:
        mid = (low + high) // 2
        if token_length_fn(text[:mid]) <= max_tokens:
            best = mid
            low = mid + 1
            continue
        high = mid - 1
    return best


def _refine_cut(text: str, cut: int, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int:
    best = cut
    lower_bound = max(1, cut - 32)
    for candidate in range(cut, lower_bound - 1, -1):
        if text[candidate - 1].isspace() or text[candidate - 1] in _STRONG_BOUNDARIES or text[candidate - 1] in _WEAK_BOUNDARIES:
            if candidate >= max(1, cut // 2) and token_length_fn(text[:candidate]) <= max_tokens:
                return candidate
            best = max(best, candidate)
    return best