test_translation_local_backends.py 7.8 KB
import torch

from translation.backends.local_seq2seq import MarianMTTranslationBackend, NLLBTranslationBackend
from translation.service import TranslationService
from translation.text_splitter import compute_safe_input_token_limit, split_text_for_translation


class _FakeBatch(dict):
    def to(self, device):
        self["device"] = device
        return self


class _FakeTokenizer:
    def __init__(self):
        self.src_lang = None
        self.pad_token = "</s>"
        self.eos_token = "</s>"
        self.lang_code_to_id = {"eng_Latn": 101, "zho_Hans": 202}
        self.last_call = None

    def __call__(self, texts, **kwargs):
        self.last_call = {"texts": list(texts), **kwargs}
        return _FakeBatch({"input_ids": torch.tensor([[1, 2, 3]])})

    def batch_decode(self, generated, skip_special_tokens=True):
        del generated, skip_special_tokens
        return ["translated" for _ in range(len(self.last_call["texts"]))]

    def convert_tokens_to_ids(self, token):
        return self.lang_code_to_id[token]


class _FakeModel:
    def to(self, device):
        self.device = device
        return self

    def eval(self):
        return self

    def generate(self, **kwargs):
        self.last_generate_kwargs = kwargs
        return [[42]]


def _stub_load_model(self):
    self.tokenizer = _FakeTokenizer()
    self.seq2seq_model = _FakeModel()


def test_marian_language_validation(monkeypatch):
    monkeypatch.setattr(MarianMTTranslationBackend, "_load_model", _stub_load_model)
    backend = MarianMTTranslationBackend(
        name="opus-mt-zh-en",
        model_id="Helsinki-NLP/opus-mt-zh-en",
        model_dir="./models/translation/Helsinki-NLP/opus-mt-zh-en",
        device="cpu",
        torch_dtype="float32",
        batch_size=1,
        max_input_length=16,
        max_new_tokens=16,
        num_beams=1,
        source_langs=["zh"],
        target_langs=["en"],
    )

    result = backend.translate("测试", source_lang="zh", target_lang="en")
    assert result == "translated"

    try:
        backend.translate("test", source_lang="en", target_lang="zh")
    except ValueError as exc:
        assert "source languages" in str(exc)
    else:
        raise AssertionError("Expected unsupported source language to raise")


def test_nllb_uses_src_lang_and_forced_bos(monkeypatch):
    monkeypatch.setattr(NLLBTranslationBackend, "_load_model", _stub_load_model)
    backend = NLLBTranslationBackend(
        name="nllb-200-distilled-600m",
        model_id="facebook/nllb-200-distilled-600M",
        model_dir="./models/translation/facebook/nllb-200-distilled-600M",
        device="cpu",
        torch_dtype="float32",
        batch_size=1,
        max_input_length=16,
        max_new_tokens=16,
        num_beams=1,
    )

    result = backend.translate("test", source_lang="en", target_lang="zh")

    assert result == "translated"
    assert backend.tokenizer.src_lang == "eng_Latn"
    assert backend.seq2seq_model.last_generate_kwargs["forced_bos_token_id"] == 202


def test_translation_service_preloads_enabled_backends(monkeypatch):
    created = []

    def _fake_create_backend(self, *, name, backend_type, cfg):
        del self, cfg
        created.append((name, backend_type))

        class _Backend:
            model = name

            @property
            def supports_batch(self):
                return True

            def translate(self, text, target_lang, source_lang=None, scene=None):
                del target_lang, source_lang, scene
                return text

        return _Backend()

    monkeypatch.setattr(TranslationService, "_create_backend", _fake_create_backend)
    config = {
        "service_url": "http://127.0.0.1:6006",
        "timeout_sec": 10.0,
        "default_model": "opus-mt-en-zh",
        "default_scene": "general",
        "capabilities": {
            "opus-mt-en-zh": {
                "enabled": True,
                "backend": "local_marian",
                "use_cache": True,
                "model_id": "dummy",
                "model_dir": "dummy",
                "device": "cpu",
                "torch_dtype": "float32",
                "batch_size": 1,
                "max_input_length": 8,
                "max_new_tokens": 8,
                "num_beams": 1,
            },
            "nllb-200-distilled-600m": {
                "enabled": True,
                "backend": "local_nllb",
                "use_cache": True,
                "model_id": "dummy",
                "model_dir": "dummy",
                "device": "cpu",
                "torch_dtype": "float32",
                "batch_size": 1,
                "max_input_length": 8,
                "max_new_tokens": 8,
                "num_beams": 1,
            },
        },
        "cache": {
            "ttl_seconds": 60,
            "sliding_expiration": True,
        },
    }

    service = TranslationService(config)

    assert service.available_models == ["opus-mt-en-zh", "nllb-200-distilled-600m"]
    assert service.loaded_models == ["opus-mt-en-zh", "nllb-200-distilled-600m"]
    assert created == [
        ("opus-mt-en-zh", "local_marian"),
        ("nllb-200-distilled-600m", "local_nllb"),
    ]

    backend = service.get_backend("opus-mt-en-zh")
    assert backend.model == "opus-mt-en-zh"


def test_compute_safe_input_token_limit_uses_decode_constraints():
    nllb_limit = compute_safe_input_token_limit(
        max_input_length=256,
        max_new_tokens=64,
        decoding_length_mode="source",
        decoding_length_extra=8,
    )
    opus_limit = compute_safe_input_token_limit(
        max_input_length=256,
        max_new_tokens=256,
    )

    assert nllb_limit == 56
    assert opus_limit == 248


def test_split_text_for_translation_prefers_sentence_boundaries():
    text = (
        "这是一条很长的中文商品描述,包含材质、尺码和适用场景。"
        "适合春夏通勤,也适合日常出街穿搭;"
        "如果长度超了,应该优先按完整语义分句,而不是切成很碎的小片段。"
    )

    segments = split_text_for_translation(
        text,
        max_tokens=36,
        token_length_fn=len,
    )

    assert len(segments) >= 2
    assert "".join(segments) == text
    assert all(len(segment) <= 36 for segment in segments)
    assert segments[0].endswith(("。", ";"))


class _SegmentingMarianBackend(MarianMTTranslationBackend):
    def _load_model(self):
        self.translated_batches = []

    def _token_count(self, text, target_lang, source_lang=None):
        del target_lang, source_lang
        return len(text)

    def _translate_batch(self, texts, target_lang, source_lang=None):
        del source_lang
        self.translated_batches.append(list(texts))
        if target_lang == "zh":
            return [f"<{text.strip()}>" for text in texts]
        return [f"[{text.strip()}]" for text in texts]


def test_local_backend_splits_oversized_text_before_translation():
    backend = _SegmentingMarianBackend(
        name="opus-mt-en-zh",
        model_id="Helsinki-NLP/opus-mt-en-zh",
        model_dir="./models/translation/Helsinki-NLP/opus-mt-en-zh",
        device="cpu",
        torch_dtype="float32",
        batch_size=8,
        max_input_length=24,
        max_new_tokens=24,
        num_beams=1,
        source_langs=["en"],
        target_langs=["zh"],
    )

    text = (
        "This soft cotton dress is breathable and lightweight, "
        "works well for spring travel and everyday wear, "
        "and should be split on natural clause boundaries when it gets too long."
    )

    result = backend.translate(text, source_lang="en", target_lang="zh")

    assert result is not None
    assert len(backend.translated_batches) == 1
    assert len(backend.translated_batches[0]) >= 2
    assert all(len(piece) <= 16 for piece in backend.translated_batches[0])
    assert result == "".join(f"<{piece.strip()}>" for piece in backend.translated_batches[0])