services_config.py 6.65 KB
"""
Services configuration - single source for translation, embedding, rerank providers.

All provider selection and endpoint resolution is centralized here.
Priority: env vars > config.yaml > defaults.
"""

from __future__ import annotations

import os
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional

import yaml


@dataclass
class ServiceConfig:
    """Config for one capability (translation/embedding/rerank)."""
    provider: str
    providers: Dict[str, Any] = field(default_factory=dict)

    def get_provider_cfg(self) -> Dict[str, Any]:
        """Get config for current provider."""
        p = (self.provider or "").strip().lower()
        return self.providers.get(p, {}) if isinstance(self.providers, dict) else {}


def _load_services_raw(config_path: Optional[Path] = None) -> Dict[str, Any]:
    """Load services block from config.yaml."""
    if config_path is None:
        config_path = Path(__file__).parent / "config.yaml"
    path = Path(config_path)
    if not path.exists():
        return {}
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = yaml.safe_load(f)
    except Exception:
        return {}
    services = data.get("services") if isinstance(data, dict) else {}
    return services if isinstance(services, dict) else {}


def _resolve_translation() -> ServiceConfig:
    raw = _load_services_raw()
    cfg = raw.get("translation", {}) if isinstance(raw.get("translation"), dict) else {}
    providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}

    provider = (
        os.getenv("TRANSLATION_PROVIDER")
        or cfg.get("provider")
        or "direct"
    )
    provider = str(provider).strip().lower()

    # Env override for http base_url
    env_url = os.getenv("TRANSLATION_SERVICE_URL")
    if env_url and provider == "http":
        providers = dict(providers)
        providers["http"] = dict(providers.get("http", {}))
        providers["http"]["base_url"] = env_url.rstrip("/")

    return ServiceConfig(provider=provider, providers=providers)


def _resolve_embedding() -> ServiceConfig:
    raw = _load_services_raw()
    cfg = raw.get("embedding", {}) if isinstance(raw.get("embedding"), dict) else {}
    providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}

    provider = (
        os.getenv("EMBEDDING_PROVIDER")
        or cfg.get("provider")
        or "http"
    )
    provider = str(provider).strip().lower()

    env_url = os.getenv("EMBEDDING_SERVICE_URL")
    if env_url and provider == "http":
        providers = dict(providers)
        providers["http"] = dict(providers.get("http", {}))
        providers["http"]["base_url"] = env_url.rstrip("/")

    return ServiceConfig(provider=provider, providers=providers)


def _resolve_rerank() -> ServiceConfig:
    raw = _load_services_raw()
    cfg = raw.get("rerank", {}) if isinstance(raw.get("rerank"), dict) else {}
    providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}

    provider = (
        os.getenv("RERANK_PROVIDER")
        or cfg.get("provider")
        or "http"
    )
    provider = str(provider).strip().lower()

    env_url = os.getenv("RERANKER_SERVICE_URL")
    if env_url:
        url = env_url.rstrip("/")
        if not url.endswith("/rerank"):
            url = f"{url}/rerank" if "/rerank" not in url else url
        providers = dict(providers)
        providers["http"] = dict(providers.get("http", {}))
        providers["http"]["base_url"] = url.replace("/rerank", "")
        providers["http"]["service_url"] = url

    return ServiceConfig(provider=provider, providers=providers)


def get_rerank_backend_config() -> tuple[str, dict]:
    """
    Resolve reranker backend name and config for the reranker service process.
    Returns (backend_name, backend_cfg).
    Env RERANK_BACKEND overrides config.
    """
    raw = _load_services_raw()
    cfg = raw.get("rerank", {}) if isinstance(raw.get("rerank"), dict) else {}
    backends = cfg.get("backends", {}) if isinstance(cfg.get("backends"), dict) else {}
    name = (
        os.getenv("RERANK_BACKEND")
        or cfg.get("backend")
        or "qwen3_vllm"
    )
    name = str(name).strip().lower()
    backend_cfg = backends.get(name, {}) if isinstance(backends.get(name), dict) else {}
    return name, backend_cfg


def get_embedding_backend_config() -> tuple[str, dict]:
    """
    Resolve embedding backend name and config for the embedding service process.
    Returns (backend_name, backend_cfg).
    Env EMBEDDING_BACKEND overrides config.
    """
    raw = _load_services_raw()
    cfg = raw.get("embedding", {}) if isinstance(raw.get("embedding"), dict) else {}
    backends = cfg.get("backends", {}) if isinstance(cfg.get("backends"), dict) else {}
    name = (
        os.getenv("EMBEDDING_BACKEND")
        or cfg.get("backend")
        or "tei"
    )
    name = str(name).strip().lower()
    backend_cfg = backends.get(name, {}) if isinstance(backends.get(name), dict) else {}
    return name, backend_cfg


@lru_cache(maxsize=1)
def get_translation_config() -> ServiceConfig:
    """Get translation service config."""
    return _resolve_translation()


@lru_cache(maxsize=1)
def get_embedding_config() -> ServiceConfig:
    """Get embedding service config."""
    return _resolve_embedding()


@lru_cache(maxsize=1)
def get_rerank_config() -> ServiceConfig:
    """Get rerank service config."""
    return _resolve_rerank()


def get_translation_base_url() -> str:
    """Resolve translation HTTP base URL (for http provider)."""
    base = (
        os.getenv("TRANSLATION_SERVICE_URL")
        or get_translation_config().providers.get("http", {}).get("base_url")
        or "http://127.0.0.1:6006"
    )
    return str(base).rstrip("/")


def get_embedding_base_url() -> str:
    """Resolve embedding HTTP base URL."""
    base = (
        os.getenv("EMBEDDING_SERVICE_URL")
        or get_embedding_config().providers.get("http", {}).get("base_url")
        or "http://127.0.0.1:6005"
    )
    return str(base).rstrip("/")


def get_rerank_service_url() -> str:
    """Resolve rerank service URL (full path including /rerank)."""
    base = (
        os.getenv("RERANKER_SERVICE_URL")
        or get_rerank_config().providers.get("http", {}).get("service_url")
        or get_rerank_config().providers.get("http", {}).get("base_url")
        or "http://127.0.0.1:6007"
    )
    base = str(base).rstrip("/")
    return base if base.endswith("/rerank") else f"{base}/rerank"


def clear_services_cache() -> None:
    """Clear cached config (for tests)."""
    get_translation_config.cache_clear()
    get_embedding_config.cache_clear()
    get_rerank_config.cache_clear()