settings.py 8.71 KB
"""Translation config normalization and validation helpers."""

from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional

from translation.scenes import normalize_scene_name


TranslationConfig = Dict[str, Any]


def build_translation_config(raw_cfg: Mapping[str, Any]) -> TranslationConfig:
    if not isinstance(raw_cfg, Mapping):
        raise ValueError("services.translation must be a mapping")

    config: TranslationConfig = {
        "service_url": _require_http_url(raw_cfg.get("service_url"), "services.translation.service_url").rstrip("/"),
        "timeout_sec": _require_positive_float(raw_cfg.get("timeout_sec"), "services.translation.timeout_sec"),
        "default_model": _require_string(raw_cfg.get("default_model"), "services.translation.default_model").lower(),
        "default_scene": normalize_scene_name(
            _require_string(raw_cfg.get("default_scene"), "services.translation.default_scene")
        ),
        "cache": _build_cache_config(raw_cfg.get("cache")),
        "capabilities": _build_capabilities(raw_cfg.get("capabilities")),
    }

    default_model = config["default_model"]
    capabilities = config["capabilities"]
    if default_model not in capabilities:
        raise ValueError(
            f"services.translation.default_model '{default_model}' is not defined in services.translation.capabilities"
        )
    if not capabilities[default_model]["enabled"]:
        raise ValueError(
            f"services.translation.default_model '{default_model}' must reference an enabled capability"
        )
    if not get_enabled_translation_models(config):
        raise ValueError("At least one translation capability must be enabled")

    return config


def normalize_translation_model(config: Mapping[str, Any], model: Optional[str]) -> str:
    normalized = str(model or config.get("default_model") or "").strip().lower()
    if not normalized:
        raise ValueError("translation model cannot be empty")
    return normalized


def normalize_translation_scene(config: Mapping[str, Any], scene: Optional[str]) -> str:
    return normalize_scene_name(scene or config.get("default_scene"))


def get_enabled_translation_models(config: Mapping[str, Any]) -> List[str]:
    capabilities = config.get("capabilities")
    if not isinstance(capabilities, Mapping):
        raise ValueError("translation config missing capabilities")
    return [name for name, capability in capabilities.items() if isinstance(capability, Mapping) and capability.get("enabled") is True]


def get_translation_capability(
    config: Mapping[str, Any],
    model: Optional[str],
    *,
    require_enabled: bool = False,
) -> Dict[str, Any]:
    normalized = normalize_translation_model(config, model)
    capabilities = config.get("capabilities")
    if not isinstance(capabilities, Mapping):
        raise ValueError("translation config missing capabilities")

    capability = capabilities.get(normalized)
    if not isinstance(capability, Mapping):
        raise ValueError(f"Translation capability '{normalized}' is not defined")
    if require_enabled and capability.get("enabled") is not True:
        enabled = ", ".join(get_enabled_translation_models(config)) or "none"
        raise ValueError(f"Translation model '{normalized}' is not enabled. Available models: {enabled}")
    return dict(capability)


def get_translation_cache(config: Mapping[str, Any]) -> Dict[str, Any]:
    cache = config.get("cache")
    if not isinstance(cache, Mapping):
        raise ValueError("translation config missing cache")
    return dict(cache)


def _build_cache_config(raw_cache: Any) -> Dict[str, Any]:
    if not isinstance(raw_cache, Mapping):
        raise ValueError("services.translation.cache must be a mapping")
    return {
        "ttl_seconds": _require_positive_int(raw_cache.get("ttl_seconds"), "services.translation.cache.ttl_seconds"),
        "sliding_expiration": _require_bool(
            raw_cache.get("sliding_expiration"),
            "services.translation.cache.sliding_expiration",
        ),
    }


def _build_capabilities(raw_capabilities: Any) -> Dict[str, Dict[str, Any]]:
    if not isinstance(raw_capabilities, Mapping):
        raise ValueError("services.translation.capabilities must be a mapping")

    resolved: Dict[str, Dict[str, Any]] = {}
    for name, raw_capability in raw_capabilities.items():
        if not isinstance(raw_capability, Mapping):
            raise ValueError(f"services.translation.capabilities.{name} must be a mapping")

        capability_name = _require_string(name, "translation capability name").lower()
        prefix = f"services.translation.capabilities.{capability_name}"
        capability = dict(raw_capability)
        capability["enabled"] = _require_bool(capability.get("enabled"), f"{prefix}.enabled")
        capability["backend"] = _require_string(capability.get("backend"), f"{prefix}.backend").lower()
        _validate_capability(capability_name, capability)
        resolved[capability_name] = capability

    return resolved


def _validate_capability(name: str, capability: Mapping[str, Any]) -> None:
    prefix = f"services.translation.capabilities.{name}"
    backend = capability.get("backend")
    _require_bool(capability.get("use_cache"), f"{prefix}.use_cache")

    if backend == "qwen_mt":
        _require_string(capability.get("model"), f"{prefix}.model")
        _require_http_url(capability.get("base_url"), f"{prefix}.base_url")
        _require_positive_float(capability.get("timeout_sec"), f"{prefix}.timeout_sec")
        return

    if backend == "llm":
        _require_string(capability.get("model"), f"{prefix}.model")
        _require_http_url(capability.get("base_url"), f"{prefix}.base_url")
        _require_positive_float(capability.get("timeout_sec"), f"{prefix}.timeout_sec")
        return

    if backend == "deepl":
        _require_http_url(capability.get("api_url"), f"{prefix}.api_url")
        _require_positive_float(capability.get("timeout_sec"), f"{prefix}.timeout_sec")
        return

    if backend in {"local_nllb", "local_marian"}:
        _require_string(capability.get("model_id"), f"{prefix}.model_id")
        _require_string(capability.get("model_dir"), f"{prefix}.model_dir")
        _require_string(capability.get("device"), f"{prefix}.device")
        _require_string(capability.get("torch_dtype"), f"{prefix}.torch_dtype")
        _require_positive_int(capability.get("batch_size"), f"{prefix}.batch_size")
        _require_positive_int(capability.get("max_input_length"), f"{prefix}.max_input_length")
        _require_positive_int(capability.get("max_new_tokens"), f"{prefix}.max_new_tokens")
        _require_positive_int(capability.get("num_beams"), f"{prefix}.num_beams")
        if "ct2_decoding_length_mode" in capability:
            mode = _require_string(capability.get("ct2_decoding_length_mode"), f"{prefix}.ct2_decoding_length_mode").lower()
            if mode not in {"fixed", "source"}:
                raise ValueError(f"{prefix}.ct2_decoding_length_mode must be one of: fixed, source")
        if "ct2_decoding_length_extra" in capability:
            _require_int(capability.get("ct2_decoding_length_extra"), f"{prefix}.ct2_decoding_length_extra")
        if "ct2_decoding_length_min" in capability:
            _require_positive_int(capability.get("ct2_decoding_length_min"), f"{prefix}.ct2_decoding_length_min")
        return

    raise ValueError(f"Unsupported translation backend '{backend}' for capability '{name}'")


def _require_string(value: Any, field_name: str) -> str:
    text = str(value or "").strip()
    if not text:
        raise ValueError(f"{field_name} is required")
    return text


def _require_float(value: Any, field_name: str) -> float:
    if value in (None, ""):
        raise ValueError(f"{field_name} is required")
    return float(value)


def _require_positive_float(value: Any, field_name: str) -> float:
    parsed = _require_float(value, field_name)
    if parsed <= 0:
        raise ValueError(f"{field_name} must be greater than 0")
    return parsed


def _require_int(value: Any, field_name: str) -> int:
    if value in (None, ""):
        raise ValueError(f"{field_name} is required")
    return int(value)


def _require_positive_int(value: Any, field_name: str) -> int:
    parsed = _require_int(value, field_name)
    if parsed <= 0:
        raise ValueError(f"{field_name} must be greater than 0")
    return parsed


def _require_bool(value: Any, field_name: str) -> bool:
    if not isinstance(value, bool):
        raise ValueError(f"{field_name} must be a boolean")
    return value


def _require_http_url(value: Any, field_name: str) -> str:
    text = _require_string(value, field_name)
    if not (text.startswith("http://") or text.startswith("https://")):
        raise ValueError(f"{field_name} must start with http:// or https://")
    return text