service.py 14.6 KB
"""Translation service orchestration."""

from __future__ import annotations

import logging
from typing import Dict, List, Optional, Tuple

from config.loader import get_app_config
from config.schema import AppConfig
from translation.cache import TranslationCache
from translation.protocols import TranslateInput, TranslateOutput, TranslationBackendProtocol
from translation.settings import (
    TranslationConfig,
    get_enabled_translation_models,
    get_translation_capability,
    normalize_translation_model,
    normalize_translation_scene,
    translation_cache_probe_models,
)

logger = logging.getLogger(__name__)


class TranslationService:
    """Owns translation backends and routes calls by model and scene."""

    def __init__(self, config: Optional[TranslationConfig] = None, app_config: Optional[AppConfig] = None) -> None:
        self._app_config = app_config or get_app_config()
        self.config = config or self._app_config.services.translation.as_dict()
        self._enabled_capabilities = self._collect_enabled_capabilities()
        if not self._enabled_capabilities:
            raise ValueError("No enabled translation backends found in services.translation.capabilities")
        self._translation_cache = TranslationCache(self.config["cache"])
        self._backends = self._initialize_backends()

    def _collect_enabled_capabilities(self) -> Dict[str, Dict[str, object]]:
        enabled: Dict[str, Dict[str, object]] = {}
        for name in get_enabled_translation_models(self.config):
            capability = get_translation_capability(self.config, name, require_enabled=True)
            backend_type = capability.get("backend")
            if not backend_type:
                raise ValueError(f"Translation capability '{name}' must define a backend")
            enabled[name] = capability
        return enabled

    def _create_backend(
        self,
        *,
        name: str,
        backend_type: str,
        cfg: Dict[str, object],
    ) -> TranslationBackendProtocol:
        registry = {
            "qwen_mt": self._create_qwen_mt_backend,
            "deepl": self._create_deepl_backend,
            "llm": self._create_llm_backend,
            "local_nllb": self._create_local_nllb_backend,
            "local_marian": self._create_local_marian_backend,
        }
        factory = registry.get(backend_type)
        if factory is None:
            raise ValueError(f"Unsupported translation backend '{backend_type}' for capability '{name}'")
        return factory(name=name, cfg=cfg)

    def _initialize_backends(self) -> Dict[str, TranslationBackendProtocol]:
        backends: Dict[str, TranslationBackendProtocol] = {}
        for name, capability_cfg in self._enabled_capabilities.items():
            backend_type = str(capability_cfg["backend"])
            logger.info("Initializing translation backend | model=%s backend=%s", name, backend_type)
            backends[name] = self._create_backend(
                name=name,
                backend_type=backend_type,
                cfg=capability_cfg,
            )
            logger.info(
                "Translation backend initialized | model=%s backend=%s use_cache=%s backend_model=%s",
                name,
                backend_type,
                bool(capability_cfg.get("use_cache")),
                getattr(backends[name], "model", name),
            )
        return backends

    def _create_qwen_mt_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
        from translation.backends.qwen_mt import QwenMTTranslationBackend

        return QwenMTTranslationBackend(
            capability_name=name,
            model=str(cfg["model"]).strip(),
            base_url=str(cfg["base_url"]).strip(),
            api_key=self._app_config.infrastructure.secrets.dashscope_api_key,
            timeout=int(cfg["timeout_sec"]),
            glossary_id=cfg.get("glossary_id"),
        )

    def _create_deepl_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
        from translation.backends.deepl import DeepLTranslationBackend

        return DeepLTranslationBackend(
            api_key=self._app_config.infrastructure.secrets.deepl_auth_key,
            api_url=str(cfg["api_url"]).strip(),
            timeout=float(cfg["timeout_sec"]),
            glossary_id=cfg.get("glossary_id"),
        )

    def _create_llm_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
        from translation.backends.llm import LLMTranslationBackend

        return LLMTranslationBackend(
            capability_name=name,
            model=str(cfg["model"]).strip(),
            timeout_sec=float(cfg["timeout_sec"]),
            base_url=str(cfg["base_url"]).strip(),
            api_key=self._app_config.infrastructure.secrets.dashscope_api_key,
        )

    def _create_local_nllb_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
        from translation.backends.local_ctranslate2 import NLLBCTranslate2TranslationBackend

        return NLLBCTranslate2TranslationBackend(
            name=name,
            model_id=str(cfg["model_id"]).strip(),
            model_dir=str(cfg["model_dir"]).strip(),
            device=str(cfg["device"]).strip(),
            torch_dtype=str(cfg["torch_dtype"]).strip(),
            batch_size=int(cfg["batch_size"]),
            max_input_length=int(cfg["max_input_length"]),
            max_new_tokens=int(cfg["max_new_tokens"]),
            num_beams=int(cfg["num_beams"]),
            ct2_model_dir=cfg.get("ct2_model_dir"),
            ct2_compute_type=cfg.get("ct2_compute_type"),
            ct2_auto_convert=bool(cfg.get("ct2_auto_convert", True)),
            ct2_conversion_quantization=cfg.get("ct2_conversion_quantization"),
            ct2_inter_threads=int(cfg.get("ct2_inter_threads", 1)),
            ct2_intra_threads=int(cfg.get("ct2_intra_threads", 0)),
            ct2_max_queued_batches=int(cfg.get("ct2_max_queued_batches", 0)),
            ct2_batch_type=str(cfg.get("ct2_batch_type", "examples")),
            ct2_decoding_length_mode=str(cfg.get("ct2_decoding_length_mode", "fixed")),
            ct2_decoding_length_extra=int(cfg.get("ct2_decoding_length_extra", 0)),
            ct2_decoding_length_min=int(cfg.get("ct2_decoding_length_min", 1)),
        )

    def _create_local_marian_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
        from translation.backends.local_ctranslate2 import MarianCTranslate2TranslationBackend, get_marian_language_direction

        source_lang, target_lang = get_marian_language_direction(name)

        return MarianCTranslate2TranslationBackend(
            name=name,
            model_id=str(cfg["model_id"]).strip(),
            model_dir=str(cfg["model_dir"]).strip(),
            device=str(cfg["device"]).strip(),
            torch_dtype=str(cfg["torch_dtype"]).strip(),
            batch_size=int(cfg["batch_size"]),
            max_input_length=int(cfg["max_input_length"]),
            max_new_tokens=int(cfg["max_new_tokens"]),
            num_beams=int(cfg["num_beams"]),
            source_langs=[source_lang],
            target_langs=[target_lang],
            ct2_model_dir=cfg.get("ct2_model_dir"),
            ct2_compute_type=cfg.get("ct2_compute_type"),
            ct2_auto_convert=bool(cfg.get("ct2_auto_convert", True)),
            ct2_conversion_quantization=cfg.get("ct2_conversion_quantization"),
            ct2_inter_threads=int(cfg.get("ct2_inter_threads", 1)),
            ct2_intra_threads=int(cfg.get("ct2_intra_threads", 0)),
            ct2_max_queued_batches=int(cfg.get("ct2_max_queued_batches", 0)),
            ct2_batch_type=str(cfg.get("ct2_batch_type", "examples")),
            ct2_decoding_length_mode=str(cfg.get("ct2_decoding_length_mode", "fixed")),
            ct2_decoding_length_extra=int(cfg.get("ct2_decoding_length_extra", 0)),
            ct2_decoding_length_min=int(cfg.get("ct2_decoding_length_min", 1)),
        )

    @property
    def available_models(self) -> List[str]:
        return list(self._enabled_capabilities.keys())

    @property
    def loaded_models(self) -> List[str]:
        return list(self._backends.keys())

    def get_backend(self, model: Optional[str] = None) -> TranslationBackendProtocol:
        normalized = normalize_translation_model(self.config, model)
        backend = self._backends.get(normalized)
        if backend is None:
            raise ValueError(
                f"Translation model '{normalized}' is not enabled. "
                f"Available models: {', '.join(self.available_models) or 'none'}"
            )
        return backend

    def translate(
        self,
        text: TranslateInput,
        target_lang: str,
        source_lang: Optional[str] = None,
        *,
        model: Optional[str] = None,
        scene: Optional[str] = None,
    ) -> TranslateOutput:
        normalized_model = normalize_translation_model(self.config, model)
        backend = self.get_backend(normalized_model)
        active_scene = normalize_translation_scene(self.config, scene)
        capability_cfg = self._enabled_capabilities[normalized_model]
        use_cache = bool(capability_cfg.get("use_cache"))
        logger.info(
            "Translation route | backend=%s request_type=%s use_cache=%s cache_available=%s",
            getattr(backend, "model", normalized_model),
            "single" if isinstance(text, str) else "batch",
            use_cache,
            self._translation_cache.available,
        )
        if not use_cache or not self._translation_cache.available:
            return backend.translate(
                text=text,
                target_lang=target_lang,
                source_lang=source_lang,
                scene=active_scene,
            )

        if isinstance(text, str):
            return self._translate_with_cache(
                backend,
                text=text,
                target_lang=target_lang,
                source_lang=source_lang,
                scene=active_scene,
                model=normalized_model,
            )

        return self._translate_batch_with_cache(
            text=text,
            target_lang=target_lang,
            source_lang=source_lang,
            backend=backend,
            scene=active_scene,
            model=normalized_model,
        )

    def _translate_with_cache(
        self,
        backend: TranslationBackendProtocol,
        *,
        text: str,
        target_lang: str,
        source_lang: Optional[str],
        scene: str,
        model: str,
    ) -> Optional[str]:
        if not text.strip():
            return text
        cached, _served = self._tiered_cache_get(
            request_model=model,
            target_lang=target_lang,
            source_text=text,
        )
        if cached is not None:
            logger.info(
                "Translation cache served | request_type=single text_len=%s",
                len(text),
            )
            return cached
        translated = backend.translate(
            text=text,
            target_lang=target_lang,
            source_lang=source_lang,
            scene=scene,
        )
        if translated is not None:
            self._translation_cache.set(
                model=model,
                target_lang=target_lang,
                source_text=text,
                translated_text=translated,
            )
            logger.info(
                "Translation backend result cached | request_type=single text_len=%s result_len=%s",
                len(text),
                len(str(translated)),
            )
        else:
            logger.warning(
                "Translation backend returned empty result | request_type=single text_len=%s",
                len(text),
            )
        return translated

    def _tiered_cache_get(
        self,
        *,
        request_model: str,
        target_lang: str,
        source_text: str,
    ) -> Tuple[Optional[str], Optional[str]]:
        """Redis lookup: cache from higher-tier or **same-tier** models may satisfy A.

        Lower-tier entries are never read. Returns ``(translated, served_model)``.
        """
        probe_models = translation_cache_probe_models(self.config, request_model)

        for probe_model in probe_models:
            hit = self._translation_cache.get(
                model=probe_model,
                target_lang=target_lang,
                source_text=source_text,
            )
            if hit is not None:
                return hit, probe_model

        return None, None

    def _translate_batch_with_cache(
        self,
        *,
        text: TranslateInput,
        target_lang: str,
        source_lang: Optional[str],
        backend: TranslationBackendProtocol,
        scene: str,
        model: str,
    ) -> List[Optional[str]]:
        texts = list(text)
        results: List[Optional[str]] = [None] * len(texts)
        misses: List[str] = []
        miss_indices: List[int] = []
        cache_hits = 0

        for idx, item in enumerate(texts):
            normalized_text = "" if item is None else str(item)
            if not normalized_text.strip():
                results[idx] = normalized_text
                continue
            cached, _served = self._tiered_cache_get(
                request_model=model,
                target_lang=target_lang,
                source_text=normalized_text,
            )
            if cached is not None:
                results[idx] = cached
                cache_hits += 1
                continue
            misses.append(normalized_text)
            miss_indices.append(idx)

        logger.info(
            "Translation batch cache summary | total=%s cache_hits=%s cache_misses=%s",
            len(texts),
            cache_hits,
            len(misses),
        )

        if misses:
            translated = backend.translate(
                text=misses,
                target_lang=target_lang,
                source_lang=source_lang,
                scene=scene,
            )
            translated_list = translated if isinstance(translated, list) else [translated]
            for idx, original_text, translated_text in zip(miss_indices, misses, translated_list):
                results[idx] = translated_text
                if translated_text is not None:
                    self._translation_cache.set(
                        model=model,
                        target_lang=target_lang,
                        source_text=original_text,
                        translated_text=translated_text,
                    )
                else:
                    logger.warning(
                        "Translation batch item returned empty result | item_index=%s text_len=%s",
                        idx,
                        len(original_text),
                    )

        return results