Blame view

config/services_config.py 10.2 KB
42e3aea6   tangwang   tidy
1
  """
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
2
  Services configuration - single source for translation, embedding, rerank.
42e3aea6   tangwang   tidy
3
  
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
4
5
6
  Translation is modeled as:
  - one translator service endpoint used by business callers
  - multiple translation capabilities loaded inside the translator service
42e3aea6   tangwang   tidy
7
8
9
10
11
12
13
14
  """
  
  from __future__ import annotations
  
  import os
  from dataclasses import dataclass, field
  from functools import lru_cache
  from pathlib import Path
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
15
  from typing import Any, Dict, List, Optional
42e3aea6   tangwang   tidy
16
17
18
19
20
21
  
  import yaml
  
  
  @dataclass
  class ServiceConfig:
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
22
23
      """Config for one capability (embedding/rerank)."""
  
42e3aea6   tangwang   tidy
24
25
26
27
      provider: str
      providers: Dict[str, Any] = field(default_factory=dict)
  
      def get_provider_cfg(self) -> Dict[str, Any]:
42e3aea6   tangwang   tidy
28
29
30
31
          p = (self.provider or "").strip().lower()
          return self.providers.get(p, {}) if isinstance(self.providers, dict) else {}
  
  
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  @dataclass
  class TranslationServiceConfig:
      """Dedicated config model for the translation service."""
  
      service_url: str
      timeout_sec: float
      default_model: str
      default_scene: str
      capabilities: Dict[str, Dict[str, Any]] = field(default_factory=dict)
      cache: Dict[str, Any] = field(default_factory=dict)
  
      def normalize_model_name(self, model: Optional[str]) -> str:
          normalized = str(model or self.default_model).strip().lower()
          aliases = {
              "qwen": "qwen-mt",
              "qwen-mt-flash": "qwen-mt",
              "qwen-mt-flush": "qwen-mt",
              "service": self.default_model,
              "default": self.default_model,
          }
          return aliases.get(normalized, normalized)
  
      @property
      def enabled_models(self) -> List[str]:
          items: List[str] = []
          for name, cfg in self.capabilities.items():
              if isinstance(cfg, dict) and bool(cfg.get("enabled", False)):
                  items.append(str(name).strip().lower())
          return items
  
      def get_capability_cfg(self, model: Optional[str]) -> Dict[str, Any]:
          normalized = self.normalize_model_name(model)
          value = self.capabilities.get(normalized)
          return dict(value) if isinstance(value, dict) else {}
  
  
42e3aea6   tangwang   tidy
68
  def _load_services_raw(config_path: Optional[Path] = None) -> Dict[str, Any]:
42e3aea6   tangwang   tidy
69
70
71
72
      if config_path is None:
          config_path = Path(__file__).parent / "config.yaml"
      path = Path(config_path)
      if not path.exists():
26b910bd   tangwang   refactor service ...
73
          raise FileNotFoundError(f"services config file not found: {path}")
42e3aea6   tangwang   tidy
74
75
76
      try:
          with open(path, "r", encoding="utf-8") as f:
              data = yaml.safe_load(f)
26b910bd   tangwang   refactor service ...
77
78
79
80
81
82
83
84
85
86
      except Exception as exc:
          raise RuntimeError(f"failed to parse services config from {path}: {exc}") from exc
      if not isinstance(data, dict):
          raise RuntimeError(f"invalid config format in {path}: expected mapping root")
      services = data.get("services")
      if not isinstance(services, dict):
          raise RuntimeError("config.yaml must contain a valid 'services' mapping")
      return services
  
  
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
87
  def _resolve_provider_name(env_name: str, config_provider: Any, capability: str) -> str:
26b910bd   tangwang   refactor service ...
88
89
90
91
92
93
94
      provider = os.getenv(env_name) or config_provider
      if not provider:
          raise ValueError(
              f"services.{capability}.provider is required "
              f"(or set env override {env_name})"
          )
      return str(provider).strip().lower()
42e3aea6   tangwang   tidy
95
96
  
  
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
97
  def _resolve_translation() -> TranslationServiceConfig:
42e3aea6   tangwang   tidy
98
99
      raw = _load_services_raw()
      cfg = raw.get("translation", {}) if isinstance(raw.get("translation"), dict) else {}
42e3aea6   tangwang   tidy
100
  
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
101
102
103
104
105
      service_url = (
          os.getenv("TRANSLATION_SERVICE_URL")
          or cfg.get("service_url")
          or cfg.get("base_url")
          or "http://127.0.0.1:6006"
42e3aea6   tangwang   tidy
106
      )
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
      timeout_sec = float(os.getenv("TRANSLATION_TIMEOUT_SEC") or cfg.get("timeout_sec") or 10.0)
  
      raw_capabilities = cfg.get("capabilities")
      if not isinstance(raw_capabilities, dict):
          raw_capabilities = cfg.get("providers")
      capabilities = raw_capabilities if isinstance(raw_capabilities, dict) else {}
  
      default_model = str(
          os.getenv("TRANSLATION_MODEL")
          or cfg.get("default_model")
          or cfg.get("provider")
          or "qwen-mt"
      ).strip().lower()
      default_scene = str(
          os.getenv("TRANSLATION_SCENE")
          or cfg.get("default_scene")
          or "general"
      ).strip() or "general"
  
      resolved_capabilities: Dict[str, Dict[str, Any]] = {}
      for name, value in capabilities.items():
          if not isinstance(value, dict):
              continue
          normalized = str(name or "").strip().lower()
          if not normalized:
              continue
          copied = dict(value)
          copied.setdefault("enabled", normalized == default_model)
          resolved_capabilities[normalized] = copied
  
      aliases = {
          "qwen": "qwen-mt",
          "qwen-mt-flash": "qwen-mt",
          "qwen-mt-flush": "qwen-mt",
      }
      default_model = aliases.get(default_model, default_model)
42e3aea6   tangwang   tidy
143
  
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
144
145
146
147
148
149
      if default_model not in resolved_capabilities:
          raise ValueError(
              f"services.translation.default_model '{default_model}' is not defined in capabilities"
          )
      if not bool(resolved_capabilities[default_model].get("enabled", False)):
          resolved_capabilities[default_model]["enabled"] = True
42e3aea6   tangwang   tidy
150
  
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
151
152
153
154
155
156
157
158
159
160
      cache_cfg = cfg.get("cache", {}) if isinstance(cfg.get("cache"), dict) else {}
  
      return TranslationServiceConfig(
          service_url=str(service_url).rstrip("/"),
          timeout_sec=timeout_sec,
          default_model=default_model,
          default_scene=default_scene,
          capabilities=resolved_capabilities,
          cache=cache_cfg,
      )
42e3aea6   tangwang   tidy
161
162
163
164
165
166
167
  
  
  def _resolve_embedding() -> ServiceConfig:
      raw = _load_services_raw()
      cfg = raw.get("embedding", {}) if isinstance(raw.get("embedding"), dict) else {}
      providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}
  
26b910bd   tangwang   refactor service ...
168
169
170
171
      provider = _resolve_provider_name(
          env_name="EMBEDDING_PROVIDER",
          config_provider=cfg.get("provider"),
          capability="embedding",
42e3aea6   tangwang   tidy
172
      )
26b910bd   tangwang   refactor service ...
173
174
      if provider != "http":
          raise ValueError(f"Unsupported embedding provider: {provider}")
42e3aea6   tangwang   tidy
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
  
      env_url = os.getenv("EMBEDDING_SERVICE_URL")
      if env_url and provider == "http":
          providers = dict(providers)
          providers["http"] = dict(providers.get("http", {}))
          providers["http"]["base_url"] = env_url.rstrip("/")
  
      return ServiceConfig(provider=provider, providers=providers)
  
  
  def _resolve_rerank() -> ServiceConfig:
      raw = _load_services_raw()
      cfg = raw.get("rerank", {}) if isinstance(raw.get("rerank"), dict) else {}
      providers = cfg.get("providers", {}) if isinstance(cfg.get("providers"), dict) else {}
  
26b910bd   tangwang   refactor service ...
190
191
192
193
      provider = _resolve_provider_name(
          env_name="RERANK_PROVIDER",
          config_provider=cfg.get("provider"),
          capability="rerank",
42e3aea6   tangwang   tidy
194
      )
26b910bd   tangwang   refactor service ...
195
196
      if provider != "http":
          raise ValueError(f"Unsupported rerank provider: {provider}")
42e3aea6   tangwang   tidy
197
198
199
200
201
202
203
204
205
206
207
208
209
210
  
      env_url = os.getenv("RERANKER_SERVICE_URL")
      if env_url:
          url = env_url.rstrip("/")
          if not url.endswith("/rerank"):
              url = f"{url}/rerank" if "/rerank" not in url else url
          providers = dict(providers)
          providers["http"] = dict(providers.get("http", {}))
          providers["http"]["base_url"] = url.replace("/rerank", "")
          providers["http"]["service_url"] = url
  
      return ServiceConfig(provider=provider, providers=providers)
  
  
701ae503   tangwang   docs
211
  def get_rerank_backend_config() -> tuple[str, dict]:
701ae503   tangwang   docs
212
213
214
      raw = _load_services_raw()
      cfg = raw.get("rerank", {}) if isinstance(raw.get("rerank"), dict) else {}
      backends = cfg.get("backends", {}) if isinstance(cfg.get("backends"), dict) else {}
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
215
      name = os.getenv("RERANK_BACKEND") or cfg.get("backend")
26b910bd   tangwang   refactor service ...
216
217
      if not name:
          raise ValueError("services.rerank.backend is required (or env RERANK_BACKEND)")
07cf5a93   tangwang   START_EMBEDDING=...
218
219
      name = str(name).strip().lower()
      backend_cfg = backends.get(name, {}) if isinstance(backends.get(name), dict) else {}
26b910bd   tangwang   refactor service ...
220
221
      if not backend_cfg:
          raise ValueError(f"services.rerank.backends.{name} is required")
07cf5a93   tangwang   START_EMBEDDING=...
222
223
224
225
      return name, backend_cfg
  
  
  def get_embedding_backend_config() -> tuple[str, dict]:
07cf5a93   tangwang   START_EMBEDDING=...
226
227
228
      raw = _load_services_raw()
      cfg = raw.get("embedding", {}) if isinstance(raw.get("embedding"), dict) else {}
      backends = cfg.get("backends", {}) if isinstance(cfg.get("backends"), dict) else {}
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
229
      name = os.getenv("EMBEDDING_BACKEND") or cfg.get("backend")
26b910bd   tangwang   refactor service ...
230
231
      if not name:
          raise ValueError("services.embedding.backend is required (or env EMBEDDING_BACKEND)")
701ae503   tangwang   docs
232
233
      name = str(name).strip().lower()
      backend_cfg = backends.get(name, {}) if isinstance(backends.get(name), dict) else {}
26b910bd   tangwang   refactor service ...
234
235
      if not backend_cfg:
          raise ValueError(f"services.embedding.backends.{name} is required")
701ae503   tangwang   docs
236
237
238
      return name, backend_cfg
  
  
42e3aea6   tangwang   tidy
239
  @lru_cache(maxsize=1)
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
240
  def get_translation_config() -> TranslationServiceConfig:
42e3aea6   tangwang   tidy
241
242
243
244
245
      return _resolve_translation()
  
  
  @lru_cache(maxsize=1)
  def get_embedding_config() -> ServiceConfig:
42e3aea6   tangwang   tidy
246
247
248
249
250
      return _resolve_embedding()
  
  
  @lru_cache(maxsize=1)
  def get_rerank_config() -> ServiceConfig:
42e3aea6   tangwang   tidy
251
252
253
254
      return _resolve_rerank()
  
  
  def get_translation_base_url() -> str:
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
255
      return get_translation_config().service_url
42e3aea6   tangwang   tidy
256
257
  
  
d4cadc13   tangwang   翻译重构
258
  def get_translation_cache_config() -> Dict[str, Any]:
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
259
      cache_cfg = get_translation_config().cache
d4cadc13   tangwang   翻译重构
260
261
262
263
264
265
266
267
268
269
270
      return {
          "enabled": bool(cache_cfg.get("enabled", True)),
          "key_prefix": str(cache_cfg.get("key_prefix", "trans:v2")),
          "ttl_seconds": int(cache_cfg.get("ttl_seconds", 360 * 24 * 3600)),
          "sliding_expiration": bool(cache_cfg.get("sliding_expiration", True)),
          "key_include_context": bool(cache_cfg.get("key_include_context", True)),
          "key_include_prompt": bool(cache_cfg.get("key_include_prompt", True)),
          "key_include_source_lang": bool(cache_cfg.get("key_include_source_lang", True)),
      }
  
  
42e3aea6   tangwang   tidy
271
  def get_embedding_base_url() -> str:
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
272
      base = os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_config().providers.get("http", {}).get("base_url")
26b910bd   tangwang   refactor service ...
273
274
      if not base:
          raise ValueError("Embedding HTTP base_url is not configured")
42e3aea6   tangwang   tidy
275
276
277
      return str(base).rstrip("/")
  
  
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
278
  def get_rerank_base_url() -> str:
42e3aea6   tangwang   tidy
279
280
281
282
      base = (
          os.getenv("RERANKER_SERVICE_URL")
          or get_rerank_config().providers.get("http", {}).get("service_url")
          or get_rerank_config().providers.get("http", {}).get("base_url")
42e3aea6   tangwang   tidy
283
      )
26b910bd   tangwang   refactor service ...
284
      if not base:
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
285
286
          raise ValueError("Rerank HTTP base_url is not configured")
      return str(base).rstrip("/")
42e3aea6   tangwang   tidy
287
288
  
  
5e4dc8e4   tangwang   翻译架构按“一个翻译服务 +
289
290
291
  def get_rerank_service_url() -> str:
      """Backward-compatible alias."""
      return get_rerank_base_url()