settings.py 12 KB
"""Translation config normalization and validation helpers."""

from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, Tuple

from translation.scenes import normalize_scene_name


TranslationConfig = Dict[str, Any]


def build_translation_config(raw_cfg: Mapping[str, Any]) -> TranslationConfig:
    if not isinstance(raw_cfg, Mapping):
        raise ValueError("services.translation must be a mapping")

    config: TranslationConfig = {
        "service_url": _require_http_url(raw_cfg.get("service_url"), "services.translation.service_url").rstrip("/"),
        "timeout_sec": _require_positive_float(raw_cfg.get("timeout_sec"), "services.translation.timeout_sec"),
        "default_model": _require_string(raw_cfg.get("default_model"), "services.translation.default_model").lower(),
        "default_scene": normalize_scene_name(
            _require_string(raw_cfg.get("default_scene"), "services.translation.default_scene")
        ),
        "cache": _build_cache_config(raw_cfg.get("cache")),
        "capabilities": _build_capabilities(raw_cfg.get("capabilities")),
    }

    default_model = config["default_model"]
    capabilities = config["capabilities"]
    if default_model not in capabilities:
        raise ValueError(
            f"services.translation.default_model '{default_model}' is not defined in services.translation.capabilities"
        )
    if not capabilities[default_model]["enabled"]:
        raise ValueError(
            f"services.translation.default_model '{default_model}' must reference an enabled capability"
        )
    if not get_enabled_translation_models(config):
        raise ValueError("At least one translation capability must be enabled")

    _validate_model_quality_tiers(config)
    return config


def normalize_translation_model(config: Mapping[str, Any], model: Optional[str]) -> str:
    normalized = str(model or config.get("default_model") or "").strip().lower()
    if not normalized:
        raise ValueError("translation model cannot be empty")
    return normalized


def normalize_translation_scene(config: Mapping[str, Any], scene: Optional[str]) -> str:
    return normalize_scene_name(scene or config.get("default_scene"))


def get_enabled_translation_models(config: Mapping[str, Any]) -> List[str]:
    capabilities = config.get("capabilities")
    if not isinstance(capabilities, Mapping):
        raise ValueError("translation config missing capabilities")
    return [name for name, capability in capabilities.items() if isinstance(capability, Mapping) and capability.get("enabled") is True]


def get_translation_capability(
    config: Mapping[str, Any],
    model: Optional[str],
    *,
    require_enabled: bool = False,
) -> Dict[str, Any]:
    normalized = normalize_translation_model(config, model)
    capabilities = config.get("capabilities")
    if not isinstance(capabilities, Mapping):
        raise ValueError("translation config missing capabilities")

    capability = capabilities.get(normalized)
    if not isinstance(capability, Mapping):
        raise ValueError(f"Translation capability '{normalized}' is not defined")
    if require_enabled and capability.get("enabled") is not True:
        enabled = ", ".join(get_enabled_translation_models(config)) or "none"
        raise ValueError(f"Translation model '{normalized}' is not enabled. Available models: {enabled}")
    return dict(capability)


def get_translation_cache(config: Mapping[str, Any]) -> Dict[str, Any]:
    cache = config.get("cache")
    if not isinstance(cache, Mapping):
        raise ValueError("translation config missing cache")
    return dict(cache)


def translation_cache_probe_models(config: Mapping[str, Any], request_model: str) -> List[str]:
    """Redis cache key models to try.

    Sort order: (1) **tier** descending (higher quality first); (2) within the same tier,
    the **request model** before other peers; (3) remaining ties by model name.

    For a request to model A with tier T, probes every configured model whose tier is
    **greater than or equal to** T. Lower tiers are never used.

    When ``enable_model_quality_tier_cache`` is false, only the request model is probed.

    When ``model_quality_tiers`` is empty or ``request_model`` is not listed, only the
    request model is probed (legacy exact-match behavior).
    """
    rm = str(request_model or "").strip().lower()
    cache = config.get("cache")
    if not isinstance(cache, Mapping):
        return [rm]
    if not bool(cache.get("enable_model_quality_tier_cache", True)):
        return [rm]
    tiers = cache.get("model_quality_tiers")
    if not isinstance(tiers, Mapping) or not tiers:
        return [rm]
    if rm not in tiers:
        return [rm]
    threshold = int(tiers[rm])
    scored: List[Tuple[int, str]] = []
    for name, tier_val in tiers.items():
        n = str(name).strip().lower()
        t = int(tier_val)
        if t >= threshold:
            scored.append((t, n))
    scored.sort(
        key=lambda item: (
            -item[0],
            0 if item[1] == rm else 1,
            item[1],
        )
    )
    out: List[str] = []
    seen: set[str] = set()
    for _t, n in scored:
        if n not in seen:
            seen.add(n)
            out.append(n)
    return out


def _build_cache_config(raw_cache: Any) -> Dict[str, Any]:
    if not isinstance(raw_cache, Mapping):
        raise ValueError("services.translation.cache must be a mapping")
    if "enable_model_quality_tier_cache" in raw_cache:
        enable_tier_cache = _require_bool(
            raw_cache["enable_model_quality_tier_cache"],
            "services.translation.cache.enable_model_quality_tier_cache",
        )
    else:
        enable_tier_cache = True
    return {
        "ttl_seconds": _require_positive_int(raw_cache.get("ttl_seconds"), "services.translation.cache.ttl_seconds"),
        "sliding_expiration": _require_bool(
            raw_cache.get("sliding_expiration"),
            "services.translation.cache.sliding_expiration",
        ),
        "enable_model_quality_tier_cache": enable_tier_cache,
        "model_quality_tiers": _build_model_quality_tiers(raw_cache.get("model_quality_tiers")),
    }


def _build_model_quality_tiers(raw: Any) -> Dict[str, int]:
    if raw is None:
        return {}
    if not isinstance(raw, Mapping):
        raise ValueError("services.translation.cache.model_quality_tiers must be a mapping")
    resolved: Dict[str, int] = {}
    for name, tier_val in raw.items():
        cap = _require_string(name, "services.translation.cache.model_quality_tiers key").lower()
        field = f"services.translation.cache.model_quality_tiers.{cap}"
        resolved[cap] = _require_non_negative_int(tier_val, field)
    return resolved


def _validate_model_quality_tiers(config: TranslationConfig) -> None:
    tiers = config["cache"].get("model_quality_tiers")
    if not isinstance(tiers, Mapping) or not tiers:
        return
    caps = config["capabilities"]
    for name in tiers:
        if name not in caps:
            raise ValueError(
                f"services.translation.cache.model_quality_tiers references unknown capability '{name}'"
            )


def _require_non_negative_int(value: Any, field_name: str) -> int:
    parsed = _require_int(value, field_name)
    if parsed < 0:
        raise ValueError(f"{field_name} must be >= 0")
    return parsed


def _build_capabilities(raw_capabilities: Any) -> Dict[str, Dict[str, Any]]:
    if not isinstance(raw_capabilities, Mapping):
        raise ValueError("services.translation.capabilities must be a mapping")

    resolved: Dict[str, Dict[str, Any]] = {}
    for name, raw_capability in raw_capabilities.items():
        if not isinstance(raw_capability, Mapping):
            raise ValueError(f"services.translation.capabilities.{name} must be a mapping")

        capability_name = _require_string(name, "translation capability name").lower()
        prefix = f"services.translation.capabilities.{capability_name}"
        capability = dict(raw_capability)
        capability["enabled"] = _require_bool(capability.get("enabled"), f"{prefix}.enabled")
        capability["backend"] = _require_string(capability.get("backend"), f"{prefix}.backend").lower()
        _validate_capability(capability_name, capability)
        resolved[capability_name] = capability

    return resolved


def _validate_capability(name: str, capability: Mapping[str, Any]) -> None:
    prefix = f"services.translation.capabilities.{name}"
    backend = capability.get("backend")
    _require_bool(capability.get("use_cache"), f"{prefix}.use_cache")

    if backend == "qwen_mt":
        _require_string(capability.get("model"), f"{prefix}.model")
        _require_http_url(capability.get("base_url"), f"{prefix}.base_url")
        _require_positive_float(capability.get("timeout_sec"), f"{prefix}.timeout_sec")
        return

    if backend == "llm":
        _require_string(capability.get("model"), f"{prefix}.model")
        _require_http_url(capability.get("base_url"), f"{prefix}.base_url")
        _require_positive_float(capability.get("timeout_sec"), f"{prefix}.timeout_sec")
        return

    if backend == "deepl":
        _require_http_url(capability.get("api_url"), f"{prefix}.api_url")
        _require_positive_float(capability.get("timeout_sec"), f"{prefix}.timeout_sec")
        return

    if backend in {"local_nllb", "local_marian"}:
        _require_string(capability.get("model_id"), f"{prefix}.model_id")
        _require_string(capability.get("model_dir"), f"{prefix}.model_dir")
        _require_string(capability.get("device"), f"{prefix}.device")
        _require_string(capability.get("torch_dtype"), f"{prefix}.torch_dtype")
        _require_positive_int(capability.get("batch_size"), f"{prefix}.batch_size")
        _require_positive_int(capability.get("max_input_length"), f"{prefix}.max_input_length")
        _require_positive_int(capability.get("max_new_tokens"), f"{prefix}.max_new_tokens")
        _require_positive_int(capability.get("num_beams"), f"{prefix}.num_beams")
        if "ct2_decoding_length_mode" in capability:
            mode = _require_string(capability.get("ct2_decoding_length_mode"), f"{prefix}.ct2_decoding_length_mode").lower()
            if mode not in {"fixed", "source"}:
                raise ValueError(f"{prefix}.ct2_decoding_length_mode must be one of: fixed, source")
        if "ct2_decoding_length_extra" in capability:
            _require_int(capability.get("ct2_decoding_length_extra"), f"{prefix}.ct2_decoding_length_extra")
        if "ct2_decoding_length_min" in capability:
            _require_positive_int(capability.get("ct2_decoding_length_min"), f"{prefix}.ct2_decoding_length_min")
        return

    raise ValueError(f"Unsupported translation backend '{backend}' for capability '{name}'")


def _require_string(value: Any, field_name: str) -> str:
    text = str(value or "").strip()
    if not text:
        raise ValueError(f"{field_name} is required")
    return text


def _require_float(value: Any, field_name: str) -> float:
    if value in (None, ""):
        raise ValueError(f"{field_name} is required")
    return float(value)


def _require_positive_float(value: Any, field_name: str) -> float:
    parsed = _require_float(value, field_name)
    if parsed <= 0:
        raise ValueError(f"{field_name} must be greater than 0")
    return parsed


def _require_int(value: Any, field_name: str) -> int:
    if value in (None, ""):
        raise ValueError(f"{field_name} is required")
    return int(value)


def _require_positive_int(value: Any, field_name: str) -> int:
    parsed = _require_int(value, field_name)
    if parsed <= 0:
        raise ValueError(f"{field_name} must be greater than 0")
    return parsed


def _require_bool(value: Any, field_name: str) -> bool:
    if not isinstance(value, bool):
        raise ValueError(f"{field_name} must be a boolean")
    return value


def _require_http_url(value: Any, field_name: str) -> str:
    text = _require_string(value, field_name)
    if not (text.startswith("http://") or text.startswith("https://")):
        raise ValueError(f"{field_name} must start with http:// or https://")
    return text