services_config.py 10.2 KB
"""
Services configuration - single source for translation, embedding, rerank.

Translation is modeled as:
- one translator service endpoint used by business callers
- multiple translation capabilities loaded inside the translator service
"""

from __future__ import annotations

import os
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional

import yaml


@dataclass
class ServiceConfig:
    """Config for one capability (embedding/rerank)."""

    provider: str
    providers: Dict[str, Any] = field(default_factory=dict)

    def get_provider_cfg(self) -> Dict[str, Any]:
        p = (self.provider or "").strip().lower()
        return self.providers.get(p, {}) if isinstance(self.providers, dict) else {}


@dataclass
class TranslationServiceConfig:
    """Dedicated config model for the translation service."""

    service_url: str
    timeout_sec: float
    default_model: str
    default_scene: str
    capabilities: Dict[str, Dict[str, Any]] = field(default_factory=dict)
    cache: Dict[str, Any] = field(default_factory=dict)

    def normalize_model_name(self, model: Optional[str]) -> str:
        normalized = str(model or self.default_model).strip().lower()
        aliases = {
            "qwen": "qwen-mt",
            "qwen-mt-flash": "qwen-mt",
            "qwen-mt-flush": "qwen-mt",
            "service": self.default_model,
            "default": self.default_model,
        }
        return aliases.get(normalized, normalized)

    @property
    def enabled_models(self) -> List[str]:
        items: List[str] = []
        for name, cfg in self.capabilities.items():
            if isinstance(cfg, dict) and bool(cfg.get("enabled", False)):
                items.append(str(name).strip().lower())
        return items

    def get_capability_cfg(self, model: Optional[str]) -> Dict[str, Any]:
        normalized = self.normalize_model_name(model)
        value = self.capabilities.get(normalized)
        return dict(value) if isinstance(value, dict) else {}


def _load_services_raw(config_path: Optional[Path] = None) -> Dict[str, Any]:
    if config_path is None:
        config_path = Path(__file__).parent / "config.yaml"
    path = Path(config_path)
    if not path.exists():
        raise FileNotFoundError(f"services config file not found: {path}")
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = yaml.safe_load(f)
    except Exception as exc:
        raise RuntimeError(f"failed to parse services config from {path}: {exc}") from exc
    if not isinstance(data, dict):
        raise RuntimeError(f"invalid config format in {path}: expected mapping root")
    services = data.get("services")
    if not isinstance(services, dict):
        raise RuntimeError("config.yaml must contain a valid 'services' mapping")
    return services


def _resolve_provider_name(env_name: str, config_provider: Any, capability: str) -> str:
    provider = os.getenv(env_name) or config_provider
    if not provider:
        raise ValueError(
            f"services.{capability}.provider is required "
            f"(or set env override {env_name})"
        )
    return str(provider).strip().lower()


def _resolve_translation() -> TranslationServiceConfig:
    raw = _load_services_raw()
    cfg = raw.get("translation", {}) if isinstance(raw.get("translation"), dict) else {}

    service_url = (
        os.getenv("TRANSLATION_SERVICE_URL")
        or cfg.get("service_url")
        or cfg.get("base_url")
        or "http://127.0.0.1:6006"
    )
    timeout_sec = float(os.getenv("TRANSLATION_TIMEOUT_SEC") or cfg.get("timeout_sec") or 10.0)

    raw_capabilities = cfg.get("capabilities")
    if not isinstance(raw_capabilities, dict):
        raw_capabilities = cfg.get("providers")
    capabilities = raw_capabilities if isinstance(raw_capabilities, dict) else {}

    default_model = str(
        os.getenv("TRANSLATION_MODEL")
        or cfg.get("default_model")
        or cfg.get("provider")
        or "qwen-mt"
    ).strip().lower()
    default_scene = str(
        os.getenv("TRANSLATION_SCENE")
        or cfg.get("default_scene")
        or "general"
    ).strip() or "general"

    resolved_capabilities: Dict[str, Dict[str, Any]] = {}
    for name, value in capabilities.items():
        if not isinstance(value, dict):
            continue
        normalized = str(name or "").strip().lower()
        if not normalized:
            continue
        copied = dict(value)
        copied.setdefault("enabled", normalized == default_model)
        resolved_capabilities[normalized] = copied

    aliases = {
        "qwen": "qwen-mt",
        "qwen-mt-flash": "qwen-mt",
        "qwen-mt-flush": "qwen-mt",
    }
    default_model = aliases.get(default_model, default_model)

    if default_model not in resolved_capabilities:
        raise ValueError(
            f"services.translation.default_model '{default_model}' is not defined in capabilities"
        )
    if not bool(resolved_capabilities[default_model].get("enabled", False)):
        resolved_capabilities[default_model]["enabled"] = True

    cache_cfg = cfg.get("cache", {}) if isinstance(cfg.get("cache"), dict) else {}

    return TranslationServiceConfig(
        service_url=str(service_url).rstrip("/"),
        timeout_sec=timeout_sec,
        default_model=default_model,
        default_scene=default_scene,
        capabilities=resolved_capabilities,
        cache=cache_cfg,
    )


def _resolve_embedding() -> ServiceConfig:
    raw = _load_services_raw()
    cfg = raw.get("embedding", {}) if isinstance(raw.get("embedding"), dict) else {}
    providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}

    provider = _resolve_provider_name(
        env_name="EMBEDDING_PROVIDER",
        config_provider=cfg.get("provider"),
        capability="embedding",
    )
    if provider != "http":
        raise ValueError(f"Unsupported embedding provider: {provider}")

    env_url = os.getenv("EMBEDDING_SERVICE_URL")
    if env_url and provider == "http":
        providers = dict(providers)
        providers["http"] = dict(providers.get("http", {}))
        providers["http"]["base_url"] = env_url.rstrip("/")

    return ServiceConfig(provider=provider, providers=providers)


def _resolve_rerank() -> ServiceConfig:
    raw = _load_services_raw()
    cfg = raw.get("rerank", {}) if isinstance(raw.get("rerank"), dict) else {}
    providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}

    provider = _resolve_provider_name(
        env_name="RERANK_PROVIDER",
        config_provider=cfg.get("provider"),
        capability="rerank",
    )
    if provider != "http":
        raise ValueError(f"Unsupported rerank provider: {provider}")

    env_url = os.getenv("RERANKER_SERVICE_URL")
    if env_url:
        url = env_url.rstrip("/")
        if not url.endswith("/rerank"):
            url = f"{url}/rerank" if "/rerank" not in url else url
        providers = dict(providers)
        providers["http"] = dict(providers.get("http", {}))
        providers["http"]["base_url"] = url.replace("/rerank", "")
        providers["http"]["service_url"] = url

    return ServiceConfig(provider=provider, providers=providers)


def get_rerank_backend_config() -> tuple[str, dict]:
    raw = _load_services_raw()
    cfg = raw.get("rerank", {}) if isinstance(raw.get("rerank"), dict) else {}
    backends = cfg.get("backends", {}) if isinstance(cfg.get("backends"), dict) else {}
    name = os.getenv("RERANK_BACKEND") or cfg.get("backend")
    if not name:
        raise ValueError("services.rerank.backend is required (or env RERANK_BACKEND)")
    name = str(name).strip().lower()
    backend_cfg = backends.get(name, {}) if isinstance(backends.get(name), dict) else {}
    if not backend_cfg:
        raise ValueError(f"services.rerank.backends.{name} is required")
    return name, backend_cfg


def get_embedding_backend_config() -> tuple[str, dict]:
    raw = _load_services_raw()
    cfg = raw.get("embedding", {}) if isinstance(raw.get("embedding"), dict) else {}
    backends = cfg.get("backends", {}) if isinstance(cfg.get("backends"), dict) else {}
    name = os.getenv("EMBEDDING_BACKEND") or cfg.get("backend")
    if not name:
        raise ValueError("services.embedding.backend is required (or env EMBEDDING_BACKEND)")
    name = str(name).strip().lower()
    backend_cfg = backends.get(name, {}) if isinstance(backends.get(name), dict) else {}
    if not backend_cfg:
        raise ValueError(f"services.embedding.backends.{name} is required")
    return name, backend_cfg


@lru_cache(maxsize=1)
def get_translation_config() -> TranslationServiceConfig:
    return _resolve_translation()


@lru_cache(maxsize=1)
def get_embedding_config() -> ServiceConfig:
    return _resolve_embedding()


@lru_cache(maxsize=1)
def get_rerank_config() -> ServiceConfig:
    return _resolve_rerank()


def get_translation_base_url() -> str:
    return get_translation_config().service_url


def get_translation_cache_config() -> Dict[str, Any]:
    cache_cfg = get_translation_config().cache
    return {
        "enabled": bool(cache_cfg.get("enabled", True)),
        "key_prefix": str(cache_cfg.get("key_prefix", "trans:v2")),
        "ttl_seconds": int(cache_cfg.get("ttl_seconds", 360 * 24 * 3600)),
        "sliding_expiration": bool(cache_cfg.get("sliding_expiration", True)),
        "key_include_context": bool(cache_cfg.get("key_include_context", True)),
        "key_include_prompt": bool(cache_cfg.get("key_include_prompt", True)),
        "key_include_source_lang": bool(cache_cfg.get("key_include_source_lang", True)),
    }


def get_embedding_base_url() -> str:
    base = os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_config().providers.get("http", {}).get("base_url")
    if not base:
        raise ValueError("Embedding HTTP base_url is not configured")
    return str(base).rstrip("/")


def get_rerank_base_url() -> str:
    base = (
        os.getenv("RERANKER_SERVICE_URL")
        or get_rerank_config().providers.get("http", {}).get("service_url")
        or get_rerank_config().providers.get("http", {}).get("base_url")
    )
    if not base:
        raise ValueError("Rerank HTTP base_url is not configured")
    return str(base).rstrip("/")


def get_rerank_service_url() -> str:
    """Backward-compatible alias."""
    return get_rerank_base_url()