services_config.py 7.4 KB
"""
Services configuration - single source for translation, embedding, rerank.

Translation is modeled as:
- one translator service endpoint used by business callers
- multiple translation capabilities loaded inside the translator service
"""

from __future__ import annotations

import os
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional

import yaml
from translation.settings import TranslationConfig, build_translation_config, get_translation_cache


@dataclass
class ServiceConfig:
    """Config for one capability (embedding/rerank)."""

    provider: str
    providers: Dict[str, Any] = field(default_factory=dict)

    def get_provider_cfg(self) -> Dict[str, Any]:
        p = (self.provider or "").strip().lower()
        return self.providers.get(p, {}) if isinstance(self.providers, dict) else {}


def _load_services_raw(config_path: Optional[Path] = None) -> Dict[str, Any]:
    if config_path is None:
        config_path = Path(__file__).parent / "config.yaml"
    path = Path(config_path)
    if not path.exists():
        raise FileNotFoundError(f"services config file not found: {path}")
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = yaml.safe_load(f)
    except Exception as exc:
        raise RuntimeError(f"failed to parse services config from {path}: {exc}") from exc
    if not isinstance(data, dict):
        raise RuntimeError(f"invalid config format in {path}: expected mapping root")
    services = data.get("services")
    if not isinstance(services, dict):
        raise RuntimeError("config.yaml must contain a valid 'services' mapping")
    return services


def _resolve_provider_name(env_name: str, config_provider: Any, capability: str) -> str:
    provider = os.getenv(env_name) or config_provider
    if not provider:
        raise ValueError(
            f"services.{capability}.provider is required "
            f"(or set env override {env_name})"
        )
    return str(provider).strip().lower()


def _resolve_translation() -> TranslationConfig:
    raw = _load_services_raw()
    cfg = raw.get("translation", {}) if isinstance(raw.get("translation"), dict) else {}
    return build_translation_config(cfg)


def _resolve_embedding() -> ServiceConfig:
    raw = _load_services_raw()
    cfg = raw.get("embedding", {}) if isinstance(raw.get("embedding"), dict) else {}
    providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}

    provider = _resolve_provider_name(
        env_name="EMBEDDING_PROVIDER",
        config_provider=cfg.get("provider"),
        capability="embedding",
    )
    if provider != "http":
        raise ValueError(f"Unsupported embedding provider: {provider}")

    env_text_url = os.getenv("EMBEDDING_TEXT_SERVICE_URL")
    env_image_url = os.getenv("EMBEDDING_IMAGE_SERVICE_URL")
    if provider == "http":
        providers = dict(providers)
        http_cfg = dict(providers.get("http", {}))
        if env_text_url:
            http_cfg["text_base_url"] = env_text_url.rstrip("/")
        if env_image_url:
            http_cfg["image_base_url"] = env_image_url.rstrip("/")
        if not http_cfg.get("text_base_url"):
            raise ValueError("services.embedding.providers.http.text_base_url is required")
        if not http_cfg.get("image_base_url"):
            raise ValueError("services.embedding.providers.http.image_base_url is required")
        providers["http"] = http_cfg

    return ServiceConfig(provider=provider, providers=providers)


def _resolve_rerank() -> ServiceConfig:
    raw = _load_services_raw()
    cfg = raw.get("rerank", {}) if isinstance(raw.get("rerank"), dict) else {}
    providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}

    provider = _resolve_provider_name(
        env_name="RERANK_PROVIDER",
        config_provider=cfg.get("provider"),
        capability="rerank",
    )
    if provider != "http":
        raise ValueError(f"Unsupported rerank provider: {provider}")

    env_url = os.getenv("RERANKER_SERVICE_URL")
    if env_url:
        url = env_url.rstrip("/")
        if not url.endswith("/rerank"):
            url = f"{url}/rerank" if "/rerank" not in url else url
        providers = dict(providers)
        providers["http"] = dict(providers.get("http", {}))
        providers["http"]["base_url"] = url.replace("/rerank", "")
        providers["http"]["service_url"] = url

    return ServiceConfig(provider=provider, providers=providers)


def get_rerank_backend_config() -> tuple[str, dict]:
    raw = _load_services_raw()
    cfg = raw.get("rerank", {}) if isinstance(raw.get("rerank"), dict) else {}
    backends = cfg.get("backends", {}) if isinstance(cfg.get("backends"), dict) else {}
    name = os.getenv("RERANK_BACKEND") or cfg.get("backend")
    if not name:
        raise ValueError("services.rerank.backend is required (or env RERANK_BACKEND)")
    name = str(name).strip().lower()
    backend_cfg = backends.get(name, {}) if isinstance(backends.get(name), dict) else {}
    if not backend_cfg:
        raise ValueError(f"services.rerank.backends.{name} is required")
    return name, backend_cfg


def get_embedding_backend_config() -> tuple[str, dict]:
    raw = _load_services_raw()
    cfg = raw.get("embedding", {}) if isinstance(raw.get("embedding"), dict) else {}
    backends = cfg.get("backends", {}) if isinstance(cfg.get("backends"), dict) else {}
    name = os.getenv("EMBEDDING_BACKEND") or cfg.get("backend")
    if not name:
        raise ValueError("services.embedding.backend is required (or env EMBEDDING_BACKEND)")
    name = str(name).strip().lower()
    backend_cfg = backends.get(name, {}) if isinstance(backends.get(name), dict) else {}
    if not backend_cfg:
        raise ValueError(f"services.embedding.backends.{name} is required")
    return name, backend_cfg


@lru_cache(maxsize=1)
def get_translation_config() -> TranslationConfig:
    return _resolve_translation()


@lru_cache(maxsize=1)
def get_embedding_config() -> ServiceConfig:
    return _resolve_embedding()


@lru_cache(maxsize=1)
def get_rerank_config() -> ServiceConfig:
    return _resolve_rerank()


def get_translation_base_url() -> str:
    return str(get_translation_config()["service_url"])


def get_translation_cache_config() -> Dict[str, Any]:
    return get_translation_cache(get_translation_config())


def get_embedding_text_base_url() -> str:
    provider_cfg = get_embedding_config().providers.get("http", {})
    base = os.getenv("EMBEDDING_TEXT_SERVICE_URL") or provider_cfg.get("text_base_url")
    if not base:
        raise ValueError("Embedding text HTTP base_url is not configured")
    return str(base).rstrip("/")


def get_embedding_image_base_url() -> str:
    provider_cfg = get_embedding_config().providers.get("http", {})
    base = os.getenv("EMBEDDING_IMAGE_SERVICE_URL") or provider_cfg.get("image_base_url")
    if not base:
        raise ValueError("Embedding image HTTP base_url is not configured")
    return str(base).rstrip("/")


def get_rerank_base_url() -> str:
    base = (
        os.getenv("RERANKER_SERVICE_URL")
        or get_rerank_config().providers.get("http", {}).get("service_url")
        or get_rerank_config().providers.get("http", {}).get("base_url")
    )
    if not base:
        raise ValueError("Rerank HTTP base_url is not configured")
    return str(base).rstrip("/")


def get_rerank_service_url() -> str:
    """Backward-compatible alias."""
    return get_rerank_base_url()