import torch from translation.backends.local_seq2seq import MarianMTTranslationBackend, NLLBTranslationBackend from translation.service import TranslationService class _FakeBatch(dict): def to(self, device): self["device"] = device return self class _FakeTokenizer: def __init__(self): self.src_lang = None self.pad_token = "" self.eos_token = "" self.lang_code_to_id = {"eng_Latn": 101, "zho_Hans": 202} self.last_call = None def __call__(self, texts, **kwargs): self.last_call = {"texts": list(texts), **kwargs} return _FakeBatch({"input_ids": torch.tensor([[1, 2, 3]])}) def batch_decode(self, generated, skip_special_tokens=True): del generated, skip_special_tokens return ["translated" for _ in range(len(self.last_call["texts"]))] def convert_tokens_to_ids(self, token): return self.lang_code_to_id[token] class _FakeModel: def to(self, device): self.device = device return self def eval(self): return self def generate(self, **kwargs): self.last_generate_kwargs = kwargs return [[42]] def _stub_load_model(self): self.tokenizer = _FakeTokenizer() self.seq2seq_model = _FakeModel() def test_marian_language_validation(monkeypatch): monkeypatch.setattr(MarianMTTranslationBackend, "_load_model", _stub_load_model) backend = MarianMTTranslationBackend( name="opus-mt-zh-en", model_id="Helsinki-NLP/opus-mt-zh-en", model_dir="./models/translation/Helsinki-NLP/opus-mt-zh-en", device="cpu", torch_dtype="float32", batch_size=1, max_input_length=16, max_new_tokens=16, num_beams=1, source_langs=["zh"], target_langs=["en"], ) result = backend.translate("测试", source_lang="zh", target_lang="en") assert result == "translated" try: backend.translate("test", source_lang="en", target_lang="zh") except ValueError as exc: assert "source languages" in str(exc) else: raise AssertionError("Expected unsupported source language to raise") def test_nllb_uses_src_lang_and_forced_bos(monkeypatch): monkeypatch.setattr(NLLBTranslationBackend, "_load_model", _stub_load_model) backend = NLLBTranslationBackend( name="nllb-200-distilled-600m", model_id="facebook/nllb-200-distilled-600M", model_dir="./models/translation/facebook/nllb-200-distilled-600M", device="cpu", torch_dtype="float32", batch_size=1, max_input_length=16, max_new_tokens=16, num_beams=1, ) result = backend.translate("test", source_lang="en", target_lang="zh") assert result == "translated" assert backend.tokenizer.src_lang == "eng_Latn" assert backend.seq2seq_model.last_generate_kwargs["forced_bos_token_id"] == 202 def test_translation_service_preloads_enabled_backends(monkeypatch): created = [] def _fake_create_backend(self, *, name, backend_type, cfg): del self, cfg created.append((name, backend_type)) class _Backend: model = name @property def supports_batch(self): return True def translate(self, text, target_lang, source_lang=None, scene=None): del target_lang, source_lang, scene return text return _Backend() monkeypatch.setattr(TranslationService, "_create_backend", _fake_create_backend) config = { "service_url": "http://127.0.0.1:6006", "timeout_sec": 10.0, "default_model": "opus-mt-en-zh", "default_scene": "general", "capabilities": { "opus-mt-en-zh": { "enabled": True, "backend": "local_marian", "use_cache": True, "model_id": "dummy", "model_dir": "dummy", "device": "cpu", "torch_dtype": "float32", "batch_size": 1, "max_input_length": 8, "max_new_tokens": 8, "num_beams": 1, }, "nllb-200-distilled-600m": { "enabled": True, "backend": "local_nllb", "use_cache": True, "model_id": "dummy", "model_dir": "dummy", "device": "cpu", "torch_dtype": "float32", "batch_size": 1, "max_input_length": 8, "max_new_tokens": 8, "num_beams": 1, }, }, "cache": { "ttl_seconds": 60, "sliding_expiration": True, }, } service = TranslationService(config) assert service.available_models == ["opus-mt-en-zh", "nllb-200-distilled-600m"] assert service.loaded_models == ["opus-mt-en-zh", "nllb-200-distilled-600m"] assert created == [ ("opus-mt-en-zh", "local_marian"), ("nllb-200-distilled-600m", "local_nllb"), ] backend = service.get_backend("opus-mt-en-zh") assert backend.model == "opus-mt-en-zh"