test_translation_local_backends.py 5.09 KB
import torch

from translation.backends.local_seq2seq import MarianMTTranslationBackend, NLLBTranslationBackend
from translation.service import TranslationService


class _FakeBatch(dict):
    def to(self, device):
        self["device"] = device
        return self


class _FakeTokenizer:
    def __init__(self):
        self.src_lang = None
        self.pad_token = "</s>"
        self.eos_token = "</s>"
        self.lang_code_to_id = {"eng_Latn": 101, "zho_Hans": 202}
        self.last_call = None

    def __call__(self, texts, **kwargs):
        self.last_call = {"texts": list(texts), **kwargs}
        return _FakeBatch({"input_ids": torch.tensor([[1, 2, 3]])})

    def batch_decode(self, generated, skip_special_tokens=True):
        del generated, skip_special_tokens
        return ["translated" for _ in range(len(self.last_call["texts"]))]

    def convert_tokens_to_ids(self, token):
        return self.lang_code_to_id[token]


class _FakeModel:
    def to(self, device):
        self.device = device
        return self

    def eval(self):
        return self

    def generate(self, **kwargs):
        self.last_generate_kwargs = kwargs
        return [[42]]


def _stub_load_model(self):
    self.tokenizer = _FakeTokenizer()
    self.seq2seq_model = _FakeModel()


def test_marian_language_validation(monkeypatch):
    monkeypatch.setattr(MarianMTTranslationBackend, "_load_model", _stub_load_model)
    backend = MarianMTTranslationBackend(
        name="opus-mt-zh-en",
        model_id="Helsinki-NLP/opus-mt-zh-en",
        model_dir="./models/translation/Helsinki-NLP/opus-mt-zh-en",
        device="cpu",
        torch_dtype="float32",
        batch_size=1,
        max_input_length=16,
        max_new_tokens=16,
        num_beams=1,
        source_langs=["zh"],
        target_langs=["en"],
    )

    result = backend.translate("测试", source_lang="zh", target_lang="en")
    assert result == "translated"

    try:
        backend.translate("test", source_lang="en", target_lang="zh")
    except ValueError as exc:
        assert "source languages" in str(exc)
    else:
        raise AssertionError("Expected unsupported source language to raise")


def test_nllb_uses_src_lang_and_forced_bos(monkeypatch):
    monkeypatch.setattr(NLLBTranslationBackend, "_load_model", _stub_load_model)
    backend = NLLBTranslationBackend(
        name="nllb-200-distilled-600m",
        model_id="facebook/nllb-200-distilled-600M",
        model_dir="./models/translation/facebook/nllb-200-distilled-600M",
        device="cpu",
        torch_dtype="float32",
        batch_size=1,
        max_input_length=16,
        max_new_tokens=16,
        num_beams=1,
    )

    result = backend.translate("test", source_lang="en", target_lang="zh")

    assert result == "translated"
    assert backend.tokenizer.src_lang == "eng_Latn"
    assert backend.seq2seq_model.last_generate_kwargs["forced_bos_token_id"] == 202


def test_translation_service_preloads_enabled_backends(monkeypatch):
    created = []

    def _fake_create_backend(self, *, name, backend_type, cfg):
        del self, cfg
        created.append((name, backend_type))

        class _Backend:
            model = name

            @property
            def supports_batch(self):
                return True

            def translate(self, text, target_lang, source_lang=None, scene=None):
                del target_lang, source_lang, scene
                return text

        return _Backend()

    monkeypatch.setattr(TranslationService, "_create_backend", _fake_create_backend)
    config = {
        "service_url": "http://127.0.0.1:6006",
        "timeout_sec": 10.0,
        "default_model": "opus-mt-en-zh",
        "default_scene": "general",
        "capabilities": {
            "opus-mt-en-zh": {
                "enabled": True,
                "backend": "local_marian",
                "use_cache": True,
                "model_id": "dummy",
                "model_dir": "dummy",
                "device": "cpu",
                "torch_dtype": "float32",
                "batch_size": 1,
                "max_input_length": 8,
                "max_new_tokens": 8,
                "num_beams": 1,
            },
            "nllb-200-distilled-600m": {
                "enabled": True,
                "backend": "local_nllb",
                "use_cache": True,
                "model_id": "dummy",
                "model_dir": "dummy",
                "device": "cpu",
                "torch_dtype": "float32",
                "batch_size": 1,
                "max_input_length": 8,
                "max_new_tokens": 8,
                "num_beams": 1,
            },
        },
        "cache": {
            "ttl_seconds": 60,
            "sliding_expiration": True,
        },
    }

    service = TranslationService(config)

    assert service.available_models == ["opus-mt-en-zh", "nllb-200-distilled-600m"]
    assert service.loaded_models == ["opus-mt-en-zh", "nllb-200-distilled-600m"]
    assert created == [
        ("opus-mt-en-zh", "local_marian"),
        ("nllb-200-distilled-600m", "local_nllb"),
    ]

    backend = service.get_backend("opus-mt-en-zh")
    assert backend.model == "opus-mt-en-zh"