test_translator_failure_semantics.py 6.42 KB
import logging

from translation.cache import TranslationCache
from translation.logging_utils import (
    TranslationRequestFilter,
    bind_translation_request_id,
    reset_translation_request_id,
)
from translation.service import TranslationService


class _FakeCache:
    def __init__(self):
        self.available = True
        self.storage = {}
        self.get_calls = []
        self.set_calls = []

    def get(self, *, model, target_lang, source_text):
        self.get_calls.append((model, target_lang, source_text))
        return self.storage.get((model, target_lang, source_text))

    def set(self, *, model, target_lang, source_text, translated_text):
        self.set_calls.append((model, target_lang, source_text, translated_text))
        self.storage[(model, target_lang, source_text)] = translated_text


def test_translation_cache_key_format(monkeypatch):
    monkeypatch.setattr(TranslationCache, "_init_redis_client", staticmethod(lambda: None))
    cache = TranslationCache({"ttl_seconds": 60, "sliding_expiration": True})
    key = cache.build_key(model="llm", target_lang="en", source_text="商品标题")
    assert key.startswith("trans:llm:en:商品标题")
    assert len(key) == len("trans:llm:en:商品标题") + 64


def test_service_caches_all_capabilities(monkeypatch):
    monkeypatch.setattr(TranslationCache, "_init_redis_client", staticmethod(lambda: None))
    created = {}

    def _fake_create_backend(self, *, name, backend_type, cfg):
        del self, backend_type, cfg

        class _Backend:
            model = name

            @property
            def supports_batch(self):
                return True

            def translate(self, text, target_lang, source_lang=None, scene=None):
                del target_lang, source_lang, scene
                if isinstance(text, list):
                    return [f"{name}:{item}" for item in text]
                return f"{name}:{text}"

        backend = _Backend()
        created[name] = backend
        return backend

    monkeypatch.setattr(TranslationService, "_create_backend", _fake_create_backend)
    config = {
        "service_url": "http://127.0.0.1:6006",
        "timeout_sec": 10.0,
        "default_model": "llm",
        "default_scene": "general",
        "capabilities": {
            "llm": {
                "enabled": True,
                "backend": "llm",
                "model": "dummy-llm",
                "base_url": "https://example.com",
                "timeout_sec": 10.0,
                "use_cache": True,
            },
            "opus-mt-zh-en": {
                "enabled": True,
                "backend": "local_marian",
                "model_id": "dummy",
                "model_dir": "dummy",
                "device": "cpu",
                "torch_dtype": "float32",
                "batch_size": 8,
                "max_input_length": 16,
                "max_new_tokens": 16,
                "num_beams": 1,
                "use_cache": True,
            },
        },
        "cache": {
            "ttl_seconds": 60,
            "sliding_expiration": True,
        },
    }

    service = TranslationService(config)
    fake_cache = _FakeCache()
    service._translation_cache = fake_cache

    first = service.translate("商品标题", target_lang="en", source_lang="zh", model="llm")
    second = service.translate("商品标题", target_lang="en", source_lang="zh", model="llm")
    batch = service.translate(["连衣裙", "衬衫"], target_lang="en", source_lang="zh", model="opus-mt-zh-en")

    assert first == "llm:商品标题"
    assert second == "llm:商品标题"
    assert batch == ["opus-mt-zh-en:连衣裙", "opus-mt-zh-en:衬衫"]
    assert fake_cache.get_calls == [
        ("llm", "en", "商品标题"),
        ("llm", "en", "商品标题"),
        ("opus-mt-zh-en", "en", "连衣裙"),
        ("opus-mt-zh-en", "en", "衬衫"),
    ]
    assert fake_cache.set_calls == [
        ("llm", "en", "商品标题", "llm:商品标题"),
        ("opus-mt-zh-en", "en", "连衣裙", "opus-mt-zh-en:连衣裙"),
        ("opus-mt-zh-en", "en", "衬衫", "opus-mt-zh-en:衬衫"),
    ]


def test_translation_request_filter_injects_reqid():
    reqid, token = bind_translation_request_id("req-test-1234567890")
    try:
        record = logging.LogRecord(
            name="translation.service",
            level=logging.INFO,
            pathname=__file__,
            lineno=1,
            msg="hello",
            args=(),
            exc_info=None,
        )
        TranslationRequestFilter().filter(record)

        assert reqid == "req-test-1234567890"
        assert record.reqid == "req-test-1234567890"
    finally:
        reset_translation_request_id(token)


def test_translation_route_log_focuses_on_routing_decision(monkeypatch, caplog):
    monkeypatch.setattr(TranslationCache, "_init_redis_client", staticmethod(lambda: None))

    def _fake_create_backend(self, *, name, backend_type, cfg):
        del self, backend_type, cfg

        class _Backend:
            model = name

            @property
            def supports_batch(self):
                return True

            def translate(self, text, target_lang, source_lang=None, scene=None):
                del target_lang, source_lang, scene
                return text

        return _Backend()

    monkeypatch.setattr(TranslationService, "_create_backend", _fake_create_backend)
    service = TranslationService(
        {
            "service_url": "http://127.0.0.1:6006",
            "timeout_sec": 10.0,
            "default_model": "llm",
            "default_scene": "general",
            "capabilities": {
                "llm": {
                    "enabled": True,
                    "backend": "llm",
                    "model": "dummy-llm",
                    "base_url": "https://example.com",
                    "timeout_sec": 10.0,
                    "use_cache": True,
                }
            },
            "cache": {
                "ttl_seconds": 60,
                "sliding_expiration": True,
            },
        }
    )

    with caplog.at_level(logging.INFO):
        service.translate("商品标题", target_lang="en", source_lang="zh", model="llm")

    route_messages = [
        record.getMessage()
        for record in caplog.records
        if record.name == "translation.service" and record.getMessage().startswith("Translation route |")
    ]

    assert route_messages == [
        "Translation route | backend=llm request_type=single use_cache=True cache_available=False"
    ]