"""Translation service orchestration.""" from __future__ import annotations import logging from typing import Dict, List, Optional from config.services_config import get_translation_config from translation.cache import TranslationCache from translation.protocols import TranslateInput, TranslateOutput, TranslationBackendProtocol from translation.settings import ( TranslationConfig, get_enabled_translation_models, get_translation_capability, normalize_translation_model, normalize_translation_scene, ) logger = logging.getLogger(__name__) class TranslationService: """Owns translation backends and routes calls by model and scene.""" def __init__(self, config: Optional[TranslationConfig] = None) -> None: self.config = config or get_translation_config() self._enabled_capabilities = self._collect_enabled_capabilities() if not self._enabled_capabilities: raise ValueError("No enabled translation backends found in services.translation.capabilities") self._translation_cache = TranslationCache(self.config["cache"]) self._backends = self._initialize_backends() def _collect_enabled_capabilities(self) -> Dict[str, Dict[str, object]]: enabled: Dict[str, Dict[str, object]] = {} for name in get_enabled_translation_models(self.config): capability = get_translation_capability(self.config, name, require_enabled=True) backend_type = capability.get("backend") if not backend_type: raise ValueError(f"Translation capability '{name}' must define a backend") enabled[name] = capability return enabled def _create_backend( self, *, name: str, backend_type: str, cfg: Dict[str, object], ) -> TranslationBackendProtocol: registry = { "qwen_mt": self._create_qwen_mt_backend, "deepl": self._create_deepl_backend, "llm": self._create_llm_backend, "local_nllb": self._create_local_nllb_backend, "local_marian": self._create_local_marian_backend, } factory = registry.get(backend_type) if factory is None: raise ValueError(f"Unsupported translation backend '{backend_type}' for capability '{name}'") return factory(name=name, cfg=cfg) def _initialize_backends(self) -> Dict[str, TranslationBackendProtocol]: backends: Dict[str, TranslationBackendProtocol] = {} for name, capability_cfg in self._enabled_capabilities.items(): backend_type = str(capability_cfg["backend"]) logger.info("Initializing translation backend | model=%s backend=%s", name, backend_type) backends[name] = self._create_backend( name=name, backend_type=backend_type, cfg=capability_cfg, ) logger.info( "Translation backend initialized | model=%s backend=%s use_cache=%s backend_model=%s", name, backend_type, bool(capability_cfg.get("use_cache")), getattr(backends[name], "model", name), ) return backends def _create_qwen_mt_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: from translation.backends.qwen_mt import QwenMTTranslationBackend return QwenMTTranslationBackend( capability_name=name, model=str(cfg["model"]).strip(), base_url=str(cfg["base_url"]).strip(), api_key=cfg.get("api_key"), timeout=int(cfg["timeout_sec"]), glossary_id=cfg.get("glossary_id"), ) def _create_deepl_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: from translation.backends.deepl import DeepLTranslationBackend return DeepLTranslationBackend( api_key=cfg.get("api_key"), api_url=str(cfg["api_url"]).strip(), timeout=float(cfg["timeout_sec"]), glossary_id=cfg.get("glossary_id"), ) def _create_llm_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: from translation.backends.llm import LLMTranslationBackend return LLMTranslationBackend( capability_name=name, model=str(cfg["model"]).strip(), timeout_sec=float(cfg["timeout_sec"]), base_url=str(cfg["base_url"]).strip(), ) def _create_local_nllb_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: from translation.backends.local_ctranslate2 import NLLBCTranslate2TranslationBackend return NLLBCTranslate2TranslationBackend( name=name, model_id=str(cfg["model_id"]).strip(), model_dir=str(cfg["model_dir"]).strip(), device=str(cfg["device"]).strip(), torch_dtype=str(cfg["torch_dtype"]).strip(), batch_size=int(cfg["batch_size"]), max_input_length=int(cfg["max_input_length"]), max_new_tokens=int(cfg["max_new_tokens"]), num_beams=int(cfg["num_beams"]), ct2_model_dir=cfg.get("ct2_model_dir"), ct2_compute_type=cfg.get("ct2_compute_type"), ct2_auto_convert=bool(cfg.get("ct2_auto_convert", True)), ct2_conversion_quantization=cfg.get("ct2_conversion_quantization"), ct2_inter_threads=int(cfg.get("ct2_inter_threads", 1)), ct2_intra_threads=int(cfg.get("ct2_intra_threads", 0)), ct2_max_queued_batches=int(cfg.get("ct2_max_queued_batches", 0)), ct2_batch_type=str(cfg.get("ct2_batch_type", "examples")), ct2_decoding_length_mode=str(cfg.get("ct2_decoding_length_mode", "fixed")), ct2_decoding_length_extra=int(cfg.get("ct2_decoding_length_extra", 0)), ct2_decoding_length_min=int(cfg.get("ct2_decoding_length_min", 1)), ) def _create_local_marian_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: from translation.backends.local_ctranslate2 import MarianCTranslate2TranslationBackend, get_marian_language_direction source_lang, target_lang = get_marian_language_direction(name) return MarianCTranslate2TranslationBackend( name=name, model_id=str(cfg["model_id"]).strip(), model_dir=str(cfg["model_dir"]).strip(), device=str(cfg["device"]).strip(), torch_dtype=str(cfg["torch_dtype"]).strip(), batch_size=int(cfg["batch_size"]), max_input_length=int(cfg["max_input_length"]), max_new_tokens=int(cfg["max_new_tokens"]), num_beams=int(cfg["num_beams"]), source_langs=[source_lang], target_langs=[target_lang], ct2_model_dir=cfg.get("ct2_model_dir"), ct2_compute_type=cfg.get("ct2_compute_type"), ct2_auto_convert=bool(cfg.get("ct2_auto_convert", True)), ct2_conversion_quantization=cfg.get("ct2_conversion_quantization"), ct2_inter_threads=int(cfg.get("ct2_inter_threads", 1)), ct2_intra_threads=int(cfg.get("ct2_intra_threads", 0)), ct2_max_queued_batches=int(cfg.get("ct2_max_queued_batches", 0)), ct2_batch_type=str(cfg.get("ct2_batch_type", "examples")), ct2_decoding_length_mode=str(cfg.get("ct2_decoding_length_mode", "fixed")), ct2_decoding_length_extra=int(cfg.get("ct2_decoding_length_extra", 0)), ct2_decoding_length_min=int(cfg.get("ct2_decoding_length_min", 1)), ) @property def available_models(self) -> List[str]: return list(self._enabled_capabilities.keys()) @property def loaded_models(self) -> List[str]: return list(self._backends.keys()) def get_backend(self, model: Optional[str] = None) -> TranslationBackendProtocol: normalized = normalize_translation_model(self.config, model) backend = self._backends.get(normalized) if backend is None: raise ValueError( f"Translation model '{normalized}' is not enabled. " f"Available models: {', '.join(self.available_models) or 'none'}" ) return backend def translate( self, text: TranslateInput, target_lang: str, source_lang: Optional[str] = None, *, model: Optional[str] = None, scene: Optional[str] = None, ) -> TranslateOutput: normalized_model = normalize_translation_model(self.config, model) backend = self.get_backend(normalized_model) active_scene = normalize_translation_scene(self.config, scene) capability_cfg = self._enabled_capabilities[normalized_model] use_cache = bool(capability_cfg.get("use_cache")) text_count = 1 if isinstance(text, str) else len(list(text)) logger.info( "Translation route | model=%s backend=%s scene=%s target_lang=%s source_lang=%s count=%s use_cache=%s cache_available=%s", normalized_model, getattr(backend, "model", normalized_model), active_scene, target_lang, source_lang or "auto", text_count, use_cache, self._translation_cache.available, ) if not use_cache or not self._translation_cache.available: return backend.translate( text=text, target_lang=target_lang, source_lang=source_lang, scene=active_scene, ) if isinstance(text, str): return self._translate_with_cache( backend, text=text, target_lang=target_lang, source_lang=source_lang, scene=active_scene, model=normalized_model, ) return self._translate_batch_with_cache( text=text, target_lang=target_lang, source_lang=source_lang, backend=backend, scene=active_scene, model=normalized_model, ) def _translate_with_cache( self, backend: TranslationBackendProtocol, *, text: str, target_lang: str, source_lang: Optional[str], scene: str, model: str, ) -> Optional[str]: if not text.strip(): return text cached = self._translation_cache.get(model=model, target_lang=target_lang, source_text=text) if cached is not None: logger.info( "Translation cache served | model=%s scene=%s target_lang=%s source_lang=%s text_len=%s", model, scene, target_lang, source_lang or "auto", len(text), ) return cached translated = backend.translate( text=text, target_lang=target_lang, source_lang=source_lang, scene=scene, ) if translated is not None: self._translation_cache.set( model=model, target_lang=target_lang, source_text=text, translated_text=translated, ) logger.info( "Translation backend result cached | model=%s scene=%s target_lang=%s source_lang=%s text_len=%s result_len=%s", model, scene, target_lang, source_lang or "auto", len(text), len(str(translated)), ) else: logger.warning( "Translation backend returned empty result | model=%s scene=%s target_lang=%s source_lang=%s text_len=%s", model, scene, target_lang, source_lang or "auto", len(text), ) return translated def _translate_batch_with_cache( self, *, text: TranslateInput, target_lang: str, source_lang: Optional[str], backend: TranslationBackendProtocol, scene: str, model: str, ) -> List[Optional[str]]: texts = list(text) results: List[Optional[str]] = [None] * len(texts) misses: List[str] = [] miss_indices: List[int] = [] cache_hits = 0 for idx, item in enumerate(texts): normalized_text = "" if item is None else str(item) if not normalized_text.strip(): results[idx] = normalized_text continue cached = self._translation_cache.get( model=model, target_lang=target_lang, source_text=normalized_text, ) if cached is not None: results[idx] = cached cache_hits += 1 continue misses.append(normalized_text) miss_indices.append(idx) logger.info( "Translation batch cache summary | model=%s scene=%s target_lang=%s source_lang=%s total=%s cache_hits=%s cache_misses=%s", model, scene, target_lang, source_lang or "auto", len(texts), cache_hits, len(misses), ) if misses: translated = backend.translate( text=misses, target_lang=target_lang, source_lang=source_lang, scene=scene, ) translated_list = translated if isinstance(translated, list) else [translated] for idx, original_text, translated_text in zip(miss_indices, misses, translated_list): results[idx] = translated_text if translated_text is not None: self._translation_cache.set( model=model, target_lang=target_lang, source_text=original_text, translated_text=translated_text, ) else: logger.warning( "Translation batch item returned empty result | model=%s scene=%s target_lang=%s source_lang=%s item_index=%s text_len=%s", model, scene, target_lang, source_lang or "auto", idx, len(original_text), ) return results