"""Translation config normalization and validation helpers.""" from __future__ import annotations from typing import Any, Dict, List, Mapping, Optional, Tuple from translation.scenes import normalize_scene_name TranslationConfig = Dict[str, Any] def build_translation_config(raw_cfg: Mapping[str, Any]) -> TranslationConfig: if not isinstance(raw_cfg, Mapping): raise ValueError("services.translation must be a mapping") config: TranslationConfig = { "service_url": _require_http_url(raw_cfg.get("service_url"), "services.translation.service_url").rstrip("/"), "timeout_sec": _require_positive_float(raw_cfg.get("timeout_sec"), "services.translation.timeout_sec"), "default_model": _require_string(raw_cfg.get("default_model"), "services.translation.default_model").lower(), "default_scene": normalize_scene_name( _require_string(raw_cfg.get("default_scene"), "services.translation.default_scene") ), "cache": _build_cache_config(raw_cfg.get("cache")), "capabilities": _build_capabilities(raw_cfg.get("capabilities")), } default_model = config["default_model"] capabilities = config["capabilities"] if default_model not in capabilities: raise ValueError( f"services.translation.default_model '{default_model}' is not defined in services.translation.capabilities" ) if not capabilities[default_model]["enabled"]: raise ValueError( f"services.translation.default_model '{default_model}' must reference an enabled capability" ) if not get_enabled_translation_models(config): raise ValueError("At least one translation capability must be enabled") _validate_model_quality_tiers(config) return config def normalize_translation_model(config: Mapping[str, Any], model: Optional[str]) -> str: normalized = str(model or config.get("default_model") or "").strip().lower() if not normalized: raise ValueError("translation model cannot be empty") return normalized def normalize_translation_scene(config: Mapping[str, Any], scene: Optional[str]) -> str: return normalize_scene_name(scene or config.get("default_scene")) def get_enabled_translation_models(config: Mapping[str, Any]) -> List[str]: capabilities = config.get("capabilities") if not isinstance(capabilities, Mapping): raise ValueError("translation config missing capabilities") return [name for name, capability in capabilities.items() if isinstance(capability, Mapping) and capability.get("enabled") is True] def get_translation_capability( config: Mapping[str, Any], model: Optional[str], *, require_enabled: bool = False, ) -> Dict[str, Any]: normalized = normalize_translation_model(config, model) capabilities = config.get("capabilities") if not isinstance(capabilities, Mapping): raise ValueError("translation config missing capabilities") capability = capabilities.get(normalized) if not isinstance(capability, Mapping): raise ValueError(f"Translation capability '{normalized}' is not defined") if require_enabled and capability.get("enabled") is not True: enabled = ", ".join(get_enabled_translation_models(config)) or "none" raise ValueError(f"Translation model '{normalized}' is not enabled. Available models: {enabled}") return dict(capability) def get_translation_cache(config: Mapping[str, Any]) -> Dict[str, Any]: cache = config.get("cache") if not isinstance(cache, Mapping): raise ValueError("translation config missing cache") return dict(cache) def translation_cache_probe_models(config: Mapping[str, Any], request_model: str) -> List[str]: """Redis cache key models to try. Sort order: (1) **tier** descending (higher quality first); (2) within the same tier, the **request model** before other peers; (3) remaining ties by model name. For a request to model A with tier T, probes every configured model whose tier is **greater than or equal to** T. Lower tiers are never used. When ``enable_model_quality_tier_cache`` is false, only the request model is probed. When ``model_quality_tiers`` is empty or ``request_model`` is not listed, only the request model is probed (legacy exact-match behavior). """ rm = str(request_model or "").strip().lower() cache = config.get("cache") if not isinstance(cache, Mapping): return [rm] if not bool(cache.get("enable_model_quality_tier_cache", True)): return [rm] tiers = cache.get("model_quality_tiers") if not isinstance(tiers, Mapping) or not tiers: return [rm] if rm not in tiers: return [rm] threshold = int(tiers[rm]) scored: List[Tuple[int, str]] = [] for name, tier_val in tiers.items(): n = str(name).strip().lower() t = int(tier_val) if t >= threshold: scored.append((t, n)) scored.sort( key=lambda item: ( -item[0], 0 if item[1] == rm else 1, item[1], ) ) out: List[str] = [] seen: set[str] = set() for _t, n in scored: if n not in seen: seen.add(n) out.append(n) return out def _build_cache_config(raw_cache: Any) -> Dict[str, Any]: if not isinstance(raw_cache, Mapping): raise ValueError("services.translation.cache must be a mapping") if "enable_model_quality_tier_cache" in raw_cache: enable_tier_cache = _require_bool( raw_cache["enable_model_quality_tier_cache"], "services.translation.cache.enable_model_quality_tier_cache", ) else: enable_tier_cache = True return { "ttl_seconds": _require_positive_int(raw_cache.get("ttl_seconds"), "services.translation.cache.ttl_seconds"), "sliding_expiration": _require_bool( raw_cache.get("sliding_expiration"), "services.translation.cache.sliding_expiration", ), "enable_model_quality_tier_cache": enable_tier_cache, "model_quality_tiers": _build_model_quality_tiers(raw_cache.get("model_quality_tiers")), } def _build_model_quality_tiers(raw: Any) -> Dict[str, int]: if raw is None: return {} if not isinstance(raw, Mapping): raise ValueError("services.translation.cache.model_quality_tiers must be a mapping") resolved: Dict[str, int] = {} for name, tier_val in raw.items(): cap = _require_string(name, "services.translation.cache.model_quality_tiers key").lower() field = f"services.translation.cache.model_quality_tiers.{cap}" resolved[cap] = _require_non_negative_int(tier_val, field) return resolved def _validate_model_quality_tiers(config: TranslationConfig) -> None: tiers = config["cache"].get("model_quality_tiers") if not isinstance(tiers, Mapping) or not tiers: return caps = config["capabilities"] for name in tiers: if name not in caps: raise ValueError( f"services.translation.cache.model_quality_tiers references unknown capability '{name}'" ) def _require_non_negative_int(value: Any, field_name: str) -> int: parsed = _require_int(value, field_name) if parsed < 0: raise ValueError(f"{field_name} must be >= 0") return parsed def _build_capabilities(raw_capabilities: Any) -> Dict[str, Dict[str, Any]]: if not isinstance(raw_capabilities, Mapping): raise ValueError("services.translation.capabilities must be a mapping") resolved: Dict[str, Dict[str, Any]] = {} for name, raw_capability in raw_capabilities.items(): if not isinstance(raw_capability, Mapping): raise ValueError(f"services.translation.capabilities.{name} must be a mapping") capability_name = _require_string(name, "translation capability name").lower() prefix = f"services.translation.capabilities.{capability_name}" capability = dict(raw_capability) capability["enabled"] = _require_bool(capability.get("enabled"), f"{prefix}.enabled") capability["backend"] = _require_string(capability.get("backend"), f"{prefix}.backend").lower() _validate_capability(capability_name, capability) resolved[capability_name] = capability return resolved def _validate_capability(name: str, capability: Mapping[str, Any]) -> None: prefix = f"services.translation.capabilities.{name}" backend = capability.get("backend") _require_bool(capability.get("use_cache"), f"{prefix}.use_cache") if backend == "qwen_mt": _require_string(capability.get("model"), f"{prefix}.model") _require_http_url(capability.get("base_url"), f"{prefix}.base_url") _require_positive_float(capability.get("timeout_sec"), f"{prefix}.timeout_sec") return if backend == "llm": _require_string(capability.get("model"), f"{prefix}.model") _require_http_url(capability.get("base_url"), f"{prefix}.base_url") _require_positive_float(capability.get("timeout_sec"), f"{prefix}.timeout_sec") return if backend == "deepl": _require_http_url(capability.get("api_url"), f"{prefix}.api_url") _require_positive_float(capability.get("timeout_sec"), f"{prefix}.timeout_sec") return if backend in {"local_nllb", "local_marian"}: _require_string(capability.get("model_id"), f"{prefix}.model_id") _require_string(capability.get("model_dir"), f"{prefix}.model_dir") _require_string(capability.get("device"), f"{prefix}.device") _require_string(capability.get("torch_dtype"), f"{prefix}.torch_dtype") _require_positive_int(capability.get("batch_size"), f"{prefix}.batch_size") _require_positive_int(capability.get("max_input_length"), f"{prefix}.max_input_length") _require_positive_int(capability.get("max_new_tokens"), f"{prefix}.max_new_tokens") _require_positive_int(capability.get("num_beams"), f"{prefix}.num_beams") if "ct2_decoding_length_mode" in capability: mode = _require_string(capability.get("ct2_decoding_length_mode"), f"{prefix}.ct2_decoding_length_mode").lower() if mode not in {"fixed", "source"}: raise ValueError(f"{prefix}.ct2_decoding_length_mode must be one of: fixed, source") if "ct2_decoding_length_extra" in capability: _require_int(capability.get("ct2_decoding_length_extra"), f"{prefix}.ct2_decoding_length_extra") if "ct2_decoding_length_min" in capability: _require_positive_int(capability.get("ct2_decoding_length_min"), f"{prefix}.ct2_decoding_length_min") return raise ValueError(f"Unsupported translation backend '{backend}' for capability '{name}'") def _require_string(value: Any, field_name: str) -> str: text = str(value or "").strip() if not text: raise ValueError(f"{field_name} is required") return text def _require_float(value: Any, field_name: str) -> float: if value in (None, ""): raise ValueError(f"{field_name} is required") return float(value) def _require_positive_float(value: Any, field_name: str) -> float: parsed = _require_float(value, field_name) if parsed <= 0: raise ValueError(f"{field_name} must be greater than 0") return parsed def _require_int(value: Any, field_name: str) -> int: if value in (None, ""): raise ValueError(f"{field_name} is required") return int(value) def _require_positive_int(value: Any, field_name: str) -> int: parsed = _require_int(value, field_name) if parsed <= 0: raise ValueError(f"{field_name} must be greater than 0") return parsed def _require_bool(value: Any, field_name: str) -> bool: if not isinstance(value, bool): raise ValueError(f"{field_name} must be a boolean") return value def _require_http_url(value: Any, field_name: str) -> str: text = _require_string(value, field_name) if not (text.startswith("http://") or text.startswith("https://")): raise ValueError(f"{field_name} must start with http:// or https://") return text