services_config.py 8.17 KB
"""
Services configuration - single source for translation, embedding, rerank providers.

All provider selection and endpoint resolution is centralized here.
Priority: env vars > config.yaml.
"""

from __future__ import annotations

import os
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional

import yaml


@dataclass
class ServiceConfig:
    """Config for one capability (translation/embedding/rerank)."""
    provider: str
    providers: Dict[str, Any] = field(default_factory=dict)

    def get_provider_cfg(self) -> Dict[str, Any]:
        """Get config for current provider."""
        p = (self.provider or "").strip().lower()
        return self.providers.get(p, {}) if isinstance(self.providers, dict) else {}


def _load_services_raw(config_path: Optional[Path] = None) -> Dict[str, Any]:
    """Load services block from config.yaml."""
    if config_path is None:
        config_path = Path(__file__).parent / "config.yaml"
    path = Path(config_path)
    if not path.exists():
        raise FileNotFoundError(f"services config file not found: {path}")
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = yaml.safe_load(f)
    except Exception as exc:
        raise RuntimeError(f"failed to parse services config from {path}: {exc}") from exc
    if not isinstance(data, dict):
        raise RuntimeError(f"invalid config format in {path}: expected mapping root")
    services = data.get("services")
    if not isinstance(services, dict):
        raise RuntimeError("config.yaml must contain a valid 'services' mapping")
    return services


def _resolve_provider_name(
    env_name: str,
    config_provider: Any,
    capability: str,
) -> str:
    provider = os.getenv(env_name) or config_provider
    if not provider:
        raise ValueError(
            f"services.{capability}.provider is required "
            f"(or set env override {env_name})"
        )
    return str(provider).strip().lower()


def _resolve_translation() -> ServiceConfig:
    raw = _load_services_raw()
    cfg = raw.get("translation", {}) if isinstance(raw.get("translation"), dict) else {}
    providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}

    provider = _resolve_provider_name(
        env_name="TRANSLATION_PROVIDER",
        config_provider=cfg.get("provider"),
        capability="translation",
    )
    if provider not in ("direct", "local", "inprocess", "http", "service"):
        raise ValueError(f"Unsupported translation provider: {provider}")

    # Env override for http base_url
    env_url = os.getenv("TRANSLATION_SERVICE_URL")
    if env_url and provider == "http":
        providers = dict(providers)
        providers["http"] = dict(providers.get("http", {}))
        providers["http"]["base_url"] = env_url.rstrip("/")

    return ServiceConfig(provider=provider, providers=providers)


def _resolve_embedding() -> ServiceConfig:
    raw = _load_services_raw()
    cfg = raw.get("embedding", {}) if isinstance(raw.get("embedding"), dict) else {}
    providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}

    provider = _resolve_provider_name(
        env_name="EMBEDDING_PROVIDER",
        config_provider=cfg.get("provider"),
        capability="embedding",
    )
    if provider != "http":
        raise ValueError(f"Unsupported embedding provider: {provider}")

    env_url = os.getenv("EMBEDDING_SERVICE_URL")
    if env_url and provider == "http":
        providers = dict(providers)
        providers["http"] = dict(providers.get("http", {}))
        providers["http"]["base_url"] = env_url.rstrip("/")

    return ServiceConfig(provider=provider, providers=providers)


def _resolve_rerank() -> ServiceConfig:
    raw = _load_services_raw()
    cfg = raw.get("rerank", {}) if isinstance(raw.get("rerank"), dict) else {}
    providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}

    provider = _resolve_provider_name(
        env_name="RERANK_PROVIDER",
        config_provider=cfg.get("provider"),
        capability="rerank",
    )
    if provider != "http":
        raise ValueError(f"Unsupported rerank provider: {provider}")

    env_url = os.getenv("RERANKER_SERVICE_URL")
    if env_url:
        url = env_url.rstrip("/")
        if not url.endswith("/rerank"):
            url = f"{url}/rerank" if "/rerank" not in url else url
        providers = dict(providers)
        providers["http"] = dict(providers.get("http", {}))
        providers["http"]["base_url"] = url.replace("/rerank", "")
        providers["http"]["service_url"] = url

    return ServiceConfig(provider=provider, providers=providers)


def get_rerank_backend_config() -> tuple[str, dict]:
    """
    Resolve reranker backend name and config for the reranker service process.
    Returns (backend_name, backend_cfg).
    Env RERANK_BACKEND overrides config.
    """
    raw = _load_services_raw()
    cfg = raw.get("rerank", {}) if isinstance(raw.get("rerank"), dict) else {}
    backends = cfg.get("backends", {}) if isinstance(cfg.get("backends"), dict) else {}
    name = (
        os.getenv("RERANK_BACKEND")
        or cfg.get("backend")
    )
    if not name:
        raise ValueError("services.rerank.backend is required (or env RERANK_BACKEND)")
    name = str(name).strip().lower()
    backend_cfg = backends.get(name, {}) if isinstance(backends.get(name), dict) else {}
    if not backend_cfg:
        raise ValueError(f"services.rerank.backends.{name} is required")
    return name, backend_cfg


def get_embedding_backend_config() -> tuple[str, dict]:
    """
    Resolve embedding backend name and config for the embedding service process.
    Returns (backend_name, backend_cfg).
    Env EMBEDDING_BACKEND overrides config.
    """
    raw = _load_services_raw()
    cfg = raw.get("embedding", {}) if isinstance(raw.get("embedding"), dict) else {}
    backends = cfg.get("backends", {}) if isinstance(cfg.get("backends"), dict) else {}
    name = (
        os.getenv("EMBEDDING_BACKEND")
        or cfg.get("backend")
    )
    if not name:
        raise ValueError("services.embedding.backend is required (or env EMBEDDING_BACKEND)")
    name = str(name).strip().lower()
    backend_cfg = backends.get(name, {}) if isinstance(backends.get(name), dict) else {}
    if not backend_cfg:
        raise ValueError(f"services.embedding.backends.{name} is required")
    return name, backend_cfg


@lru_cache(maxsize=1)
def get_translation_config() -> ServiceConfig:
    """Get translation service config."""
    return _resolve_translation()


@lru_cache(maxsize=1)
def get_embedding_config() -> ServiceConfig:
    """Get embedding service config."""
    return _resolve_embedding()


@lru_cache(maxsize=1)
def get_rerank_config() -> ServiceConfig:
    """Get rerank service config."""
    return _resolve_rerank()


def get_translation_base_url() -> str:
    """Resolve translation HTTP base URL (for http provider)."""
    base = (
        os.getenv("TRANSLATION_SERVICE_URL")
        or get_translation_config().providers.get("http", {}).get("base_url")
    )
    if not base:
        raise ValueError("Translation HTTP base_url is not configured")
    return str(base).rstrip("/")


def get_embedding_base_url() -> str:
    """Resolve embedding HTTP base URL."""
    base = (
        os.getenv("EMBEDDING_SERVICE_URL")
        or get_embedding_config().providers.get("http", {}).get("base_url")
    )
    if not base:
        raise ValueError("Embedding HTTP base_url is not configured")
    return str(base).rstrip("/")


def get_rerank_service_url() -> str:
    """Resolve rerank service URL (full path including /rerank)."""
    base = (
        os.getenv("RERANKER_SERVICE_URL")
        or get_rerank_config().providers.get("http", {}).get("service_url")
        or get_rerank_config().providers.get("http", {}).get("base_url")
    )
    if not base:
        raise ValueError("Rerank HTTP service_url/base_url is not configured")
    base = str(base).rstrip("/")
    return base if base.endswith("/rerank") else f"{base}/rerank"


def clear_services_cache() -> None:
    """Clear cached config (for tests)."""
    get_translation_config.cache_clear()
    get_embedding_config.cache_clear()
    get_rerank_config.cache_clear()