service.py 6.92 KB
"""Translation service orchestration."""

from __future__ import annotations

import logging
import threading
from typing import Dict, List, Optional

from config.services_config import get_translation_config
from translation.protocols import TranslateInput, TranslateOutput, TranslationBackendProtocol
from translation.settings import (
    TranslationConfig,
    get_enabled_translation_models,
    get_translation_capability,
    normalize_translation_model,
    normalize_translation_scene,
)

logger = logging.getLogger(__name__)


class TranslationService:
    """Owns translation backends and routes calls by model and scene."""

    def __init__(self, config: Optional[TranslationConfig] = None) -> None:
        self.config = config or get_translation_config()
        self._enabled_capabilities = self._collect_enabled_capabilities()
        self._backends: Dict[str, TranslationBackendProtocol] = {}
        self._backend_lock = threading.Lock()
        if not self._enabled_capabilities:
            raise ValueError("No enabled translation backends found in services.translation.capabilities")

    def _collect_enabled_capabilities(self) -> Dict[str, Dict[str, object]]:
        enabled: Dict[str, Dict[str, object]] = {}
        for name in get_enabled_translation_models(self.config):
            capability = get_translation_capability(self.config, name, require_enabled=True)
            backend_type = capability.get("backend")
            if not backend_type:
                raise ValueError(f"Translation capability '{name}' must define a backend")
            enabled[name] = capability
        return enabled

    def _create_backend(
        self,
        *,
        name: str,
        backend_type: str,
        cfg: Dict[str, object],
    ) -> TranslationBackendProtocol:
        registry = {
            "qwen_mt": self._create_qwen_mt_backend,
            "deepl": self._create_deepl_backend,
            "llm": self._create_llm_backend,
            "local_nllb": self._create_local_nllb_backend,
            "local_marian": self._create_local_marian_backend,
        }
        factory = registry.get(backend_type)
        if factory is None:
            raise ValueError(f"Unsupported translation backend '{backend_type}' for capability '{name}'")
        return factory(name=name, cfg=cfg)

    def _create_qwen_mt_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
        from translation.backends.qwen_mt import QwenMTTranslationBackend

        return QwenMTTranslationBackend(
            capability_name=name,
            model=str(cfg["model"]).strip(),
            base_url=str(cfg["base_url"]).strip(),
            api_key=cfg.get("api_key"),
            use_cache=bool(cfg["use_cache"]),
            timeout=int(cfg["timeout_sec"]),
            glossary_id=cfg.get("glossary_id"),
        )

    def _create_deepl_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
        from translation.backends.deepl import DeepLTranslationBackend

        return DeepLTranslationBackend(
            api_key=cfg.get("api_key"),
            api_url=str(cfg["api_url"]).strip(),
            timeout=float(cfg["timeout_sec"]),
            glossary_id=cfg.get("glossary_id"),
        )

    def _create_llm_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
        from translation.backends.llm import LLMTranslationBackend

        return LLMTranslationBackend(
            capability_name=name,
            model=str(cfg["model"]).strip(),
            timeout_sec=float(cfg["timeout_sec"]),
            base_url=str(cfg["base_url"]).strip(),
        )

    def _create_local_nllb_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
        from translation.backends.local_seq2seq import NLLBTranslationBackend

        return NLLBTranslationBackend(
            name=name,
            model_id=str(cfg["model_id"]).strip(),
            model_dir=str(cfg["model_dir"]).strip(),
            device=str(cfg["device"]).strip(),
            torch_dtype=str(cfg["torch_dtype"]).strip(),
            batch_size=int(cfg["batch_size"]),
            max_input_length=int(cfg["max_input_length"]),
            max_new_tokens=int(cfg["max_new_tokens"]),
            num_beams=int(cfg["num_beams"]),
        )

    def _create_local_marian_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol:
        from translation.backends.local_seq2seq import MarianMTTranslationBackend, get_marian_language_direction

        source_lang, target_lang = get_marian_language_direction(name)

        return MarianMTTranslationBackend(
            name=name,
            model_id=str(cfg["model_id"]).strip(),
            model_dir=str(cfg["model_dir"]).strip(),
            device=str(cfg["device"]).strip(),
            torch_dtype=str(cfg["torch_dtype"]).strip(),
            batch_size=int(cfg["batch_size"]),
            max_input_length=int(cfg["max_input_length"]),
            max_new_tokens=int(cfg["max_new_tokens"]),
            num_beams=int(cfg["num_beams"]),
            source_langs=[source_lang],
            target_langs=[target_lang],
        )

    @property
    def available_models(self) -> List[str]:
        return list(self._enabled_capabilities.keys())

    @property
    def loaded_models(self) -> List[str]:
        return list(self._backends.keys())

    def get_backend(self, model: Optional[str] = None) -> TranslationBackendProtocol:
        normalized = normalize_translation_model(self.config, model)
        capability_cfg = self._enabled_capabilities.get(normalized)
        if capability_cfg is None:
            raise ValueError(
                f"Translation model '{normalized}' is not enabled. "
                f"Available models: {', '.join(self.available_models) or 'none'}"
            )
        backend = self._backends.get(normalized)
        if backend is not None:
            return backend
        with self._backend_lock:
            backend = self._backends.get(normalized)
            if backend is None:
                backend_type = str(capability_cfg["backend"])
                logger.info("Initializing translation backend | model=%s backend=%s", normalized, backend_type)
                backend = self._create_backend(
                    name=normalized,
                    backend_type=backend_type,
                    cfg=capability_cfg,
                )
                self._backends[normalized] = backend
        return backend

    def translate(
        self,
        text: TranslateInput,
        target_lang: str,
        source_lang: Optional[str] = None,
        *,
        model: Optional[str] = None,
        scene: Optional[str] = None,
    ) -> TranslateOutput:
        backend = self.get_backend(model)
        active_scene = normalize_translation_scene(self.config, scene)
        return backend.translate(
            text=text,
            target_lang=target_lang,
            source_lang=source_lang,
            scene=active_scene,
        )