Commit 7214c2e7ba4d70785b5c0a0a1a1c145448c3f5ef
1 parent
4747e2f4
mplemented**
- Text and image embedding are now split into separate services/processes, while still keeping a single replica as requested. The split lives in [embeddings/server.py](/data/saas-search/embeddings/server.py#L112), [config/services_config.py](/data/saas-search/config/services_config.py#L68), [providers/embedding.py](/data/saas-search/providers/embedding.py#L27), and the start scripts [scripts/start_embedding_service.sh](/data/saas-search/scripts/start_embedding_service.sh#L36), [scripts/start_embedding_text_service.sh](/data/saas-search/scripts/start_embedding_text_service.sh), [scripts/start_embedding_image_service.sh](/data/saas-search/scripts/start_embedding_image_service.sh). - Independent admission control is in place now: text and image have separate inflight limits, and image can be kept much stricter than text. The request handling, reject path, `/health`, and `/ready` are in [embeddings/server.py](/data/saas-search/embeddings/server.py#L613), [embeddings/server.py](/data/saas-search/embeddings/server.py#L786), and [embeddings/server.py](/data/saas-search/embeddings/server.py#L1028). - I checked the Redis embedding cache. It did exist, but there was a real flaw: cache keys did not distinguish `normalize=true` from `normalize=false`. I fixed that in [embeddings/cache_keys.py](/data/saas-search/embeddings/cache_keys.py#L6), and both text and image now use the same normalize-aware keying. I also added service-side BF16 cache hits that short-circuit before the model lane, so repeated requests no longer get throttled behind image inference. **What This Means** - Image pressure no longer blocks text, because they are on different ports/processes. - Repeated text/image requests now return from Redis without consuming model capacity. - Over-capacity requests are rejected quickly instead of sitting blocked. - I did not add a load balancer or multi-replica HA, per your GPU constraint. I also did not build Grafana/Prometheus dashboards in this pass, but `/health` now exposes the metrics needed to wire them. **Validation** - Tests passed: `.venv/bin/python -m pytest -q tests/test_embedding_pipeline.py tests/test_embedding_service_limits.py` -> `10 passed` - Stress test tool updates are in [scripts/perf_api_benchmark.py](/data/saas-search/scripts/perf_api_benchmark.py#L155) - Fresh benchmark on split text service `6105`: 535 requests / 3s, 100% success, `174.56 rps`, avg `88.48 ms` - Fresh benchmark on split image service `6108`: 1213 requests / 3s, 100% success, `403.32 rps`, avg `9.64 ms` - Live health after the run showed cache hits and non-zero cache-hit latency accounting: - text `avg_latency_ms=4.251` - image `avg_latency_ms=1.462`
Showing
18 changed files
with
744 additions
and
116 deletions
Show diff stats
config/__init__.py
| ... | ... | @@ -27,6 +27,8 @@ from .services_config import ( |
| 27 | 27 | get_rerank_backend_config, |
| 28 | 28 | get_translation_base_url, |
| 29 | 29 | get_embedding_base_url, |
| 30 | + get_embedding_text_base_url, | |
| 31 | + get_embedding_image_base_url, | |
| 30 | 32 | get_rerank_service_url, |
| 31 | 33 | get_translation_cache_config, |
| 32 | 34 | ServiceConfig, |
| ... | ... | @@ -53,6 +55,8 @@ __all__ = [ |
| 53 | 55 | 'get_rerank_backend_config', |
| 54 | 56 | 'get_translation_base_url', |
| 55 | 57 | 'get_embedding_base_url', |
| 58 | + 'get_embedding_text_base_url', | |
| 59 | + 'get_embedding_image_base_url', | |
| 56 | 60 | 'get_rerank_service_url', |
| 57 | 61 | 'get_translation_cache_config', |
| 58 | 62 | 'ServiceConfig', | ... | ... |
config/config.yaml
| ... | ... | @@ -199,6 +199,8 @@ services: |
| 199 | 199 | providers: |
| 200 | 200 | http: |
| 201 | 201 | base_url: "http://127.0.0.1:6005" |
| 202 | + text_base_url: "http://127.0.0.1:6005" | |
| 203 | + image_base_url: "http://127.0.0.1:6008" | |
| 202 | 204 | # 服务内文本后端(embedding 进程启动时读取) |
| 203 | 205 | backend: "tei" # tei | local_st |
| 204 | 206 | backends: | ... | ... |
config/env_config.py
| ... | ... | @@ -61,6 +61,10 @@ INDEXER_PORT = int(os.getenv('INDEXER_PORT', 6004)) |
| 61 | 61 | # Optional dependent services |
| 62 | 62 | EMBEDDING_HOST = os.getenv('EMBEDDING_HOST', '127.0.0.1') |
| 63 | 63 | EMBEDDING_PORT = int(os.getenv('EMBEDDING_PORT', 6005)) |
| 64 | +EMBEDDING_TEXT_HOST = os.getenv('EMBEDDING_TEXT_HOST', EMBEDDING_HOST) | |
| 65 | +EMBEDDING_TEXT_PORT = int(os.getenv('EMBEDDING_TEXT_PORT', EMBEDDING_PORT)) | |
| 66 | +EMBEDDING_IMAGE_HOST = os.getenv('EMBEDDING_IMAGE_HOST', EMBEDDING_HOST) | |
| 67 | +EMBEDDING_IMAGE_PORT = int(os.getenv('EMBEDDING_IMAGE_PORT', 6008)) | |
| 64 | 68 | TRANSLATION_HOST = os.getenv('TRANSLATION_HOST', '127.0.0.1') |
| 65 | 69 | TRANSLATION_PORT = int(os.getenv('TRANSLATION_PORT', 6006)) |
| 66 | 70 | RERANKER_HOST = os.getenv('RERANKER_HOST', '127.0.0.1') |
| ... | ... | @@ -74,6 +78,12 @@ INDEXER_BASE_URL = os.getenv('INDEXER_BASE_URL') or ( |
| 74 | 78 | f'http://localhost:{INDEXER_PORT}' if INDEXER_HOST == '0.0.0.0' else f'http://{INDEXER_HOST}:{INDEXER_PORT}' |
| 75 | 79 | ) |
| 76 | 80 | EMBEDDING_SERVICE_URL = os.getenv('EMBEDDING_SERVICE_URL') or f'http://{EMBEDDING_HOST}:{EMBEDDING_PORT}' |
| 81 | +EMBEDDING_TEXT_SERVICE_URL = os.getenv('EMBEDDING_TEXT_SERVICE_URL') or ( | |
| 82 | + f'http://{EMBEDDING_TEXT_HOST}:{EMBEDDING_TEXT_PORT}' | |
| 83 | +) | |
| 84 | +EMBEDDING_IMAGE_SERVICE_URL = os.getenv('EMBEDDING_IMAGE_SERVICE_URL') or ( | |
| 85 | + f'http://{EMBEDDING_IMAGE_HOST}:{EMBEDDING_IMAGE_PORT}' | |
| 86 | +) | |
| 77 | 87 | RERANKER_SERVICE_URL = os.getenv('RERANKER_SERVICE_URL') or f'http://{RERANKER_HOST}:{RERANKER_PORT}/rerank' |
| 78 | 88 | |
| 79 | 89 | # Model IDs / paths | ... | ... |
config/services_config.py
| ... | ... | @@ -79,10 +79,17 @@ def _resolve_embedding() -> ServiceConfig: |
| 79 | 79 | raise ValueError(f"Unsupported embedding provider: {provider}") |
| 80 | 80 | |
| 81 | 81 | env_url = os.getenv("EMBEDDING_SERVICE_URL") |
| 82 | - if env_url and provider == "http": | |
| 82 | + env_text_url = os.getenv("EMBEDDING_TEXT_SERVICE_URL") | |
| 83 | + env_image_url = os.getenv("EMBEDDING_IMAGE_SERVICE_URL") | |
| 84 | + if (env_url or env_text_url or env_image_url) and provider == "http": | |
| 83 | 85 | providers = dict(providers) |
| 84 | 86 | providers["http"] = dict(providers.get("http", {})) |
| 85 | - providers["http"]["base_url"] = env_url.rstrip("/") | |
| 87 | + if env_url: | |
| 88 | + providers["http"]["base_url"] = env_url.rstrip("/") | |
| 89 | + if env_text_url: | |
| 90 | + providers["http"]["text_base_url"] = env_text_url.rstrip("/") | |
| 91 | + if env_image_url: | |
| 92 | + providers["http"]["image_base_url"] = env_image_url.rstrip("/") | |
| 86 | 93 | |
| 87 | 94 | return ServiceConfig(provider=provider, providers=providers) |
| 88 | 95 | |
| ... | ... | @@ -165,12 +172,44 @@ def get_translation_cache_config() -> Dict[str, Any]: |
| 165 | 172 | |
| 166 | 173 | |
| 167 | 174 | def get_embedding_base_url() -> str: |
| 168 | - base = os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_config().providers.get("http", {}).get("base_url") | |
| 175 | + provider_cfg = get_embedding_config().providers.get("http", {}) | |
| 176 | + base = ( | |
| 177 | + os.getenv("EMBEDDING_SERVICE_URL") | |
| 178 | + or provider_cfg.get("base_url") | |
| 179 | + or provider_cfg.get("text_base_url") | |
| 180 | + or provider_cfg.get("image_base_url") | |
| 181 | + ) | |
| 169 | 182 | if not base: |
| 170 | 183 | raise ValueError("Embedding HTTP base_url is not configured") |
| 171 | 184 | return str(base).rstrip("/") |
| 172 | 185 | |
| 173 | 186 | |
| 187 | +def get_embedding_text_base_url() -> str: | |
| 188 | + provider_cfg = get_embedding_config().providers.get("http", {}) | |
| 189 | + base = ( | |
| 190 | + os.getenv("EMBEDDING_TEXT_SERVICE_URL") | |
| 191 | + or provider_cfg.get("text_base_url") | |
| 192 | + or os.getenv("EMBEDDING_SERVICE_URL") | |
| 193 | + or provider_cfg.get("base_url") | |
| 194 | + ) | |
| 195 | + if not base: | |
| 196 | + raise ValueError("Embedding text HTTP base_url is not configured") | |
| 197 | + return str(base).rstrip("/") | |
| 198 | + | |
| 199 | + | |
| 200 | +def get_embedding_image_base_url() -> str: | |
| 201 | + provider_cfg = get_embedding_config().providers.get("http", {}) | |
| 202 | + base = ( | |
| 203 | + os.getenv("EMBEDDING_IMAGE_SERVICE_URL") | |
| 204 | + or provider_cfg.get("image_base_url") | |
| 205 | + or os.getenv("EMBEDDING_SERVICE_URL") | |
| 206 | + or provider_cfg.get("base_url") | |
| 207 | + ) | |
| 208 | + if not base: | |
| 209 | + raise ValueError("Embedding image HTTP base_url is not configured") | |
| 210 | + return str(base).rstrip("/") | |
| 211 | + | |
| 212 | + | |
| 174 | 213 | def get_rerank_base_url() -> str: |
| 175 | 214 | base = ( |
| 176 | 215 | os.getenv("RERANKER_SERVICE_URL") | ... | ... |
| ... | ... | @@ -0,0 +1,13 @@ |
| 1 | +"""Shared cache key helpers for embedding inputs.""" | |
| 2 | + | |
| 3 | +from __future__ import annotations | |
| 4 | + | |
| 5 | + | |
| 6 | +def build_text_cache_key(text: str, *, normalize: bool) -> str: | |
| 7 | + normalized_text = str(text or "").strip() | |
| 8 | + return f"norm:{1 if normalize else 0}:text:{normalized_text}" | |
| 9 | + | |
| 10 | + | |
| 11 | +def build_image_cache_key(url: str, *, normalize: bool) -> str: | |
| 12 | + normalized_url = str(url or "").strip() | |
| 13 | + return f"norm:{1 if normalize else 0}:image:{normalized_url}" | ... | ... |
embeddings/image_encoder.py
| ... | ... | @@ -10,8 +10,9 @@ from PIL import Image |
| 10 | 10 | |
| 11 | 11 | logger = logging.getLogger(__name__) |
| 12 | 12 | |
| 13 | -from config.services_config import get_embedding_base_url | |
| 13 | +from config.services_config import get_embedding_image_base_url | |
| 14 | 14 | from config.env_config import REDIS_CONFIG |
| 15 | +from embeddings.cache_keys import build_image_cache_key | |
| 15 | 16 | from embeddings.redis_embedding_cache import RedisEmbeddingCache |
| 16 | 17 | |
| 17 | 18 | |
| ... | ... | @@ -23,7 +24,12 @@ class CLIPImageEncoder: |
| 23 | 24 | """ |
| 24 | 25 | |
| 25 | 26 | def __init__(self, service_url: Optional[str] = None): |
| 26 | - resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url() | |
| 27 | + resolved_url = ( | |
| 28 | + service_url | |
| 29 | + or os.getenv("EMBEDDING_IMAGE_SERVICE_URL") | |
| 30 | + or os.getenv("EMBEDDING_SERVICE_URL") | |
| 31 | + or get_embedding_image_base_url() | |
| 32 | + ) | |
| 27 | 33 | self.service_url = str(resolved_url).rstrip("/") |
| 28 | 34 | self.endpoint = f"{self.service_url}/embed/image" |
| 29 | 35 | # Reuse embedding cache prefix, but separate namespace for images to avoid collisions. |
| ... | ... | @@ -75,7 +81,8 @@ class CLIPImageEncoder: |
| 75 | 81 | Returns: |
| 76 | 82 | Embedding vector |
| 77 | 83 | """ |
| 78 | - cached = self.cache.get(url) | |
| 84 | + cache_key = build_image_cache_key(url, normalize=normalize_embeddings) | |
| 85 | + cached = self.cache.get(cache_key) | |
| 79 | 86 | if cached is not None: |
| 80 | 87 | return cached |
| 81 | 88 | |
| ... | ... | @@ -85,7 +92,7 @@ class CLIPImageEncoder: |
| 85 | 92 | vec = np.array(response_data[0], dtype=np.float32) |
| 86 | 93 | if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): |
| 87 | 94 | raise RuntimeError(f"Invalid image embedding returned for URL: {url}") |
| 88 | - self.cache.set(url, vec) | |
| 95 | + self.cache.set(cache_key, vec) | |
| 89 | 96 | return vec |
| 90 | 97 | |
| 91 | 98 | def encode_batch( |
| ... | ... | @@ -116,7 +123,8 @@ class CLIPImageEncoder: |
| 116 | 123 | |
| 117 | 124 | normalized_urls = [str(u).strip() for u in images] # type: ignore[list-item] |
| 118 | 125 | for pos, url in enumerate(normalized_urls): |
| 119 | - cached = self.cache.get(url) | |
| 126 | + cache_key = build_image_cache_key(url, normalize=normalize_embeddings) | |
| 127 | + cached = self.cache.get(cache_key) | |
| 120 | 128 | if cached is not None: |
| 121 | 129 | results.append(cached) |
| 122 | 130 | else: |
| ... | ... | @@ -139,7 +147,7 @@ class CLIPImageEncoder: |
| 139 | 147 | vec = np.array(embedding, dtype=np.float32) |
| 140 | 148 | if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): |
| 141 | 149 | raise RuntimeError(f"Invalid image embedding returned for URL: {url}") |
| 142 | - self.cache.set(url, vec) | |
| 150 | + self.cache.set(build_image_cache_key(url, normalize=normalize_embeddings), vec) | |
| 143 | 151 | pos = pending_positions[i + j] |
| 144 | 152 | results[pos] = vec |
| 145 | 153 | ... | ... |
embeddings/redis_embedding_cache.py
| ... | ... | @@ -12,10 +12,13 @@ from __future__ import annotations |
| 12 | 12 | |
| 13 | 13 | import logging |
| 14 | 14 | from datetime import timedelta |
| 15 | -from typing import Optional | |
| 15 | +from typing import Any, Optional | |
| 16 | 16 | |
| 17 | 17 | import numpy as np |
| 18 | -import redis | |
| 18 | +try: | |
| 19 | + import redis | |
| 20 | +except ImportError: # pragma: no cover - runtime fallback for minimal envs | |
| 21 | + redis = None # type: ignore[assignment] | |
| 19 | 22 | |
| 20 | 23 | from config.env_config import REDIS_CONFIG |
| 21 | 24 | from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis |
| ... | ... | @@ -40,6 +43,11 @@ class RedisEmbeddingCache: |
| 40 | 43 | self.redis_client = redis_client |
| 41 | 44 | return |
| 42 | 45 | |
| 46 | + if redis is None: | |
| 47 | + logger.warning("redis package is not installed, continuing without embedding cache") | |
| 48 | + self.redis_client = None | |
| 49 | + return | |
| 50 | + | |
| 43 | 51 | try: |
| 44 | 52 | client = redis.Redis( |
| 45 | 53 | host=REDIS_CONFIG.get("host", "localhost"), |
| ... | ... | @@ -104,4 +112,3 @@ class RedisEmbeddingCache: |
| 104 | 112 | except Exception as e: |
| 105 | 113 | logger.warning("Error storing embedding in cache: %s", e) |
| 106 | 114 | return False |
| 107 | - | ... | ... |
embeddings/server.py
| ... | ... | @@ -21,9 +21,12 @@ import numpy as np |
| 21 | 21 | from fastapi import FastAPI, HTTPException, Request, Response |
| 22 | 22 | from fastapi.concurrency import run_in_threadpool |
| 23 | 23 | |
| 24 | +from config.env_config import REDIS_CONFIG | |
| 24 | 25 | from config.services_config import get_embedding_backend_config |
| 26 | +from embeddings.cache_keys import build_image_cache_key, build_text_cache_key | |
| 25 | 27 | from embeddings.config import CONFIG |
| 26 | 28 | from embeddings.protocols import ImageEncoderProtocol |
| 29 | +from embeddings.redis_embedding_cache import RedisEmbeddingCache | |
| 27 | 30 | |
| 28 | 31 | app = FastAPI(title="saas-search Embedding Service", version="1.0.0") |
| 29 | 32 | |
| ... | ... | @@ -106,8 +109,15 @@ verbose_logger = logging.getLogger("embedding.verbose") |
| 106 | 109 | _text_model: Optional[Any] = None |
| 107 | 110 | _image_model: Optional[ImageEncoderProtocol] = None |
| 108 | 111 | _text_backend_name: str = "" |
| 109 | -open_text_model = os.getenv("EMBEDDING_ENABLE_TEXT_MODEL", "true").lower() in ("1", "true", "yes") | |
| 110 | -open_image_model = os.getenv("EMBEDDING_ENABLE_IMAGE_MODEL", "true").lower() in ("1", "true", "yes") | |
| 112 | +_SERVICE_KIND = (os.getenv("EMBEDDING_SERVICE_KIND", "all") or "all").strip().lower() | |
| 113 | +if _SERVICE_KIND not in {"all", "text", "image"}: | |
| 114 | + raise RuntimeError( | |
| 115 | + f"Invalid EMBEDDING_SERVICE_KIND={_SERVICE_KIND!r}; expected all, text, or image" | |
| 116 | + ) | |
| 117 | +_TEXT_ENABLED_BY_ENV = os.getenv("EMBEDDING_ENABLE_TEXT_MODEL", "true").lower() in ("1", "true", "yes") | |
| 118 | +_IMAGE_ENABLED_BY_ENV = os.getenv("EMBEDDING_ENABLE_IMAGE_MODEL", "true").lower() in ("1", "true", "yes") | |
| 119 | +open_text_model = _TEXT_ENABLED_BY_ENV and _SERVICE_KIND in {"all", "text"} | |
| 120 | +open_image_model = _IMAGE_ENABLED_BY_ENV and _SERVICE_KIND in {"all", "image"} | |
| 111 | 121 | |
| 112 | 122 | _text_encode_lock = threading.Lock() |
| 113 | 123 | _image_encode_lock = threading.Lock() |
| ... | ... | @@ -125,6 +135,71 @@ _LOG_PREVIEW_COUNT = max(1, int(os.getenv("EMBEDDING_LOG_PREVIEW_COUNT", "3"))) |
| 125 | 135 | _LOG_TEXT_PREVIEW_CHARS = max(32, int(os.getenv("EMBEDDING_LOG_TEXT_PREVIEW_CHARS", "120"))) |
| 126 | 136 | _LOG_IMAGE_PREVIEW_CHARS = max(32, int(os.getenv("EMBEDDING_LOG_IMAGE_PREVIEW_CHARS", "180"))) |
| 127 | 137 | _VECTOR_PREVIEW_DIMS = max(1, int(os.getenv("EMBEDDING_VECTOR_PREVIEW_DIMS", "6"))) |
| 138 | +_CACHE_PREFIX = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" | |
| 139 | + | |
| 140 | + | |
| 141 | +@dataclass | |
| 142 | +class _EmbedResult: | |
| 143 | + vectors: List[Optional[List[float]]] | |
| 144 | + cache_hits: int | |
| 145 | + cache_misses: int | |
| 146 | + backend_elapsed_ms: float | |
| 147 | + mode: str | |
| 148 | + | |
| 149 | + | |
| 150 | +class _EndpointStats: | |
| 151 | + def __init__(self, name: str): | |
| 152 | + self.name = name | |
| 153 | + self._lock = threading.Lock() | |
| 154 | + self.request_total = 0 | |
| 155 | + self.success_total = 0 | |
| 156 | + self.failure_total = 0 | |
| 157 | + self.rejected_total = 0 | |
| 158 | + self.cache_hits = 0 | |
| 159 | + self.cache_misses = 0 | |
| 160 | + self.total_latency_ms = 0.0 | |
| 161 | + self.total_backend_latency_ms = 0.0 | |
| 162 | + | |
| 163 | + def record_rejected(self) -> None: | |
| 164 | + with self._lock: | |
| 165 | + self.request_total += 1 | |
| 166 | + self.rejected_total += 1 | |
| 167 | + | |
| 168 | + def record_completed( | |
| 169 | + self, | |
| 170 | + *, | |
| 171 | + success: bool, | |
| 172 | + latency_ms: float, | |
| 173 | + backend_latency_ms: float, | |
| 174 | + cache_hits: int, | |
| 175 | + cache_misses: int, | |
| 176 | + ) -> None: | |
| 177 | + with self._lock: | |
| 178 | + self.request_total += 1 | |
| 179 | + if success: | |
| 180 | + self.success_total += 1 | |
| 181 | + else: | |
| 182 | + self.failure_total += 1 | |
| 183 | + self.cache_hits += max(0, int(cache_hits)) | |
| 184 | + self.cache_misses += max(0, int(cache_misses)) | |
| 185 | + self.total_latency_ms += max(0.0, float(latency_ms)) | |
| 186 | + self.total_backend_latency_ms += max(0.0, float(backend_latency_ms)) | |
| 187 | + | |
| 188 | + def snapshot(self) -> Dict[str, Any]: | |
| 189 | + with self._lock: | |
| 190 | + completed = self.success_total + self.failure_total | |
| 191 | + return { | |
| 192 | + "request_total": self.request_total, | |
| 193 | + "success_total": self.success_total, | |
| 194 | + "failure_total": self.failure_total, | |
| 195 | + "rejected_total": self.rejected_total, | |
| 196 | + "cache_hits": self.cache_hits, | |
| 197 | + "cache_misses": self.cache_misses, | |
| 198 | + "avg_latency_ms": round(self.total_latency_ms / completed, 3) if completed else 0.0, | |
| 199 | + "avg_backend_latency_ms": round(self.total_backend_latency_ms / completed, 3) | |
| 200 | + if completed | |
| 201 | + else 0.0, | |
| 202 | + } | |
| 128 | 203 | |
| 129 | 204 | |
| 130 | 205 | class _InflightLimiter: |
| ... | ... | @@ -176,6 +251,10 @@ class _InflightLimiter: |
| 176 | 251 | |
| 177 | 252 | _text_request_limiter = _InflightLimiter(name="text", limit=_TEXT_MAX_INFLIGHT) |
| 178 | 253 | _image_request_limiter = _InflightLimiter(name="image", limit=_IMAGE_MAX_INFLIGHT) |
| 254 | +_text_stats = _EndpointStats(name="text") | |
| 255 | +_image_stats = _EndpointStats(name="image") | |
| 256 | +_text_cache = RedisEmbeddingCache(key_prefix=_CACHE_PREFIX, namespace="") | |
| 257 | +_image_cache = RedisEmbeddingCache(key_prefix=_CACHE_PREFIX, namespace="image") | |
| 179 | 258 | |
| 180 | 259 | |
| 181 | 260 | @dataclass |
| ... | ... | @@ -377,7 +456,12 @@ def load_models(): |
| 377 | 456 | """Load models at service startup to avoid first-request latency.""" |
| 378 | 457 | global _text_model, _image_model, _text_backend_name |
| 379 | 458 | |
| 380 | - logger.info("Loading embedding models at startup...") | |
| 459 | + logger.info( | |
| 460 | + "Loading embedding models at startup | service_kind=%s text_enabled=%s image_enabled=%s", | |
| 461 | + _SERVICE_KIND, | |
| 462 | + open_text_model, | |
| 463 | + open_image_model, | |
| 464 | + ) | |
| 381 | 465 | |
| 382 | 466 | if open_text_model: |
| 383 | 467 | try: |
| ... | ... | @@ -482,18 +566,72 @@ def _as_list(embedding: Optional[np.ndarray], normalize: bool = False) -> Option |
| 482 | 566 | return embedding.tolist() |
| 483 | 567 | |
| 484 | 568 | |
| 569 | +def _try_full_text_cache_hit( | |
| 570 | + normalized: List[str], | |
| 571 | + effective_normalize: bool, | |
| 572 | +) -> Optional[_EmbedResult]: | |
| 573 | + out: List[Optional[List[float]]] = [] | |
| 574 | + for text in normalized: | |
| 575 | + cached = _text_cache.get(build_text_cache_key(text, normalize=effective_normalize)) | |
| 576 | + if cached is None: | |
| 577 | + return None | |
| 578 | + vec = _as_list(cached, normalize=False) | |
| 579 | + if vec is None: | |
| 580 | + return None | |
| 581 | + out.append(vec) | |
| 582 | + return _EmbedResult( | |
| 583 | + vectors=out, | |
| 584 | + cache_hits=len(out), | |
| 585 | + cache_misses=0, | |
| 586 | + backend_elapsed_ms=0.0, | |
| 587 | + mode="cache-only", | |
| 588 | + ) | |
| 589 | + | |
| 590 | + | |
| 591 | +def _try_full_image_cache_hit( | |
| 592 | + urls: List[str], | |
| 593 | + effective_normalize: bool, | |
| 594 | +) -> Optional[_EmbedResult]: | |
| 595 | + out: List[Optional[List[float]]] = [] | |
| 596 | + for url in urls: | |
| 597 | + cached = _image_cache.get(build_image_cache_key(url, normalize=effective_normalize)) | |
| 598 | + if cached is None: | |
| 599 | + return None | |
| 600 | + vec = _as_list(cached, normalize=False) | |
| 601 | + if vec is None: | |
| 602 | + return None | |
| 603 | + out.append(vec) | |
| 604 | + return _EmbedResult( | |
| 605 | + vectors=out, | |
| 606 | + cache_hits=len(out), | |
| 607 | + cache_misses=0, | |
| 608 | + backend_elapsed_ms=0.0, | |
| 609 | + mode="cache-only", | |
| 610 | + ) | |
| 611 | + | |
| 612 | + | |
| 485 | 613 | @app.get("/health") |
| 486 | 614 | def health() -> Dict[str, Any]: |
| 487 | 615 | """Health check endpoint. Returns status and current throttling stats.""" |
| 616 | + ready = (not open_text_model or _text_model is not None) and (not open_image_model or _image_model is not None) | |
| 488 | 617 | return { |
| 489 | - "status": "ok", | |
| 618 | + "status": "ok" if ready else "degraded", | |
| 619 | + "service_kind": _SERVICE_KIND, | |
| 490 | 620 | "text_model_loaded": _text_model is not None, |
| 491 | 621 | "text_backend": _text_backend_name, |
| 492 | 622 | "image_model_loaded": _image_model is not None, |
| 623 | + "cache_enabled": { | |
| 624 | + "text": _text_cache.redis_client is not None, | |
| 625 | + "image": _image_cache.redis_client is not None, | |
| 626 | + }, | |
| 493 | 627 | "limits": { |
| 494 | 628 | "text": _text_request_limiter.snapshot(), |
| 495 | 629 | "image": _image_request_limiter.snapshot(), |
| 496 | 630 | }, |
| 631 | + "stats": { | |
| 632 | + "text": _text_stats.snapshot(), | |
| 633 | + "image": _image_stats.snapshot(), | |
| 634 | + }, | |
| 497 | 635 | "text_microbatch": { |
| 498 | 636 | "window_ms": round(_TEXT_MICROBATCH_WINDOW_SEC * 1000.0, 3), |
| 499 | 637 | "queue_depth": len(_text_single_queue), |
| ... | ... | @@ -503,44 +641,105 @@ def health() -> Dict[str, Any]: |
| 503 | 641 | } |
| 504 | 642 | |
| 505 | 643 | |
| 644 | +@app.get("/ready") | |
| 645 | +def ready() -> Dict[str, Any]: | |
| 646 | + text_ready = (not open_text_model) or (_text_model is not None) | |
| 647 | + image_ready = (not open_image_model) or (_image_model is not None) | |
| 648 | + if not (text_ready and image_ready): | |
| 649 | + raise HTTPException( | |
| 650 | + status_code=503, | |
| 651 | + detail={ | |
| 652 | + "service_kind": _SERVICE_KIND, | |
| 653 | + "text_ready": text_ready, | |
| 654 | + "image_ready": image_ready, | |
| 655 | + }, | |
| 656 | + ) | |
| 657 | + return { | |
| 658 | + "status": "ready", | |
| 659 | + "service_kind": _SERVICE_KIND, | |
| 660 | + "text_ready": text_ready, | |
| 661 | + "image_ready": image_ready, | |
| 662 | + } | |
| 663 | + | |
| 664 | + | |
| 506 | 665 | def _embed_text_impl( |
| 507 | 666 | normalized: List[str], |
| 508 | 667 | effective_normalize: bool, |
| 509 | 668 | request_id: str, |
| 510 | -) -> List[Optional[List[float]]]: | |
| 669 | +) -> _EmbedResult: | |
| 511 | 670 | if _text_model is None: |
| 512 | 671 | raise RuntimeError("Text model not loaded") |
| 513 | 672 | |
| 514 | - t0 = time.perf_counter() | |
| 673 | + out: List[Optional[List[float]]] = [None] * len(normalized) | |
| 674 | + missing_indices: List[int] = [] | |
| 675 | + missing_texts: List[str] = [] | |
| 676 | + missing_cache_keys: List[str] = [] | |
| 677 | + cache_hits = 0 | |
| 678 | + for idx, text in enumerate(normalized): | |
| 679 | + cache_key = build_text_cache_key(text, normalize=effective_normalize) | |
| 680 | + cached = _text_cache.get(cache_key) | |
| 681 | + if cached is not None: | |
| 682 | + vec = _as_list(cached, normalize=False) | |
| 683 | + if vec is not None: | |
| 684 | + out[idx] = vec | |
| 685 | + cache_hits += 1 | |
| 686 | + continue | |
| 687 | + missing_indices.append(idx) | |
| 688 | + missing_texts.append(text) | |
| 689 | + missing_cache_keys.append(cache_key) | |
| 690 | + | |
| 691 | + if not missing_texts: | |
| 692 | + logger.info( | |
| 693 | + "text backend done | backend=%s mode=cache-only inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 backend_elapsed_ms=0.00", | |
| 694 | + _text_backend_name, | |
| 695 | + len(normalized), | |
| 696 | + effective_normalize, | |
| 697 | + len(out[0]) if out and out[0] is not None else 0, | |
| 698 | + cache_hits, | |
| 699 | + extra=_request_log_extra(request_id), | |
| 700 | + ) | |
| 701 | + return _EmbedResult( | |
| 702 | + vectors=out, | |
| 703 | + cache_hits=cache_hits, | |
| 704 | + cache_misses=0, | |
| 705 | + backend_elapsed_ms=0.0, | |
| 706 | + mode="cache-only", | |
| 707 | + ) | |
| 708 | + | |
| 709 | + backend_t0 = time.perf_counter() | |
| 515 | 710 | try: |
| 516 | 711 | if _text_backend_name == "local_st": |
| 517 | - if len(normalized) == 1 and _text_batch_worker is not None: | |
| 518 | - out = [ | |
| 712 | + if len(missing_texts) == 1 and _text_batch_worker is not None: | |
| 713 | + computed = [ | |
| 519 | 714 | _encode_single_text_with_microbatch( |
| 520 | - normalized[0], | |
| 715 | + missing_texts[0], | |
| 521 | 716 | normalize=effective_normalize, |
| 522 | 717 | request_id=request_id, |
| 523 | 718 | ) |
| 524 | 719 | ] |
| 525 | - logger.info( | |
| 526 | - "text backend done | backend=%s mode=microbatch-single inputs=%d normalize=%s dim=%d backend_elapsed_ms=%.2f", | |
| 527 | - _text_backend_name, | |
| 528 | - len(normalized), | |
| 529 | - effective_normalize, | |
| 530 | - len(out[0]) if out and out[0] is not None else 0, | |
| 531 | - (time.perf_counter() - t0) * 1000.0, | |
| 532 | - extra=_request_log_extra(request_id), | |
| 533 | - ) | |
| 534 | - return out | |
| 535 | - embs = _encode_local_st(normalized, normalize_embeddings=False) | |
| 536 | - mode = "direct-batch" | |
| 720 | + mode = "microbatch-single" | |
| 721 | + else: | |
| 722 | + embs = _encode_local_st(missing_texts, normalize_embeddings=False) | |
| 723 | + computed = [] | |
| 724 | + for i, emb in enumerate(embs): | |
| 725 | + vec = _as_list(emb, normalize=effective_normalize) | |
| 726 | + if vec is None: | |
| 727 | + raise RuntimeError(f"Text model returned empty embedding for missing index {i}") | |
| 728 | + computed.append(vec) | |
| 729 | + mode = "direct-batch" | |
| 537 | 730 | else: |
| 538 | 731 | embs = _text_model.encode( |
| 539 | - normalized, | |
| 732 | + missing_texts, | |
| 540 | 733 | batch_size=int(CONFIG.TEXT_BATCH_SIZE), |
| 541 | 734 | device=CONFIG.TEXT_DEVICE, |
| 542 | 735 | normalize_embeddings=effective_normalize, |
| 543 | 736 | ) |
| 737 | + computed = [] | |
| 738 | + for i, emb in enumerate(embs): | |
| 739 | + vec = _as_list(emb, normalize=False) | |
| 740 | + if vec is None: | |
| 741 | + raise RuntimeError(f"Text model returned empty embedding for missing index {i}") | |
| 742 | + computed.append(vec) | |
| 544 | 743 | mode = "backend-batch" |
| 545 | 744 | except Exception as e: |
| 546 | 745 | logger.error( |
| ... | ... | @@ -551,30 +750,37 @@ def _embed_text_impl( |
| 551 | 750 | ) |
| 552 | 751 | raise RuntimeError(f"Text embedding backend failure: {e}") from e |
| 553 | 752 | |
| 554 | - if embs is None or len(embs) != len(normalized): | |
| 753 | + if len(computed) != len(missing_texts): | |
| 555 | 754 | raise RuntimeError( |
| 556 | - f"Text model response length mismatch: expected {len(normalized)}, " | |
| 557 | - f"got {0 if embs is None else len(embs)}" | |
| 755 | + f"Text model response length mismatch: expected {len(missing_texts)}, " | |
| 756 | + f"got {len(computed)}" | |
| 558 | 757 | ) |
| 559 | 758 | |
| 560 | - out: List[Optional[List[float]]] = [] | |
| 561 | - for i, emb in enumerate(embs): | |
| 562 | - vec = _as_list(emb, normalize=effective_normalize) | |
| 563 | - if vec is None: | |
| 564 | - raise RuntimeError(f"Text model returned empty embedding for index {i}") | |
| 565 | - out.append(vec) | |
| 759 | + for pos, cache_key, vec in zip(missing_indices, missing_cache_keys, computed): | |
| 760 | + out[pos] = vec | |
| 761 | + _text_cache.set(cache_key, np.asarray(vec, dtype=np.float32)) | |
| 762 | + | |
| 763 | + backend_elapsed_ms = (time.perf_counter() - backend_t0) * 1000.0 | |
| 566 | 764 | |
| 567 | 765 | logger.info( |
| 568 | - "text backend done | backend=%s mode=%s inputs=%d normalize=%s dim=%d backend_elapsed_ms=%.2f", | |
| 766 | + "text backend done | backend=%s mode=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d backend_elapsed_ms=%.2f", | |
| 569 | 767 | _text_backend_name, |
| 570 | 768 | mode, |
| 571 | 769 | len(normalized), |
| 572 | 770 | effective_normalize, |
| 573 | 771 | len(out[0]) if out and out[0] is not None else 0, |
| 574 | - (time.perf_counter() - t0) * 1000.0, | |
| 772 | + cache_hits, | |
| 773 | + len(missing_texts), | |
| 774 | + backend_elapsed_ms, | |
| 575 | 775 | extra=_request_log_extra(request_id), |
| 576 | 776 | ) |
| 577 | - return out | |
| 777 | + return _EmbedResult( | |
| 778 | + vectors=out, | |
| 779 | + cache_hits=cache_hits, | |
| 780 | + cache_misses=len(missing_texts), | |
| 781 | + backend_elapsed_ms=backend_elapsed_ms, | |
| 782 | + mode=mode, | |
| 783 | + ) | |
| 578 | 784 | |
| 579 | 785 | |
| 580 | 786 | @app.post("/embed/text") |
| ... | ... | @@ -584,6 +790,9 @@ async def embed_text( |
| 584 | 790 | response: Response, |
| 585 | 791 | normalize: Optional[bool] = None, |
| 586 | 792 | ) -> List[Optional[List[float]]]: |
| 793 | + if _text_model is None: | |
| 794 | + raise HTTPException(status_code=503, detail="Text embedding model not loaded in this service") | |
| 795 | + | |
| 587 | 796 | request_id = _resolve_request_id(http_request) |
| 588 | 797 | response.headers["X-Request-ID"] = request_id |
| 589 | 798 | |
| ... | ... | @@ -597,8 +806,33 @@ async def embed_text( |
| 597 | 806 | raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: empty string") |
| 598 | 807 | normalized.append(s) |
| 599 | 808 | |
| 809 | + cache_check_started = time.perf_counter() | |
| 810 | + cache_only = _try_full_text_cache_hit(normalized, effective_normalize) | |
| 811 | + if cache_only is not None: | |
| 812 | + latency_ms = (time.perf_counter() - cache_check_started) * 1000.0 | |
| 813 | + _text_stats.record_completed( | |
| 814 | + success=True, | |
| 815 | + latency_ms=latency_ms, | |
| 816 | + backend_latency_ms=0.0, | |
| 817 | + cache_hits=cache_only.cache_hits, | |
| 818 | + cache_misses=0, | |
| 819 | + ) | |
| 820 | + logger.info( | |
| 821 | + "embed_text response | backend=%s mode=cache-only inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 first_vector=%s latency_ms=%.2f", | |
| 822 | + _text_backend_name, | |
| 823 | + len(normalized), | |
| 824 | + effective_normalize, | |
| 825 | + len(cache_only.vectors[0]) if cache_only.vectors and cache_only.vectors[0] is not None else 0, | |
| 826 | + cache_only.cache_hits, | |
| 827 | + _preview_vector(cache_only.vectors[0] if cache_only.vectors else None), | |
| 828 | + latency_ms, | |
| 829 | + extra=_request_log_extra(request_id), | |
| 830 | + ) | |
| 831 | + return cache_only.vectors | |
| 832 | + | |
| 600 | 833 | accepted, active = _text_request_limiter.try_acquire() |
| 601 | 834 | if not accepted: |
| 835 | + _text_stats.record_rejected() | |
| 602 | 836 | logger.warning( |
| 603 | 837 | "embed_text rejected | client=%s backend=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", |
| 604 | 838 | _request_client(http_request), |
| ... | ... | @@ -617,6 +851,9 @@ async def embed_text( |
| 617 | 851 | |
| 618 | 852 | request_started = time.perf_counter() |
| 619 | 853 | success = False |
| 854 | + backend_elapsed_ms = 0.0 | |
| 855 | + cache_hits = 0 | |
| 856 | + cache_misses = 0 | |
| 620 | 857 | try: |
| 621 | 858 | logger.info( |
| 622 | 859 | "embed_text request | client=%s backend=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", |
| ... | ... | @@ -636,31 +873,53 @@ async def embed_text( |
| 636 | 873 | _text_backend_name, |
| 637 | 874 | extra=_request_log_extra(request_id), |
| 638 | 875 | ) |
| 639 | - out = await run_in_threadpool(_embed_text_impl, normalized, effective_normalize, request_id) | |
| 876 | + result = await run_in_threadpool(_embed_text_impl, normalized, effective_normalize, request_id) | |
| 640 | 877 | success = True |
| 878 | + backend_elapsed_ms = result.backend_elapsed_ms | |
| 879 | + cache_hits = result.cache_hits | |
| 880 | + cache_misses = result.cache_misses | |
| 641 | 881 | latency_ms = (time.perf_counter() - request_started) * 1000.0 |
| 882 | + _text_stats.record_completed( | |
| 883 | + success=True, | |
| 884 | + latency_ms=latency_ms, | |
| 885 | + backend_latency_ms=backend_elapsed_ms, | |
| 886 | + cache_hits=cache_hits, | |
| 887 | + cache_misses=cache_misses, | |
| 888 | + ) | |
| 642 | 889 | logger.info( |
| 643 | - "embed_text response | backend=%s inputs=%d normalize=%s dim=%d first_vector=%s latency_ms=%.2f", | |
| 890 | + "embed_text response | backend=%s mode=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d first_vector=%s latency_ms=%.2f", | |
| 644 | 891 | _text_backend_name, |
| 892 | + result.mode, | |
| 645 | 893 | len(normalized), |
| 646 | 894 | effective_normalize, |
| 647 | - len(out[0]) if out and out[0] is not None else 0, | |
| 648 | - _preview_vector(out[0] if out else None), | |
| 895 | + len(result.vectors[0]) if result.vectors and result.vectors[0] is not None else 0, | |
| 896 | + cache_hits, | |
| 897 | + cache_misses, | |
| 898 | + _preview_vector(result.vectors[0] if result.vectors else None), | |
| 649 | 899 | latency_ms, |
| 650 | 900 | extra=_request_log_extra(request_id), |
| 651 | 901 | ) |
| 652 | 902 | verbose_logger.info( |
| 653 | 903 | "embed_text result detail | count=%d first_vector=%s latency_ms=%.2f", |
| 654 | - len(out), | |
| 655 | - out[0][: _VECTOR_PREVIEW_DIMS] if out and out[0] is not None else [], | |
| 904 | + len(result.vectors), | |
| 905 | + result.vectors[0][: _VECTOR_PREVIEW_DIMS] | |
| 906 | + if result.vectors and result.vectors[0] is not None | |
| 907 | + else [], | |
| 656 | 908 | latency_ms, |
| 657 | 909 | extra=_request_log_extra(request_id), |
| 658 | 910 | ) |
| 659 | - return out | |
| 911 | + return result.vectors | |
| 660 | 912 | except HTTPException: |
| 661 | 913 | raise |
| 662 | 914 | except Exception as e: |
| 663 | 915 | latency_ms = (time.perf_counter() - request_started) * 1000.0 |
| 916 | + _text_stats.record_completed( | |
| 917 | + success=False, | |
| 918 | + latency_ms=latency_ms, | |
| 919 | + backend_latency_ms=backend_elapsed_ms, | |
| 920 | + cache_hits=cache_hits, | |
| 921 | + cache_misses=cache_misses, | |
| 922 | + ) | |
| 664 | 923 | logger.error( |
| 665 | 924 | "embed_text failed | backend=%s inputs=%d normalize=%s latency_ms=%.2f error=%s", |
| 666 | 925 | _text_backend_name, |
| ... | ... | @@ -686,39 +945,84 @@ def _embed_image_impl( |
| 686 | 945 | urls: List[str], |
| 687 | 946 | effective_normalize: bool, |
| 688 | 947 | request_id: str, |
| 689 | -) -> List[Optional[List[float]]]: | |
| 948 | +) -> _EmbedResult: | |
| 690 | 949 | if _image_model is None: |
| 691 | 950 | raise RuntimeError("Image model not loaded") |
| 692 | 951 | |
| 693 | - t0 = time.perf_counter() | |
| 952 | + out: List[Optional[List[float]]] = [None] * len(urls) | |
| 953 | + missing_indices: List[int] = [] | |
| 954 | + missing_urls: List[str] = [] | |
| 955 | + missing_cache_keys: List[str] = [] | |
| 956 | + cache_hits = 0 | |
| 957 | + for idx, url in enumerate(urls): | |
| 958 | + cache_key = build_image_cache_key(url, normalize=effective_normalize) | |
| 959 | + cached = _image_cache.get(cache_key) | |
| 960 | + if cached is not None: | |
| 961 | + vec = _as_list(cached, normalize=False) | |
| 962 | + if vec is not None: | |
| 963 | + out[idx] = vec | |
| 964 | + cache_hits += 1 | |
| 965 | + continue | |
| 966 | + missing_indices.append(idx) | |
| 967 | + missing_urls.append(url) | |
| 968 | + missing_cache_keys.append(cache_key) | |
| 969 | + | |
| 970 | + if not missing_urls: | |
| 971 | + logger.info( | |
| 972 | + "image backend done | mode=cache-only inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 backend_elapsed_ms=0.00", | |
| 973 | + len(urls), | |
| 974 | + effective_normalize, | |
| 975 | + len(out[0]) if out and out[0] is not None else 0, | |
| 976 | + cache_hits, | |
| 977 | + extra=_request_log_extra(request_id), | |
| 978 | + ) | |
| 979 | + return _EmbedResult( | |
| 980 | + vectors=out, | |
| 981 | + cache_hits=cache_hits, | |
| 982 | + cache_misses=0, | |
| 983 | + backend_elapsed_ms=0.0, | |
| 984 | + mode="cache-only", | |
| 985 | + ) | |
| 986 | + | |
| 987 | + backend_t0 = time.perf_counter() | |
| 694 | 988 | with _image_encode_lock: |
| 695 | 989 | vectors = _image_model.encode_image_urls( |
| 696 | - urls, | |
| 990 | + missing_urls, | |
| 697 | 991 | batch_size=CONFIG.IMAGE_BATCH_SIZE, |
| 698 | 992 | normalize_embeddings=effective_normalize, |
| 699 | 993 | ) |
| 700 | - if vectors is None or len(vectors) != len(urls): | |
| 994 | + if vectors is None or len(vectors) != len(missing_urls): | |
| 701 | 995 | raise RuntimeError( |
| 702 | - f"Image model response length mismatch: expected {len(urls)}, " | |
| 996 | + f"Image model response length mismatch: expected {len(missing_urls)}, " | |
| 703 | 997 | f"got {0 if vectors is None else len(vectors)}" |
| 704 | 998 | ) |
| 705 | 999 | |
| 706 | - out: List[Optional[List[float]]] = [] | |
| 707 | - for i, vec in enumerate(vectors): | |
| 1000 | + for pos, cache_key, vec in zip(missing_indices, missing_cache_keys, vectors): | |
| 708 | 1001 | out_vec = _as_list(vec, normalize=effective_normalize) |
| 709 | 1002 | if out_vec is None: |
| 710 | - raise RuntimeError(f"Image model returned empty embedding for index {i}") | |
| 711 | - out.append(out_vec) | |
| 1003 | + raise RuntimeError(f"Image model returned empty embedding for position {pos}") | |
| 1004 | + out[pos] = out_vec | |
| 1005 | + _image_cache.set(cache_key, np.asarray(out_vec, dtype=np.float32)) | |
| 1006 | + | |
| 1007 | + backend_elapsed_ms = (time.perf_counter() - backend_t0) * 1000.0 | |
| 712 | 1008 | |
| 713 | 1009 | logger.info( |
| 714 | - "image backend done | inputs=%d normalize=%s dim=%d backend_elapsed_ms=%.2f", | |
| 1010 | + "image backend done | mode=backend-batch inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d backend_elapsed_ms=%.2f", | |
| 715 | 1011 | len(urls), |
| 716 | 1012 | effective_normalize, |
| 717 | 1013 | len(out[0]) if out and out[0] is not None else 0, |
| 718 | - (time.perf_counter() - t0) * 1000.0, | |
| 1014 | + cache_hits, | |
| 1015 | + len(missing_urls), | |
| 1016 | + backend_elapsed_ms, | |
| 719 | 1017 | extra=_request_log_extra(request_id), |
| 720 | 1018 | ) |
| 721 | - return out | |
| 1019 | + return _EmbedResult( | |
| 1020 | + vectors=out, | |
| 1021 | + cache_hits=cache_hits, | |
| 1022 | + cache_misses=len(missing_urls), | |
| 1023 | + backend_elapsed_ms=backend_elapsed_ms, | |
| 1024 | + mode="backend-batch", | |
| 1025 | + ) | |
| 722 | 1026 | |
| 723 | 1027 | |
| 724 | 1028 | @app.post("/embed/image") |
| ... | ... | @@ -728,6 +1032,9 @@ async def embed_image( |
| 728 | 1032 | response: Response, |
| 729 | 1033 | normalize: Optional[bool] = None, |
| 730 | 1034 | ) -> List[Optional[List[float]]]: |
| 1035 | + if _image_model is None: | |
| 1036 | + raise HTTPException(status_code=503, detail="Image embedding model not loaded in this service") | |
| 1037 | + | |
| 731 | 1038 | request_id = _resolve_request_id(http_request) |
| 732 | 1039 | response.headers["X-Request-ID"] = request_id |
| 733 | 1040 | |
| ... | ... | @@ -741,8 +1048,32 @@ async def embed_image( |
| 741 | 1048 | raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: empty URL/path") |
| 742 | 1049 | urls.append(s) |
| 743 | 1050 | |
| 1051 | + cache_check_started = time.perf_counter() | |
| 1052 | + cache_only = _try_full_image_cache_hit(urls, effective_normalize) | |
| 1053 | + if cache_only is not None: | |
| 1054 | + latency_ms = (time.perf_counter() - cache_check_started) * 1000.0 | |
| 1055 | + _image_stats.record_completed( | |
| 1056 | + success=True, | |
| 1057 | + latency_ms=latency_ms, | |
| 1058 | + backend_latency_ms=0.0, | |
| 1059 | + cache_hits=cache_only.cache_hits, | |
| 1060 | + cache_misses=0, | |
| 1061 | + ) | |
| 1062 | + logger.info( | |
| 1063 | + "embed_image response | mode=cache-only inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 first_vector=%s latency_ms=%.2f", | |
| 1064 | + len(urls), | |
| 1065 | + effective_normalize, | |
| 1066 | + len(cache_only.vectors[0]) if cache_only.vectors and cache_only.vectors[0] is not None else 0, | |
| 1067 | + cache_only.cache_hits, | |
| 1068 | + _preview_vector(cache_only.vectors[0] if cache_only.vectors else None), | |
| 1069 | + latency_ms, | |
| 1070 | + extra=_request_log_extra(request_id), | |
| 1071 | + ) | |
| 1072 | + return cache_only.vectors | |
| 1073 | + | |
| 744 | 1074 | accepted, active = _image_request_limiter.try_acquire() |
| 745 | 1075 | if not accepted: |
| 1076 | + _image_stats.record_rejected() | |
| 746 | 1077 | logger.warning( |
| 747 | 1078 | "embed_image rejected | client=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", |
| 748 | 1079 | _request_client(http_request), |
| ... | ... | @@ -760,6 +1091,9 @@ async def embed_image( |
| 760 | 1091 | |
| 761 | 1092 | request_started = time.perf_counter() |
| 762 | 1093 | success = False |
| 1094 | + backend_elapsed_ms = 0.0 | |
| 1095 | + cache_hits = 0 | |
| 1096 | + cache_misses = 0 | |
| 763 | 1097 | try: |
| 764 | 1098 | logger.info( |
| 765 | 1099 | "embed_image request | client=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", |
| ... | ... | @@ -777,30 +1111,52 @@ async def embed_image( |
| 777 | 1111 | effective_normalize, |
| 778 | 1112 | extra=_request_log_extra(request_id), |
| 779 | 1113 | ) |
| 780 | - out = await run_in_threadpool(_embed_image_impl, urls, effective_normalize, request_id) | |
| 1114 | + result = await run_in_threadpool(_embed_image_impl, urls, effective_normalize, request_id) | |
| 781 | 1115 | success = True |
| 1116 | + backend_elapsed_ms = result.backend_elapsed_ms | |
| 1117 | + cache_hits = result.cache_hits | |
| 1118 | + cache_misses = result.cache_misses | |
| 782 | 1119 | latency_ms = (time.perf_counter() - request_started) * 1000.0 |
| 1120 | + _image_stats.record_completed( | |
| 1121 | + success=True, | |
| 1122 | + latency_ms=latency_ms, | |
| 1123 | + backend_latency_ms=backend_elapsed_ms, | |
| 1124 | + cache_hits=cache_hits, | |
| 1125 | + cache_misses=cache_misses, | |
| 1126 | + ) | |
| 783 | 1127 | logger.info( |
| 784 | - "embed_image response | inputs=%d normalize=%s dim=%d first_vector=%s latency_ms=%.2f", | |
| 1128 | + "embed_image response | mode=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d first_vector=%s latency_ms=%.2f", | |
| 1129 | + result.mode, | |
| 785 | 1130 | len(urls), |
| 786 | 1131 | effective_normalize, |
| 787 | - len(out[0]) if out and out[0] is not None else 0, | |
| 788 | - _preview_vector(out[0] if out else None), | |
| 1132 | + len(result.vectors[0]) if result.vectors and result.vectors[0] is not None else 0, | |
| 1133 | + cache_hits, | |
| 1134 | + cache_misses, | |
| 1135 | + _preview_vector(result.vectors[0] if result.vectors else None), | |
| 789 | 1136 | latency_ms, |
| 790 | 1137 | extra=_request_log_extra(request_id), |
| 791 | 1138 | ) |
| 792 | 1139 | verbose_logger.info( |
| 793 | 1140 | "embed_image result detail | count=%d first_vector=%s latency_ms=%.2f", |
| 794 | - len(out), | |
| 795 | - out[0][: _VECTOR_PREVIEW_DIMS] if out and out[0] is not None else [], | |
| 1141 | + len(result.vectors), | |
| 1142 | + result.vectors[0][: _VECTOR_PREVIEW_DIMS] | |
| 1143 | + if result.vectors and result.vectors[0] is not None | |
| 1144 | + else [], | |
| 796 | 1145 | latency_ms, |
| 797 | 1146 | extra=_request_log_extra(request_id), |
| 798 | 1147 | ) |
| 799 | - return out | |
| 1148 | + return result.vectors | |
| 800 | 1149 | except HTTPException: |
| 801 | 1150 | raise |
| 802 | 1151 | except Exception as e: |
| 803 | 1152 | latency_ms = (time.perf_counter() - request_started) * 1000.0 |
| 1153 | + _image_stats.record_completed( | |
| 1154 | + success=False, | |
| 1155 | + latency_ms=latency_ms, | |
| 1156 | + backend_latency_ms=backend_elapsed_ms, | |
| 1157 | + cache_hits=cache_hits, | |
| 1158 | + cache_misses=cache_misses, | |
| 1159 | + ) | |
| 804 | 1160 | logger.error( |
| 805 | 1161 | "embed_image failed | inputs=%d normalize=%s latency_ms=%.2f error=%s", |
| 806 | 1162 | len(urls), | ... | ... |
embeddings/text_encoder.py
| ... | ... | @@ -10,19 +10,26 @@ import requests |
| 10 | 10 | |
| 11 | 11 | logger = logging.getLogger(__name__) |
| 12 | 12 | |
| 13 | -from config.services_config import get_embedding_base_url | |
| 13 | +from config.services_config import get_embedding_text_base_url | |
| 14 | +from embeddings.cache_keys import build_text_cache_key | |
| 14 | 15 | from embeddings.redis_embedding_cache import RedisEmbeddingCache |
| 15 | 16 | |
| 16 | 17 | # Try to import REDIS_CONFIG, but allow import to fail |
| 17 | 18 | from config.env_config import REDIS_CONFIG |
| 18 | 19 | |
| 20 | + | |
| 19 | 21 | class TextEmbeddingEncoder: |
| 20 | 22 | """ |
| 21 | 23 | Text embedding encoder using network service. |
| 22 | 24 | """ |
| 23 | 25 | |
| 24 | 26 | def __init__(self, service_url: Optional[str] = None): |
| 25 | - resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url() | |
| 27 | + resolved_url = ( | |
| 28 | + service_url | |
| 29 | + or os.getenv("EMBEDDING_TEXT_SERVICE_URL") | |
| 30 | + or os.getenv("EMBEDDING_SERVICE_URL") | |
| 31 | + or get_embedding_text_base_url() | |
| 32 | + ) | |
| 26 | 33 | self.service_url = str(resolved_url).rstrip("/") |
| 27 | 34 | self.endpoint = f"{self.service_url}/embed/text" |
| 28 | 35 | self.expire_time = timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180)) |
| ... | ... | @@ -87,9 +94,8 @@ class TextEmbeddingEncoder: |
| 87 | 94 | uncached_texts: List[str] = [] |
| 88 | 95 | |
| 89 | 96 | embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) |
| 90 | - | |
| 91 | 97 | for i, text in enumerate(sentences): |
| 92 | - cached = self._get_cached_embedding(text) | |
| 98 | + cached = self._get_cached_embedding(text, normalize_embeddings=normalize_embeddings) | |
| 93 | 99 | if cached is not None: |
| 94 | 100 | embeddings[i] = cached |
| 95 | 101 | else: |
| ... | ... | @@ -115,7 +121,11 @@ class TextEmbeddingEncoder: |
| 115 | 121 | embedding_array = np.array(embedding, dtype=np.float32) |
| 116 | 122 | if self._is_valid_embedding(embedding_array): |
| 117 | 123 | embeddings[original_idx] = embedding_array |
| 118 | - self._set_cached_embedding(text, embedding_array) | |
| 124 | + self._set_cached_embedding( | |
| 125 | + text, | |
| 126 | + embedding_array, | |
| 127 | + normalize_embeddings=normalize_embeddings, | |
| 128 | + ) | |
| 119 | 129 | else: |
| 120 | 130 | raise ValueError( |
| 121 | 131 | f"Invalid embedding returned from service for text index {original_idx}" |
| ... | ... | @@ -150,20 +160,32 @@ class TextEmbeddingEncoder: |
| 150 | 160 | def _get_cached_embedding( |
| 151 | 161 | self, |
| 152 | 162 | query: str, |
| 163 | + *, | |
| 164 | + normalize_embeddings: bool, | |
| 153 | 165 | ) -> Optional[np.ndarray]: |
| 154 | 166 | """Get embedding from cache if exists (with sliding expiration).""" |
| 155 | - embedding = self.cache.get(query) | |
| 167 | + embedding = self.cache.get(build_text_cache_key(query, normalize=normalize_embeddings)) | |
| 156 | 168 | if embedding is not None: |
| 157 | - logger.debug(f"Cache hit for embedding: {query}") | |
| 169 | + logger.debug( | |
| 170 | + "Cache hit for text embedding | normalize=%s query=%s", | |
| 171 | + normalize_embeddings, | |
| 172 | + query, | |
| 173 | + ) | |
| 158 | 174 | return embedding |
| 159 | 175 | |
| 160 | 176 | def _set_cached_embedding( |
| 161 | 177 | self, |
| 162 | 178 | query: str, |
| 163 | 179 | embedding: np.ndarray, |
| 180 | + *, | |
| 181 | + normalize_embeddings: bool, | |
| 164 | 182 | ) -> bool: |
| 165 | 183 | """Store embedding in cache.""" |
| 166 | - ok = self.cache.set(query, embedding) | |
| 184 | + ok = self.cache.set(build_text_cache_key(query, normalize=normalize_embeddings), embedding) | |
| 167 | 185 | if ok: |
| 168 | - logger.debug(f"Successfully cached embedding for query: {query}") | |
| 186 | + logger.debug( | |
| 187 | + "Successfully cached text embedding | normalize=%s query=%s", | |
| 188 | + normalize_embeddings, | |
| 189 | + query, | |
| 190 | + ) | |
| 169 | 191 | return ok | ... | ... |
providers/embedding.py
| ... | ... | @@ -2,7 +2,11 @@ |
| 2 | 2 | |
| 3 | 3 | from __future__ import annotations |
| 4 | 4 | |
| 5 | -from config.services_config import get_embedding_config, get_embedding_base_url | |
| 5 | +from config.services_config import ( | |
| 6 | + get_embedding_config, | |
| 7 | + get_embedding_image_base_url, | |
| 8 | + get_embedding_text_base_url, | |
| 9 | +) | |
| 6 | 10 | |
| 7 | 11 | |
| 8 | 12 | def create_embedding_provider() -> "EmbeddingProvider": |
| ... | ... | @@ -21,13 +25,14 @@ class EmbeddingProvider: |
| 21 | 25 | """ |
| 22 | 26 | |
| 23 | 27 | def __init__(self) -> None: |
| 24 | - self._base_url = get_embedding_base_url() | |
| 28 | + self._text_base_url = get_embedding_text_base_url() | |
| 29 | + self._image_base_url = get_embedding_image_base_url() | |
| 25 | 30 | from embeddings.text_encoder import TextEmbeddingEncoder |
| 26 | 31 | from embeddings.image_encoder import CLIPImageEncoder |
| 27 | 32 | |
| 28 | 33 | # Initialize once; avoid per-access instantiation. |
| 29 | - self._text_encoder = TextEmbeddingEncoder(service_url=self._base_url) | |
| 30 | - self._image_encoder = CLIPImageEncoder(service_url=self._base_url) | |
| 34 | + self._text_encoder = TextEmbeddingEncoder(service_url=self._text_base_url) | |
| 35 | + self._image_encoder = CLIPImageEncoder(service_url=self._image_base_url) | |
| 31 | 36 | |
| 32 | 37 | @property |
| 33 | 38 | def text_encoder(self): | ... | ... |
requirements_embedding_service.txt
scripts/perf_api_benchmark.py
| ... | ... | @@ -6,6 +6,7 @@ Default scenarios (aligned with docs/搜索API对接指南.md): |
| 6 | 6 | - backend_search POST /search/ |
| 7 | 7 | - backend_suggest GET /search/suggestions |
| 8 | 8 | - embed_text POST /embed/text |
| 9 | +- embed_image POST /embed/image | |
| 9 | 10 | - translate POST /translate |
| 10 | 11 | - rerank POST /rerank |
| 11 | 12 | |
| ... | ... | @@ -158,6 +159,13 @@ def make_default_templates(tenant_id: str) -> Dict[str, List[RequestTemplate]]: |
| 158 | 159 | json_body=["wireless mouse", "gaming keyboard", "barbie doll"], |
| 159 | 160 | ) |
| 160 | 161 | ], |
| 162 | + "embed_image": [ | |
| 163 | + RequestTemplate( | |
| 164 | + method="POST", | |
| 165 | + path="/embed/image", | |
| 166 | + json_body=["/data/saas-search/docs/image-dress1.png"], | |
| 167 | + ) | |
| 168 | + ], | |
| 161 | 169 | "translate": [ |
| 162 | 170 | RequestTemplate( |
| 163 | 171 | method="POST", |
| ... | ... | @@ -220,7 +228,8 @@ def build_scenarios(args: argparse.Namespace) -> Dict[str, Scenario]: |
| 220 | 228 | scenario_base = { |
| 221 | 229 | "backend_search": args.backend_base, |
| 222 | 230 | "backend_suggest": args.backend_base, |
| 223 | - "embed_text": args.embedding_base, | |
| 231 | + "embed_text": args.embedding_text_base, | |
| 232 | + "embed_image": args.embedding_image_base, | |
| 224 | 233 | "translate": args.translator_base, |
| 225 | 234 | "rerank": args.reranker_base, |
| 226 | 235 | } |
| ... | ... | @@ -433,7 +442,7 @@ def parse_args() -> argparse.Namespace: |
| 433 | 442 | "--scenario", |
| 434 | 443 | type=str, |
| 435 | 444 | default="all", |
| 436 | - help="Scenario: backend_search | backend_suggest | embed_text | translate | rerank | all | comma-separated list", | |
| 445 | + help="Scenario: backend_search | backend_suggest | embed_text | embed_image | translate | rerank | all | comma-separated list", | |
| 437 | 446 | ) |
| 438 | 447 | parser.add_argument("--tenant-id", type=str, default="162", help="Tenant ID for backend search/suggest") |
| 439 | 448 | parser.add_argument("--duration", type=int, default=30, help="Duration seconds per scenario; <=0 means no duration cap") |
| ... | ... | @@ -443,7 +452,8 @@ def parse_args() -> argparse.Namespace: |
| 443 | 452 | parser.add_argument("--max-errors", type=int, default=0, help="Stop scenario when accumulated errors reach this value") |
| 444 | 453 | |
| 445 | 454 | parser.add_argument("--backend-base", type=str, default="http://127.0.0.1:6002", help="Base URL for backend search API") |
| 446 | - parser.add_argument("--embedding-base", type=str, default="http://127.0.0.1:6005", help="Base URL for embedding service") | |
| 455 | + parser.add_argument("--embedding-text-base", type=str, default="http://127.0.0.1:6005", help="Base URL for text embedding service") | |
| 456 | + parser.add_argument("--embedding-image-base", type=str, default="http://127.0.0.1:6008", help="Base URL for image embedding service") | |
| 447 | 457 | parser.add_argument("--translator-base", type=str, default="http://127.0.0.1:6006", help="Base URL for translation service") |
| 448 | 458 | parser.add_argument("--reranker-base", type=str, default="http://127.0.0.1:6007", help="Base URL for reranker service") |
| 449 | 459 | |
| ... | ... | @@ -547,7 +557,7 @@ async def main_async() -> int: |
| 547 | 557 | args = parse_args() |
| 548 | 558 | scenarios = build_scenarios(args) |
| 549 | 559 | |
| 550 | - all_names = ["backend_search", "backend_suggest", "embed_text", "translate", "rerank"] | |
| 560 | + all_names = ["backend_search", "backend_suggest", "embed_text", "embed_image", "translate", "rerank"] | |
| 551 | 561 | if args.scenario == "all": |
| 552 | 562 | run_names = [x for x in all_names if x in scenarios] |
| 553 | 563 | else: |
| ... | ... | @@ -595,7 +605,8 @@ async def main_async() -> int: |
| 595 | 605 | print(f" timeout={args.timeout}s") |
| 596 | 606 | print(f" max_errors={args.max_errors}") |
| 597 | 607 | print(f" backend_base={args.backend_base}") |
| 598 | - print(f" embedding_base={args.embedding_base}") | |
| 608 | + print(f" embedding_text_base={args.embedding_text_base}") | |
| 609 | + print(f" embedding_image_base={args.embedding_image_base}") | |
| 599 | 610 | print(f" translator_base={args.translator_base}") |
| 600 | 611 | print(f" reranker_base={args.reranker_base}") |
| 601 | 612 | if args.rerank_dynamic_docs: |
| ... | ... | @@ -643,7 +654,8 @@ async def main_async() -> int: |
| 643 | 654 | "timeout_sec": args.timeout, |
| 644 | 655 | "max_errors": args.max_errors, |
| 645 | 656 | "backend_base": args.backend_base, |
| 646 | - "embedding_base": args.embedding_base, | |
| 657 | + "embedding_text_base": args.embedding_text_base, | |
| 658 | + "embedding_image_base": args.embedding_image_base, | |
| 647 | 659 | "translator_base": args.translator_base, |
| 648 | 660 | "reranker_base": args.reranker_base, |
| 649 | 661 | "cases_file": args.cases_file or None, | ... | ... |
scripts/service_ctl.sh
| ... | ... | @@ -16,9 +16,9 @@ mkdir -p "${LOG_DIR}" |
| 16 | 16 | source "${PROJECT_ROOT}/scripts/lib/load_env.sh" |
| 17 | 17 | |
| 18 | 18 | CORE_SERVICES=("backend" "indexer" "frontend") |
| 19 | -OPTIONAL_SERVICES=("tei" "cnclip" "embedding" "translator" "reranker") | |
| 19 | +OPTIONAL_SERVICES=("tei" "cnclip" "embedding" "embedding-image" "translator" "reranker") | |
| 20 | 20 | FULL_SERVICES=("${OPTIONAL_SERVICES[@]}" "${CORE_SERVICES[@]}") |
| 21 | -STOP_ORDER_SERVICES=("frontend" "indexer" "backend" "reranker" "translator" "embedding" "cnclip" "tei") | |
| 21 | +STOP_ORDER_SERVICES=("frontend" "indexer" "backend" "reranker" "translator" "embedding-image" "embedding" "cnclip" "tei") | |
| 22 | 22 | |
| 23 | 23 | all_services() { |
| 24 | 24 | echo "${FULL_SERVICES[@]}" |
| ... | ... | @@ -30,7 +30,8 @@ get_port() { |
| 30 | 30 | backend) echo "${API_PORT:-6002}" ;; |
| 31 | 31 | indexer) echo "${INDEXER_PORT:-6004}" ;; |
| 32 | 32 | frontend) echo "${FRONTEND_PORT:-6003}" ;; |
| 33 | - embedding) echo "${EMBEDDING_PORT:-6005}" ;; | |
| 33 | + embedding) echo "${EMBEDDING_TEXT_PORT:-${EMBEDDING_PORT:-6005}}" ;; | |
| 34 | + embedding-image) echo "${EMBEDDING_IMAGE_PORT:-6008}" ;; | |
| 34 | 35 | translator) echo "${TRANSLATION_PORT:-6006}" ;; |
| 35 | 36 | reranker) echo "${RERANKER_PORT:-6007}" ;; |
| 36 | 37 | tei) echo "${TEI_PORT:-8080}" ;; |
| ... | ... | @@ -65,7 +66,8 @@ service_start_cmd() { |
| 65 | 66 | backend) echo "./scripts/start_backend.sh" ;; |
| 66 | 67 | indexer) echo "./scripts/start_indexer.sh" ;; |
| 67 | 68 | frontend) echo "./scripts/start_frontend.sh" ;; |
| 68 | - embedding) echo "./scripts/start_embedding_service.sh" ;; | |
| 69 | + embedding) echo "./scripts/start_embedding_text_service.sh" ;; | |
| 70 | + embedding-image) echo "./scripts/start_embedding_image_service.sh" ;; | |
| 69 | 71 | translator) echo "./scripts/start_translator.sh" ;; |
| 70 | 72 | reranker) echo "./scripts/start_reranker.sh" ;; |
| 71 | 73 | tei) echo "./scripts/start_tei_service.sh" ;; |
| ... | ... | @@ -77,7 +79,7 @@ service_start_cmd() { |
| 77 | 79 | service_exists() { |
| 78 | 80 | local service="$1" |
| 79 | 81 | case "${service}" in |
| 80 | - backend|indexer|frontend|embedding|translator|reranker|tei|cnclip) return 0 ;; | |
| 82 | + backend|indexer|frontend|embedding|embedding-image|translator|reranker|tei|cnclip) return 0 ;; | |
| 81 | 83 | *) return 1 ;; |
| 82 | 84 | esac |
| 83 | 85 | } |
| ... | ... | @@ -95,7 +97,7 @@ validate_targets() { |
| 95 | 97 | health_path_for_service() { |
| 96 | 98 | local service="$1" |
| 97 | 99 | case "${service}" in |
| 98 | - backend|indexer|embedding|translator|reranker|tei) echo "/health" ;; | |
| 100 | + backend|indexer|embedding|embedding-image|translator|reranker|tei) echo "/health" ;; | |
| 99 | 101 | *) echo "" ;; |
| 100 | 102 | esac |
| 101 | 103 | } | ... | ... |
scripts/start_embedding_service.sh
| 1 | 1 | #!/bin/bash |
| 2 | 2 | # |
| 3 | -# Start Embedding Service (port 6005). | |
| 3 | +# Start Embedding Service (combined/text/image mode). | |
| 4 | 4 | # |
| 5 | 5 | # Design: |
| 6 | 6 | # - Run in isolated venv `.venv-embedding` (do not pollute main `.venv`) |
| ... | ... | @@ -33,18 +33,55 @@ CLIP_AS_SERVICE_SERVER=$("${PYTHON_BIN}" -c "from embeddings.config import CONFI |
| 33 | 33 | CLIP_AS_SERVICE_MODEL_NAME=$("${PYTHON_BIN}" -c "from embeddings.config import CONFIG; print(CONFIG.CLIP_AS_SERVICE_MODEL_NAME)") |
| 34 | 34 | TEXT_BACKEND=$("${PYTHON_BIN}" -c "from config.services_config import get_embedding_backend_config; print(get_embedding_backend_config()[0])") |
| 35 | 35 | TEI_BASE_URL=$("${PYTHON_BIN}" -c "import os; from config.services_config import get_embedding_backend_config; from embeddings.config import CONFIG; _, cfg = get_embedding_backend_config(); print(os.getenv('TEI_BASE_URL') or cfg.get('base_url') or CONFIG.TEI_BASE_URL)") |
| 36 | +SERVICE_KIND="${1:-${EMBEDDING_SERVICE_KIND:-all}}" | |
| 37 | +SERVICE_KIND="$(echo "${SERVICE_KIND}" | tr '[:upper:]' '[:lower:]')" | |
| 38 | +if [[ "${SERVICE_KIND}" != "all" && "${SERVICE_KIND}" != "text" && "${SERVICE_KIND}" != "image" ]]; then | |
| 39 | + echo "ERROR: invalid embedding service kind: ${SERVICE_KIND}. expected all|text|image" >&2 | |
| 40 | + exit 1 | |
| 41 | +fi | |
| 42 | + | |
| 43 | +ENABLE_TEXT_MODEL="${EMBEDDING_ENABLE_TEXT_MODEL:-true}" | |
| 44 | +ENABLE_TEXT_MODEL="$(echo "${ENABLE_TEXT_MODEL}" | tr '[:upper:]' '[:lower:]')" | |
| 36 | 45 | ENABLE_IMAGE_MODEL="${EMBEDDING_ENABLE_IMAGE_MODEL:-true}" |
| 37 | 46 | ENABLE_IMAGE_MODEL="$(echo "${ENABLE_IMAGE_MODEL}" | tr '[:upper:]' '[:lower:]')" |
| 38 | -if [[ "${ENABLE_IMAGE_MODEL}" == "1" || "${ENABLE_IMAGE_MODEL}" == "true" || "${ENABLE_IMAGE_MODEL}" == "yes" ]]; then | |
| 39 | - IMAGE_MODEL_ENABLED=1 | |
| 40 | -else | |
| 41 | - IMAGE_MODEL_ENABLED=0 | |
| 47 | + | |
| 48 | +TEXT_MODEL_ENABLED=0 | |
| 49 | +IMAGE_MODEL_ENABLED=0 | |
| 50 | +if [[ "${SERVICE_KIND}" == "all" || "${SERVICE_KIND}" == "text" ]]; then | |
| 51 | + if [[ "${ENABLE_TEXT_MODEL}" == "1" || "${ENABLE_TEXT_MODEL}" == "true" || "${ENABLE_TEXT_MODEL}" == "yes" ]]; then | |
| 52 | + TEXT_MODEL_ENABLED=1 | |
| 53 | + fi | |
| 54 | +fi | |
| 55 | +if [[ "${SERVICE_KIND}" == "all" || "${SERVICE_KIND}" == "image" ]]; then | |
| 56 | + if [[ "${ENABLE_IMAGE_MODEL}" == "1" || "${ENABLE_IMAGE_MODEL}" == "true" || "${ENABLE_IMAGE_MODEL}" == "yes" ]]; then | |
| 57 | + IMAGE_MODEL_ENABLED=1 | |
| 58 | + fi | |
| 42 | 59 | fi |
| 43 | 60 | |
| 44 | 61 | EMBEDDING_SERVICE_HOST="${EMBEDDING_HOST:-${DEFAULT_EMBEDDING_SERVICE_HOST}}" |
| 45 | -EMBEDDING_SERVICE_PORT="${EMBEDDING_PORT:-${DEFAULT_EMBEDDING_SERVICE_PORT}}" | |
| 62 | +if [[ "${SERVICE_KIND}" == "text" ]]; then | |
| 63 | + EMBEDDING_SERVICE_PORT="${EMBEDDING_TEXT_PORT:-${EMBEDDING_PORT:-${DEFAULT_EMBEDDING_SERVICE_PORT}}}" | |
| 64 | +elif [[ "${SERVICE_KIND}" == "image" ]]; then | |
| 65 | + EMBEDDING_SERVICE_PORT="${EMBEDDING_IMAGE_PORT:-6008}" | |
| 66 | +else | |
| 67 | + EMBEDDING_SERVICE_PORT="${EMBEDDING_PORT:-${DEFAULT_EMBEDDING_SERVICE_PORT}}" | |
| 68 | +fi | |
| 46 | 69 | |
| 47 | -if [[ "${TEXT_BACKEND}" == "tei" ]]; then | |
| 70 | +export EMBEDDING_SERVICE_KIND="${SERVICE_KIND}" | |
| 71 | +export EMBEDDING_HOST="${EMBEDDING_SERVICE_HOST}" | |
| 72 | +export EMBEDDING_PORT="${EMBEDDING_SERVICE_PORT}" | |
| 73 | +if [[ "${TEXT_MODEL_ENABLED}" == "1" ]]; then | |
| 74 | + export EMBEDDING_ENABLE_TEXT_MODEL=true | |
| 75 | +else | |
| 76 | + export EMBEDDING_ENABLE_TEXT_MODEL=false | |
| 77 | +fi | |
| 78 | +if [[ "${IMAGE_MODEL_ENABLED}" == "1" ]]; then | |
| 79 | + export EMBEDDING_ENABLE_IMAGE_MODEL=true | |
| 80 | +else | |
| 81 | + export EMBEDDING_ENABLE_IMAGE_MODEL=false | |
| 82 | +fi | |
| 83 | + | |
| 84 | +if [[ "${TEXT_MODEL_ENABLED}" == "1" && "${TEXT_BACKEND}" == "tei" ]]; then | |
| 48 | 85 | if ! curl -sf "${TEI_BASE_URL%/}/health" >/dev/null 2>&1; then |
| 49 | 86 | echo "ERROR: TEI backend is selected but TEI is not reachable: ${TEI_BASE_URL}/health" >&2 |
| 50 | 87 | echo "Please start TEI first: ./scripts/start_tei_service.sh" >&2 |
| ... | ... | @@ -81,12 +118,16 @@ fi |
| 81 | 118 | echo "========================================" |
| 82 | 119 | echo "Starting Embedding Service" |
| 83 | 120 | echo "========================================" |
| 121 | +echo "Kind: ${SERVICE_KIND}" | |
| 84 | 122 | echo "Python: ${PYTHON_BIN}" |
| 85 | 123 | echo "Host: ${EMBEDDING_SERVICE_HOST}" |
| 86 | 124 | echo "Port: ${EMBEDDING_SERVICE_PORT}" |
| 87 | -echo "Text backend: ${TEXT_BACKEND}" | |
| 88 | -echo "Text max inflight: ${TEXT_MAX_INFLIGHT:-32}" | |
| 89 | -if [[ "${TEXT_BACKEND}" == "tei" ]]; then | |
| 125 | +echo "Text backend enabled: ${TEXT_MODEL_ENABLED}" | |
| 126 | +if [[ "${TEXT_MODEL_ENABLED}" == "1" ]]; then | |
| 127 | + echo "Text backend: ${TEXT_BACKEND}" | |
| 128 | + echo "Text max inflight: ${TEXT_MAX_INFLIGHT:-32}" | |
| 129 | +fi | |
| 130 | +if [[ "${TEXT_MODEL_ENABLED}" == "1" && "${TEXT_BACKEND}" == "tei" ]]; then | |
| 90 | 131 | echo "TEI URL: ${TEI_BASE_URL}" |
| 91 | 132 | fi |
| 92 | 133 | if [[ "${IMAGE_MODEL_ENABLED}" == "0" ]]; then |
| ... | ... | @@ -94,12 +135,20 @@ if [[ "${IMAGE_MODEL_ENABLED}" == "0" ]]; then |
| 94 | 135 | elif [[ "${USE_CLIP_AS_SERVICE}" == "1" ]]; then |
| 95 | 136 | echo "Image backend: clip-as-service (${CLIP_AS_SERVICE_SERVER}, model=${CLIP_AS_SERVICE_MODEL_NAME})" |
| 96 | 137 | fi |
| 97 | -echo "Image max inflight: ${IMAGE_MAX_INFLIGHT:-1}" | |
| 138 | +if [[ "${IMAGE_MODEL_ENABLED}" == "1" ]]; then | |
| 139 | + echo "Image max inflight: ${IMAGE_MAX_INFLIGHT:-1}" | |
| 140 | +fi | |
| 98 | 141 | echo "Logs: logs/embedding_api.log, logs/embedding_api_error.log, logs/verbose/embedding_verbose.log" |
| 99 | 142 | echo |
| 100 | 143 | echo "Tips:" |
| 101 | 144 | echo " - Use a single worker (GPU models cannot be safely duplicated across workers)." |
| 102 | -echo " - Clients can set EMBEDDING_SERVICE_URL=http://localhost:${EMBEDDING_SERVICE_PORT}" | |
| 145 | +if [[ "${SERVICE_KIND}" == "text" ]]; then | |
| 146 | + echo " - Clients can set EMBEDDING_TEXT_SERVICE_URL=http://localhost:${EMBEDDING_SERVICE_PORT}" | |
| 147 | +elif [[ "${SERVICE_KIND}" == "image" ]]; then | |
| 148 | + echo " - Clients can set EMBEDDING_IMAGE_SERVICE_URL=http://localhost:${EMBEDDING_SERVICE_PORT}" | |
| 149 | +else | |
| 150 | + echo " - Clients can set EMBEDDING_SERVICE_URL=http://localhost:${EMBEDDING_SERVICE_PORT}" | |
| 151 | +fi | |
| 103 | 152 | echo |
| 104 | 153 | |
| 105 | 154 | UVICORN_LOG_LEVEL="${EMBEDDING_UVICORN_LOG_LEVEL:-info}" | ... | ... |
tests/test_embedding_pipeline.py
| ... | ... | @@ -12,7 +12,9 @@ from config import ( |
| 12 | 12 | SearchConfig, |
| 13 | 13 | ) |
| 14 | 14 | from embeddings.text_encoder import TextEmbeddingEncoder |
| 15 | +from embeddings.image_encoder import CLIPImageEncoder | |
| 15 | 16 | from embeddings.bf16 import encode_embedding_for_redis |
| 17 | +from embeddings.cache_keys import build_image_cache_key, build_text_cache_key | |
| 16 | 18 | from query import QueryParser |
| 17 | 19 | |
| 18 | 20 | |
| ... | ... | @@ -67,6 +69,18 @@ class _FakeQueryEncoder: |
| 67 | 69 | return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object) |
| 68 | 70 | |
| 69 | 71 | |
| 72 | +class _FakeEmbeddingCache: | |
| 73 | + def __init__(self): | |
| 74 | + self.store: Dict[str, np.ndarray] = {} | |
| 75 | + | |
| 76 | + def get(self, key: str): | |
| 77 | + return self.store.get(key) | |
| 78 | + | |
| 79 | + def set(self, key: str, embedding: np.ndarray): | |
| 80 | + self.store[key] = np.asarray(embedding, dtype=np.float32) | |
| 81 | + return True | |
| 82 | + | |
| 83 | + | |
| 70 | 84 | def _build_test_config() -> SearchConfig: |
| 71 | 85 | return SearchConfig( |
| 72 | 86 | field_boosts={"title.en": 3.0}, |
| ... | ... | @@ -91,8 +105,8 @@ def _build_test_config() -> SearchConfig: |
| 91 | 105 | |
| 92 | 106 | |
| 93 | 107 | def test_text_embedding_encoder_response_alignment(monkeypatch): |
| 94 | - fake_redis = _FakeRedis() | |
| 95 | - monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) | |
| 108 | + fake_cache = _FakeEmbeddingCache() | |
| 109 | + monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) | |
| 96 | 110 | |
| 97 | 111 | def _fake_post(url, json, timeout, **kwargs): |
| 98 | 112 | assert url.endswith("/embed/text") |
| ... | ... | @@ -112,8 +126,8 @@ def test_text_embedding_encoder_response_alignment(monkeypatch): |
| 112 | 126 | |
| 113 | 127 | |
| 114 | 128 | def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): |
| 115 | - fake_redis = _FakeRedis() | |
| 116 | - monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) | |
| 129 | + fake_cache = _FakeEmbeddingCache() | |
| 130 | + monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) | |
| 117 | 131 | |
| 118 | 132 | def _fake_post(url, json, timeout, **kwargs): |
| 119 | 133 | return _FakeResponse([[0.1, 0.2], None]) |
| ... | ... | @@ -126,10 +140,10 @@ def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): |
| 126 | 140 | |
| 127 | 141 | |
| 128 | 142 | def test_text_embedding_encoder_cache_hit(monkeypatch): |
| 129 | - fake_redis = _FakeRedis() | |
| 143 | + fake_cache = _FakeEmbeddingCache() | |
| 130 | 144 | cached = np.array([0.9, 0.8], dtype=np.float32) |
| 131 | - fake_redis.store["embedding:cached-text"] = encode_embedding_for_redis(cached) | |
| 132 | - monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) | |
| 145 | + fake_cache.store[build_text_cache_key("cached-text", normalize=True)] = cached | |
| 146 | + monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) | |
| 133 | 147 | |
| 134 | 148 | calls = {"count": 0} |
| 135 | 149 | |
| ... | ... | @@ -147,6 +161,29 @@ def test_text_embedding_encoder_cache_hit(monkeypatch): |
| 147 | 161 | assert np.allclose(out[1], np.array([0.3, 0.4], dtype=np.float32)) |
| 148 | 162 | |
| 149 | 163 | |
| 164 | +def test_image_embedding_encoder_cache_hit(monkeypatch): | |
| 165 | + fake_cache = _FakeEmbeddingCache() | |
| 166 | + cached = np.array([0.5, 0.6], dtype=np.float32) | |
| 167 | + url = "https://example.com/a.jpg" | |
| 168 | + fake_cache.store[build_image_cache_key(url, normalize=True)] = cached | |
| 169 | + monkeypatch.setattr("embeddings.image_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) | |
| 170 | + | |
| 171 | + calls = {"count": 0} | |
| 172 | + | |
| 173 | + def _fake_post(url, params, json, timeout, **kwargs): | |
| 174 | + calls["count"] += 1 | |
| 175 | + return _FakeResponse([[0.1, 0.2]]) | |
| 176 | + | |
| 177 | + monkeypatch.setattr("embeddings.image_encoder.requests.post", _fake_post) | |
| 178 | + | |
| 179 | + encoder = CLIPImageEncoder(service_url="http://127.0.0.1:6008") | |
| 180 | + out = encoder.encode_batch(["https://example.com/a.jpg", "https://example.com/b.jpg"]) | |
| 181 | + | |
| 182 | + assert calls["count"] == 1 | |
| 183 | + assert np.allclose(out[0], cached) | |
| 184 | + assert np.allclose(out[1], np.array([0.1, 0.2], dtype=np.float32)) | |
| 185 | + | |
| 186 | + | |
| 150 | 187 | def test_query_parser_generates_query_vector_with_encoder(): |
| 151 | 188 | parser = QueryParser( |
| 152 | 189 | config=_build_test_config(), | ... | ... |
tests/test_embedding_service_limits.py
| ... | ... | @@ -28,6 +28,24 @@ class _FakeTextModel: |
| 28 | 28 | return [np.array([1.0, 2.0, 3.0], dtype=np.float32)] |
| 29 | 29 | |
| 30 | 30 | |
| 31 | +class _FakeImageModel: | |
| 32 | + def encode_image_urls(self, urls, batch_size, normalize_embeddings): | |
| 33 | + raise AssertionError("image backend should not be called on cache hit") | |
| 34 | + | |
| 35 | + | |
| 36 | +class _FakeCache: | |
| 37 | + def __init__(self, store=None): | |
| 38 | + self.store = store or {} | |
| 39 | + self.redis_client = object() | |
| 40 | + | |
| 41 | + def get(self, key): | |
| 42 | + return self.store.get(key) | |
| 43 | + | |
| 44 | + def set(self, key, value): | |
| 45 | + self.store[key] = np.asarray(value, dtype=np.float32) | |
| 46 | + return True | |
| 47 | + | |
| 48 | + | |
| 31 | 49 | def test_health_exposes_limit_stats(monkeypatch): |
| 32 | 50 | monkeypatch.setattr( |
| 33 | 51 | embedding_server, |
| ... | ... | @@ -39,6 +57,8 @@ def test_health_exposes_limit_stats(monkeypatch): |
| 39 | 57 | "_image_request_limiter", |
| 40 | 58 | embedding_server._InflightLimiter("image", 1), |
| 41 | 59 | ) |
| 60 | + monkeypatch.setattr(embedding_server, "_text_model", object()) | |
| 61 | + monkeypatch.setattr(embedding_server, "_image_model", object()) | |
| 42 | 62 | |
| 43 | 63 | payload = embedding_server.health() |
| 44 | 64 | |
| ... | ... | @@ -53,6 +73,7 @@ def test_embed_image_rejects_when_image_lane_is_full(monkeypatch): |
| 53 | 73 | acquired, _ = limiter.try_acquire() |
| 54 | 74 | assert acquired is True |
| 55 | 75 | monkeypatch.setattr(embedding_server, "_image_request_limiter", limiter) |
| 76 | + monkeypatch.setattr(embedding_server, "_image_model", object()) | |
| 56 | 77 | |
| 57 | 78 | response = _DummyResponse() |
| 58 | 79 | with pytest.raises(embedding_server.HTTPException) as exc_info: |
| ... | ... | @@ -91,3 +112,29 @@ def test_embed_text_returns_request_id_and_vector(monkeypatch): |
| 91 | 112 | |
| 92 | 113 | assert response.headers["X-Request-ID"] == "req-123456" |
| 93 | 114 | assert result == [[1.0, 2.0, 3.0]] |
| 115 | + | |
| 116 | + | |
| 117 | +def test_embed_image_service_cache_hit_bypasses_backend(monkeypatch): | |
| 118 | + cache_key = embedding_server.build_image_cache_key("https://example.com/a.jpg", normalize=True) | |
| 119 | + fake_cache = _FakeCache({cache_key: np.array([0.7, 0.8], dtype=np.float32)}) | |
| 120 | + monkeypatch.setattr( | |
| 121 | + embedding_server, | |
| 122 | + "_image_request_limiter", | |
| 123 | + embedding_server._InflightLimiter("image", 1), | |
| 124 | + ) | |
| 125 | + monkeypatch.setattr(embedding_server, "_image_model", _FakeImageModel()) | |
| 126 | + monkeypatch.setattr(embedding_server, "_image_cache", fake_cache) | |
| 127 | + | |
| 128 | + request = _DummyRequest(headers={"X-Request-ID": "img-cache-hit"}) | |
| 129 | + response = _DummyResponse() | |
| 130 | + result = asyncio.run( | |
| 131 | + embedding_server.embed_image( | |
| 132 | + ["https://example.com/a.jpg"], | |
| 133 | + request, | |
| 134 | + response, | |
| 135 | + normalize=True, | |
| 136 | + ) | |
| 137 | + ) | |
| 138 | + | |
| 139 | + assert response.headers["X-Request-ID"] == "img-cache-hit" | |
| 140 | + assert result == [[0.699999988079071, 0.800000011920929]] | ... | ... |