diff --git a/config/__init__.py b/config/__init__.py index 32de35a..2c6bbc0 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -27,6 +27,8 @@ from .services_config import ( get_rerank_backend_config, get_translation_base_url, get_embedding_base_url, + get_embedding_text_base_url, + get_embedding_image_base_url, get_rerank_service_url, get_translation_cache_config, ServiceConfig, @@ -53,6 +55,8 @@ __all__ = [ 'get_rerank_backend_config', 'get_translation_base_url', 'get_embedding_base_url', + 'get_embedding_text_base_url', + 'get_embedding_image_base_url', 'get_rerank_service_url', 'get_translation_cache_config', 'ServiceConfig', diff --git a/config/config.yaml b/config/config.yaml index 34c227d..892574a 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -199,6 +199,8 @@ services: providers: http: base_url: "http://127.0.0.1:6005" + text_base_url: "http://127.0.0.1:6005" + image_base_url: "http://127.0.0.1:6008" # 服务内文本后端(embedding 进程启动时读取) backend: "tei" # tei | local_st backends: diff --git a/config/env_config.py b/config/env_config.py index aa989be..9971068 100644 --- a/config/env_config.py +++ b/config/env_config.py @@ -61,6 +61,10 @@ INDEXER_PORT = int(os.getenv('INDEXER_PORT', 6004)) # Optional dependent services EMBEDDING_HOST = os.getenv('EMBEDDING_HOST', '127.0.0.1') EMBEDDING_PORT = int(os.getenv('EMBEDDING_PORT', 6005)) +EMBEDDING_TEXT_HOST = os.getenv('EMBEDDING_TEXT_HOST', EMBEDDING_HOST) +EMBEDDING_TEXT_PORT = int(os.getenv('EMBEDDING_TEXT_PORT', EMBEDDING_PORT)) +EMBEDDING_IMAGE_HOST = os.getenv('EMBEDDING_IMAGE_HOST', EMBEDDING_HOST) +EMBEDDING_IMAGE_PORT = int(os.getenv('EMBEDDING_IMAGE_PORT', 6008)) TRANSLATION_HOST = os.getenv('TRANSLATION_HOST', '127.0.0.1') TRANSLATION_PORT = int(os.getenv('TRANSLATION_PORT', 6006)) RERANKER_HOST = os.getenv('RERANKER_HOST', '127.0.0.1') @@ -74,6 +78,12 @@ INDEXER_BASE_URL = os.getenv('INDEXER_BASE_URL') or ( f'http://localhost:{INDEXER_PORT}' if INDEXER_HOST == '0.0.0.0' else f'http://{INDEXER_HOST}:{INDEXER_PORT}' ) EMBEDDING_SERVICE_URL = os.getenv('EMBEDDING_SERVICE_URL') or f'http://{EMBEDDING_HOST}:{EMBEDDING_PORT}' +EMBEDDING_TEXT_SERVICE_URL = os.getenv('EMBEDDING_TEXT_SERVICE_URL') or ( + f'http://{EMBEDDING_TEXT_HOST}:{EMBEDDING_TEXT_PORT}' +) +EMBEDDING_IMAGE_SERVICE_URL = os.getenv('EMBEDDING_IMAGE_SERVICE_URL') or ( + f'http://{EMBEDDING_IMAGE_HOST}:{EMBEDDING_IMAGE_PORT}' +) RERANKER_SERVICE_URL = os.getenv('RERANKER_SERVICE_URL') or f'http://{RERANKER_HOST}:{RERANKER_PORT}/rerank' # Model IDs / paths diff --git a/config/services_config.py b/config/services_config.py index 9141322..50ed7e0 100644 --- a/config/services_config.py +++ b/config/services_config.py @@ -79,10 +79,17 @@ def _resolve_embedding() -> ServiceConfig: raise ValueError(f"Unsupported embedding provider: {provider}") env_url = os.getenv("EMBEDDING_SERVICE_URL") - if env_url and provider == "http": + env_text_url = os.getenv("EMBEDDING_TEXT_SERVICE_URL") + env_image_url = os.getenv("EMBEDDING_IMAGE_SERVICE_URL") + if (env_url or env_text_url or env_image_url) and provider == "http": providers = dict(providers) providers["http"] = dict(providers.get("http", {})) - providers["http"]["base_url"] = env_url.rstrip("/") + if env_url: + providers["http"]["base_url"] = env_url.rstrip("/") + if env_text_url: + providers["http"]["text_base_url"] = env_text_url.rstrip("/") + if env_image_url: + providers["http"]["image_base_url"] = env_image_url.rstrip("/") return ServiceConfig(provider=provider, providers=providers) @@ -165,12 +172,44 @@ def get_translation_cache_config() -> Dict[str, Any]: def get_embedding_base_url() -> str: - base = os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_config().providers.get("http", {}).get("base_url") + provider_cfg = get_embedding_config().providers.get("http", {}) + base = ( + os.getenv("EMBEDDING_SERVICE_URL") + or provider_cfg.get("base_url") + or provider_cfg.get("text_base_url") + or provider_cfg.get("image_base_url") + ) if not base: raise ValueError("Embedding HTTP base_url is not configured") return str(base).rstrip("/") +def get_embedding_text_base_url() -> str: + provider_cfg = get_embedding_config().providers.get("http", {}) + base = ( + os.getenv("EMBEDDING_TEXT_SERVICE_URL") + or provider_cfg.get("text_base_url") + or os.getenv("EMBEDDING_SERVICE_URL") + or provider_cfg.get("base_url") + ) + if not base: + raise ValueError("Embedding text HTTP base_url is not configured") + return str(base).rstrip("/") + + +def get_embedding_image_base_url() -> str: + provider_cfg = get_embedding_config().providers.get("http", {}) + base = ( + os.getenv("EMBEDDING_IMAGE_SERVICE_URL") + or provider_cfg.get("image_base_url") + or os.getenv("EMBEDDING_SERVICE_URL") + or provider_cfg.get("base_url") + ) + if not base: + raise ValueError("Embedding image HTTP base_url is not configured") + return str(base).rstrip("/") + + def get_rerank_base_url() -> str: base = ( os.getenv("RERANKER_SERVICE_URL") diff --git a/embeddings/cache_keys.py b/embeddings/cache_keys.py new file mode 100644 index 0000000..1bb887a --- /dev/null +++ b/embeddings/cache_keys.py @@ -0,0 +1,13 @@ +"""Shared cache key helpers for embedding inputs.""" + +from __future__ import annotations + + +def build_text_cache_key(text: str, *, normalize: bool) -> str: + normalized_text = str(text or "").strip() + return f"norm:{1 if normalize else 0}:text:{normalized_text}" + + +def build_image_cache_key(url: str, *, normalize: bool) -> str: + normalized_url = str(url or "").strip() + return f"norm:{1 if normalize else 0}:image:{normalized_url}" diff --git a/embeddings/image_encoder.py b/embeddings/image_encoder.py index d2b8e4c..94c7ada 100644 --- a/embeddings/image_encoder.py +++ b/embeddings/image_encoder.py @@ -10,8 +10,9 @@ from PIL import Image logger = logging.getLogger(__name__) -from config.services_config import get_embedding_base_url +from config.services_config import get_embedding_image_base_url from config.env_config import REDIS_CONFIG +from embeddings.cache_keys import build_image_cache_key from embeddings.redis_embedding_cache import RedisEmbeddingCache @@ -23,7 +24,12 @@ class CLIPImageEncoder: """ def __init__(self, service_url: Optional[str] = None): - resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url() + resolved_url = ( + service_url + or os.getenv("EMBEDDING_IMAGE_SERVICE_URL") + or os.getenv("EMBEDDING_SERVICE_URL") + or get_embedding_image_base_url() + ) self.service_url = str(resolved_url).rstrip("/") self.endpoint = f"{self.service_url}/embed/image" # Reuse embedding cache prefix, but separate namespace for images to avoid collisions. @@ -75,7 +81,8 @@ class CLIPImageEncoder: Returns: Embedding vector """ - cached = self.cache.get(url) + cache_key = build_image_cache_key(url, normalize=normalize_embeddings) + cached = self.cache.get(cache_key) if cached is not None: return cached @@ -85,7 +92,7 @@ class CLIPImageEncoder: vec = np.array(response_data[0], dtype=np.float32) if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): raise RuntimeError(f"Invalid image embedding returned for URL: {url}") - self.cache.set(url, vec) + self.cache.set(cache_key, vec) return vec def encode_batch( @@ -116,7 +123,8 @@ class CLIPImageEncoder: normalized_urls = [str(u).strip() for u in images] # type: ignore[list-item] for pos, url in enumerate(normalized_urls): - cached = self.cache.get(url) + cache_key = build_image_cache_key(url, normalize=normalize_embeddings) + cached = self.cache.get(cache_key) if cached is not None: results.append(cached) else: @@ -139,7 +147,7 @@ class CLIPImageEncoder: vec = np.array(embedding, dtype=np.float32) if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): raise RuntimeError(f"Invalid image embedding returned for URL: {url}") - self.cache.set(url, vec) + self.cache.set(build_image_cache_key(url, normalize=normalize_embeddings), vec) pos = pending_positions[i + j] results[pos] = vec diff --git a/embeddings/redis_embedding_cache.py b/embeddings/redis_embedding_cache.py index 0a1a1e4..50298e6 100644 --- a/embeddings/redis_embedding_cache.py +++ b/embeddings/redis_embedding_cache.py @@ -12,10 +12,13 @@ from __future__ import annotations import logging from datetime import timedelta -from typing import Optional +from typing import Any, Optional import numpy as np -import redis +try: + import redis +except ImportError: # pragma: no cover - runtime fallback for minimal envs + redis = None # type: ignore[assignment] from config.env_config import REDIS_CONFIG from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis @@ -40,6 +43,11 @@ class RedisEmbeddingCache: self.redis_client = redis_client return + if redis is None: + logger.warning("redis package is not installed, continuing without embedding cache") + self.redis_client = None + return + try: client = redis.Redis( host=REDIS_CONFIG.get("host", "localhost"), @@ -104,4 +112,3 @@ class RedisEmbeddingCache: except Exception as e: logger.warning("Error storing embedding in cache: %s", e) return False - diff --git a/embeddings/server.py b/embeddings/server.py index bd244a0..85e0c28 100644 --- a/embeddings/server.py +++ b/embeddings/server.py @@ -21,9 +21,12 @@ import numpy as np from fastapi import FastAPI, HTTPException, Request, Response from fastapi.concurrency import run_in_threadpool +from config.env_config import REDIS_CONFIG from config.services_config import get_embedding_backend_config +from embeddings.cache_keys import build_image_cache_key, build_text_cache_key from embeddings.config import CONFIG from embeddings.protocols import ImageEncoderProtocol +from embeddings.redis_embedding_cache import RedisEmbeddingCache app = FastAPI(title="saas-search Embedding Service", version="1.0.0") @@ -106,8 +109,15 @@ verbose_logger = logging.getLogger("embedding.verbose") _text_model: Optional[Any] = None _image_model: Optional[ImageEncoderProtocol] = None _text_backend_name: str = "" -open_text_model = os.getenv("EMBEDDING_ENABLE_TEXT_MODEL", "true").lower() in ("1", "true", "yes") -open_image_model = os.getenv("EMBEDDING_ENABLE_IMAGE_MODEL", "true").lower() in ("1", "true", "yes") +_SERVICE_KIND = (os.getenv("EMBEDDING_SERVICE_KIND", "all") or "all").strip().lower() +if _SERVICE_KIND not in {"all", "text", "image"}: + raise RuntimeError( + f"Invalid EMBEDDING_SERVICE_KIND={_SERVICE_KIND!r}; expected all, text, or image" + ) +_TEXT_ENABLED_BY_ENV = os.getenv("EMBEDDING_ENABLE_TEXT_MODEL", "true").lower() in ("1", "true", "yes") +_IMAGE_ENABLED_BY_ENV = os.getenv("EMBEDDING_ENABLE_IMAGE_MODEL", "true").lower() in ("1", "true", "yes") +open_text_model = _TEXT_ENABLED_BY_ENV and _SERVICE_KIND in {"all", "text"} +open_image_model = _IMAGE_ENABLED_BY_ENV and _SERVICE_KIND in {"all", "image"} _text_encode_lock = threading.Lock() _image_encode_lock = threading.Lock() @@ -125,6 +135,71 @@ _LOG_PREVIEW_COUNT = max(1, int(os.getenv("EMBEDDING_LOG_PREVIEW_COUNT", "3"))) _LOG_TEXT_PREVIEW_CHARS = max(32, int(os.getenv("EMBEDDING_LOG_TEXT_PREVIEW_CHARS", "120"))) _LOG_IMAGE_PREVIEW_CHARS = max(32, int(os.getenv("EMBEDDING_LOG_IMAGE_PREVIEW_CHARS", "180"))) _VECTOR_PREVIEW_DIMS = max(1, int(os.getenv("EMBEDDING_VECTOR_PREVIEW_DIMS", "6"))) +_CACHE_PREFIX = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" + + +@dataclass +class _EmbedResult: + vectors: List[Optional[List[float]]] + cache_hits: int + cache_misses: int + backend_elapsed_ms: float + mode: str + + +class _EndpointStats: + def __init__(self, name: str): + self.name = name + self._lock = threading.Lock() + self.request_total = 0 + self.success_total = 0 + self.failure_total = 0 + self.rejected_total = 0 + self.cache_hits = 0 + self.cache_misses = 0 + self.total_latency_ms = 0.0 + self.total_backend_latency_ms = 0.0 + + def record_rejected(self) -> None: + with self._lock: + self.request_total += 1 + self.rejected_total += 1 + + def record_completed( + self, + *, + success: bool, + latency_ms: float, + backend_latency_ms: float, + cache_hits: int, + cache_misses: int, + ) -> None: + with self._lock: + self.request_total += 1 + if success: + self.success_total += 1 + else: + self.failure_total += 1 + self.cache_hits += max(0, int(cache_hits)) + self.cache_misses += max(0, int(cache_misses)) + self.total_latency_ms += max(0.0, float(latency_ms)) + self.total_backend_latency_ms += max(0.0, float(backend_latency_ms)) + + def snapshot(self) -> Dict[str, Any]: + with self._lock: + completed = self.success_total + self.failure_total + return { + "request_total": self.request_total, + "success_total": self.success_total, + "failure_total": self.failure_total, + "rejected_total": self.rejected_total, + "cache_hits": self.cache_hits, + "cache_misses": self.cache_misses, + "avg_latency_ms": round(self.total_latency_ms / completed, 3) if completed else 0.0, + "avg_backend_latency_ms": round(self.total_backend_latency_ms / completed, 3) + if completed + else 0.0, + } class _InflightLimiter: @@ -176,6 +251,10 @@ class _InflightLimiter: _text_request_limiter = _InflightLimiter(name="text", limit=_TEXT_MAX_INFLIGHT) _image_request_limiter = _InflightLimiter(name="image", limit=_IMAGE_MAX_INFLIGHT) +_text_stats = _EndpointStats(name="text") +_image_stats = _EndpointStats(name="image") +_text_cache = RedisEmbeddingCache(key_prefix=_CACHE_PREFIX, namespace="") +_image_cache = RedisEmbeddingCache(key_prefix=_CACHE_PREFIX, namespace="image") @dataclass @@ -377,7 +456,12 @@ def load_models(): """Load models at service startup to avoid first-request latency.""" global _text_model, _image_model, _text_backend_name - logger.info("Loading embedding models at startup...") + logger.info( + "Loading embedding models at startup | service_kind=%s text_enabled=%s image_enabled=%s", + _SERVICE_KIND, + open_text_model, + open_image_model, + ) if open_text_model: try: @@ -482,18 +566,72 @@ def _as_list(embedding: Optional[np.ndarray], normalize: bool = False) -> Option return embedding.tolist() +def _try_full_text_cache_hit( + normalized: List[str], + effective_normalize: bool, +) -> Optional[_EmbedResult]: + out: List[Optional[List[float]]] = [] + for text in normalized: + cached = _text_cache.get(build_text_cache_key(text, normalize=effective_normalize)) + if cached is None: + return None + vec = _as_list(cached, normalize=False) + if vec is None: + return None + out.append(vec) + return _EmbedResult( + vectors=out, + cache_hits=len(out), + cache_misses=0, + backend_elapsed_ms=0.0, + mode="cache-only", + ) + + +def _try_full_image_cache_hit( + urls: List[str], + effective_normalize: bool, +) -> Optional[_EmbedResult]: + out: List[Optional[List[float]]] = [] + for url in urls: + cached = _image_cache.get(build_image_cache_key(url, normalize=effective_normalize)) + if cached is None: + return None + vec = _as_list(cached, normalize=False) + if vec is None: + return None + out.append(vec) + return _EmbedResult( + vectors=out, + cache_hits=len(out), + cache_misses=0, + backend_elapsed_ms=0.0, + mode="cache-only", + ) + + @app.get("/health") def health() -> Dict[str, Any]: """Health check endpoint. Returns status and current throttling stats.""" + ready = (not open_text_model or _text_model is not None) and (not open_image_model or _image_model is not None) return { - "status": "ok", + "status": "ok" if ready else "degraded", + "service_kind": _SERVICE_KIND, "text_model_loaded": _text_model is not None, "text_backend": _text_backend_name, "image_model_loaded": _image_model is not None, + "cache_enabled": { + "text": _text_cache.redis_client is not None, + "image": _image_cache.redis_client is not None, + }, "limits": { "text": _text_request_limiter.snapshot(), "image": _image_request_limiter.snapshot(), }, + "stats": { + "text": _text_stats.snapshot(), + "image": _image_stats.snapshot(), + }, "text_microbatch": { "window_ms": round(_TEXT_MICROBATCH_WINDOW_SEC * 1000.0, 3), "queue_depth": len(_text_single_queue), @@ -503,44 +641,105 @@ def health() -> Dict[str, Any]: } +@app.get("/ready") +def ready() -> Dict[str, Any]: + text_ready = (not open_text_model) or (_text_model is not None) + image_ready = (not open_image_model) or (_image_model is not None) + if not (text_ready and image_ready): + raise HTTPException( + status_code=503, + detail={ + "service_kind": _SERVICE_KIND, + "text_ready": text_ready, + "image_ready": image_ready, + }, + ) + return { + "status": "ready", + "service_kind": _SERVICE_KIND, + "text_ready": text_ready, + "image_ready": image_ready, + } + + def _embed_text_impl( normalized: List[str], effective_normalize: bool, request_id: str, -) -> List[Optional[List[float]]]: +) -> _EmbedResult: if _text_model is None: raise RuntimeError("Text model not loaded") - t0 = time.perf_counter() + out: List[Optional[List[float]]] = [None] * len(normalized) + missing_indices: List[int] = [] + missing_texts: List[str] = [] + missing_cache_keys: List[str] = [] + cache_hits = 0 + for idx, text in enumerate(normalized): + cache_key = build_text_cache_key(text, normalize=effective_normalize) + cached = _text_cache.get(cache_key) + if cached is not None: + vec = _as_list(cached, normalize=False) + if vec is not None: + out[idx] = vec + cache_hits += 1 + continue + missing_indices.append(idx) + missing_texts.append(text) + missing_cache_keys.append(cache_key) + + if not missing_texts: + logger.info( + "text backend done | backend=%s mode=cache-only inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 backend_elapsed_ms=0.00", + _text_backend_name, + len(normalized), + effective_normalize, + len(out[0]) if out and out[0] is not None else 0, + cache_hits, + extra=_request_log_extra(request_id), + ) + return _EmbedResult( + vectors=out, + cache_hits=cache_hits, + cache_misses=0, + backend_elapsed_ms=0.0, + mode="cache-only", + ) + + backend_t0 = time.perf_counter() try: if _text_backend_name == "local_st": - if len(normalized) == 1 and _text_batch_worker is not None: - out = [ + if len(missing_texts) == 1 and _text_batch_worker is not None: + computed = [ _encode_single_text_with_microbatch( - normalized[0], + missing_texts[0], normalize=effective_normalize, request_id=request_id, ) ] - logger.info( - "text backend done | backend=%s mode=microbatch-single inputs=%d normalize=%s dim=%d backend_elapsed_ms=%.2f", - _text_backend_name, - len(normalized), - effective_normalize, - len(out[0]) if out and out[0] is not None else 0, - (time.perf_counter() - t0) * 1000.0, - extra=_request_log_extra(request_id), - ) - return out - embs = _encode_local_st(normalized, normalize_embeddings=False) - mode = "direct-batch" + mode = "microbatch-single" + else: + embs = _encode_local_st(missing_texts, normalize_embeddings=False) + computed = [] + for i, emb in enumerate(embs): + vec = _as_list(emb, normalize=effective_normalize) + if vec is None: + raise RuntimeError(f"Text model returned empty embedding for missing index {i}") + computed.append(vec) + mode = "direct-batch" else: embs = _text_model.encode( - normalized, + missing_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE, normalize_embeddings=effective_normalize, ) + computed = [] + for i, emb in enumerate(embs): + vec = _as_list(emb, normalize=False) + if vec is None: + raise RuntimeError(f"Text model returned empty embedding for missing index {i}") + computed.append(vec) mode = "backend-batch" except Exception as e: logger.error( @@ -551,30 +750,37 @@ def _embed_text_impl( ) raise RuntimeError(f"Text embedding backend failure: {e}") from e - if embs is None or len(embs) != len(normalized): + if len(computed) != len(missing_texts): raise RuntimeError( - f"Text model response length mismatch: expected {len(normalized)}, " - f"got {0 if embs is None else len(embs)}" + f"Text model response length mismatch: expected {len(missing_texts)}, " + f"got {len(computed)}" ) - out: List[Optional[List[float]]] = [] - for i, emb in enumerate(embs): - vec = _as_list(emb, normalize=effective_normalize) - if vec is None: - raise RuntimeError(f"Text model returned empty embedding for index {i}") - out.append(vec) + for pos, cache_key, vec in zip(missing_indices, missing_cache_keys, computed): + out[pos] = vec + _text_cache.set(cache_key, np.asarray(vec, dtype=np.float32)) + + backend_elapsed_ms = (time.perf_counter() - backend_t0) * 1000.0 logger.info( - "text backend done | backend=%s mode=%s inputs=%d normalize=%s dim=%d backend_elapsed_ms=%.2f", + "text backend done | backend=%s mode=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d backend_elapsed_ms=%.2f", _text_backend_name, mode, len(normalized), effective_normalize, len(out[0]) if out and out[0] is not None else 0, - (time.perf_counter() - t0) * 1000.0, + cache_hits, + len(missing_texts), + backend_elapsed_ms, extra=_request_log_extra(request_id), ) - return out + return _EmbedResult( + vectors=out, + cache_hits=cache_hits, + cache_misses=len(missing_texts), + backend_elapsed_ms=backend_elapsed_ms, + mode=mode, + ) @app.post("/embed/text") @@ -584,6 +790,9 @@ async def embed_text( response: Response, normalize: Optional[bool] = None, ) -> List[Optional[List[float]]]: + if _text_model is None: + raise HTTPException(status_code=503, detail="Text embedding model not loaded in this service") + request_id = _resolve_request_id(http_request) response.headers["X-Request-ID"] = request_id @@ -597,8 +806,33 @@ async def embed_text( raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: empty string") normalized.append(s) + cache_check_started = time.perf_counter() + cache_only = _try_full_text_cache_hit(normalized, effective_normalize) + if cache_only is not None: + latency_ms = (time.perf_counter() - cache_check_started) * 1000.0 + _text_stats.record_completed( + success=True, + latency_ms=latency_ms, + backend_latency_ms=0.0, + cache_hits=cache_only.cache_hits, + cache_misses=0, + ) + logger.info( + "embed_text response | backend=%s mode=cache-only inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 first_vector=%s latency_ms=%.2f", + _text_backend_name, + len(normalized), + effective_normalize, + len(cache_only.vectors[0]) if cache_only.vectors and cache_only.vectors[0] is not None else 0, + cache_only.cache_hits, + _preview_vector(cache_only.vectors[0] if cache_only.vectors else None), + latency_ms, + extra=_request_log_extra(request_id), + ) + return cache_only.vectors + accepted, active = _text_request_limiter.try_acquire() if not accepted: + _text_stats.record_rejected() logger.warning( "embed_text rejected | client=%s backend=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", _request_client(http_request), @@ -617,6 +851,9 @@ async def embed_text( request_started = time.perf_counter() success = False + backend_elapsed_ms = 0.0 + cache_hits = 0 + cache_misses = 0 try: logger.info( "embed_text request | client=%s backend=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", @@ -636,31 +873,53 @@ async def embed_text( _text_backend_name, extra=_request_log_extra(request_id), ) - out = await run_in_threadpool(_embed_text_impl, normalized, effective_normalize, request_id) + result = await run_in_threadpool(_embed_text_impl, normalized, effective_normalize, request_id) success = True + backend_elapsed_ms = result.backend_elapsed_ms + cache_hits = result.cache_hits + cache_misses = result.cache_misses latency_ms = (time.perf_counter() - request_started) * 1000.0 + _text_stats.record_completed( + success=True, + latency_ms=latency_ms, + backend_latency_ms=backend_elapsed_ms, + cache_hits=cache_hits, + cache_misses=cache_misses, + ) logger.info( - "embed_text response | backend=%s inputs=%d normalize=%s dim=%d first_vector=%s latency_ms=%.2f", + "embed_text response | backend=%s mode=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d first_vector=%s latency_ms=%.2f", _text_backend_name, + result.mode, len(normalized), effective_normalize, - len(out[0]) if out and out[0] is not None else 0, - _preview_vector(out[0] if out else None), + len(result.vectors[0]) if result.vectors and result.vectors[0] is not None else 0, + cache_hits, + cache_misses, + _preview_vector(result.vectors[0] if result.vectors else None), latency_ms, extra=_request_log_extra(request_id), ) verbose_logger.info( "embed_text result detail | count=%d first_vector=%s latency_ms=%.2f", - len(out), - out[0][: _VECTOR_PREVIEW_DIMS] if out and out[0] is not None else [], + len(result.vectors), + result.vectors[0][: _VECTOR_PREVIEW_DIMS] + if result.vectors and result.vectors[0] is not None + else [], latency_ms, extra=_request_log_extra(request_id), ) - return out + return result.vectors except HTTPException: raise except Exception as e: latency_ms = (time.perf_counter() - request_started) * 1000.0 + _text_stats.record_completed( + success=False, + latency_ms=latency_ms, + backend_latency_ms=backend_elapsed_ms, + cache_hits=cache_hits, + cache_misses=cache_misses, + ) logger.error( "embed_text failed | backend=%s inputs=%d normalize=%s latency_ms=%.2f error=%s", _text_backend_name, @@ -686,39 +945,84 @@ def _embed_image_impl( urls: List[str], effective_normalize: bool, request_id: str, -) -> List[Optional[List[float]]]: +) -> _EmbedResult: if _image_model is None: raise RuntimeError("Image model not loaded") - t0 = time.perf_counter() + out: List[Optional[List[float]]] = [None] * len(urls) + missing_indices: List[int] = [] + missing_urls: List[str] = [] + missing_cache_keys: List[str] = [] + cache_hits = 0 + for idx, url in enumerate(urls): + cache_key = build_image_cache_key(url, normalize=effective_normalize) + cached = _image_cache.get(cache_key) + if cached is not None: + vec = _as_list(cached, normalize=False) + if vec is not None: + out[idx] = vec + cache_hits += 1 + continue + missing_indices.append(idx) + missing_urls.append(url) + missing_cache_keys.append(cache_key) + + if not missing_urls: + logger.info( + "image backend done | mode=cache-only inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 backend_elapsed_ms=0.00", + len(urls), + effective_normalize, + len(out[0]) if out and out[0] is not None else 0, + cache_hits, + extra=_request_log_extra(request_id), + ) + return _EmbedResult( + vectors=out, + cache_hits=cache_hits, + cache_misses=0, + backend_elapsed_ms=0.0, + mode="cache-only", + ) + + backend_t0 = time.perf_counter() with _image_encode_lock: vectors = _image_model.encode_image_urls( - urls, + missing_urls, batch_size=CONFIG.IMAGE_BATCH_SIZE, normalize_embeddings=effective_normalize, ) - if vectors is None or len(vectors) != len(urls): + if vectors is None or len(vectors) != len(missing_urls): raise RuntimeError( - f"Image model response length mismatch: expected {len(urls)}, " + f"Image model response length mismatch: expected {len(missing_urls)}, " f"got {0 if vectors is None else len(vectors)}" ) - out: List[Optional[List[float]]] = [] - for i, vec in enumerate(vectors): + for pos, cache_key, vec in zip(missing_indices, missing_cache_keys, vectors): out_vec = _as_list(vec, normalize=effective_normalize) if out_vec is None: - raise RuntimeError(f"Image model returned empty embedding for index {i}") - out.append(out_vec) + raise RuntimeError(f"Image model returned empty embedding for position {pos}") + out[pos] = out_vec + _image_cache.set(cache_key, np.asarray(out_vec, dtype=np.float32)) + + backend_elapsed_ms = (time.perf_counter() - backend_t0) * 1000.0 logger.info( - "image backend done | inputs=%d normalize=%s dim=%d backend_elapsed_ms=%.2f", + "image backend done | mode=backend-batch inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d backend_elapsed_ms=%.2f", len(urls), effective_normalize, len(out[0]) if out and out[0] is not None else 0, - (time.perf_counter() - t0) * 1000.0, + cache_hits, + len(missing_urls), + backend_elapsed_ms, extra=_request_log_extra(request_id), ) - return out + return _EmbedResult( + vectors=out, + cache_hits=cache_hits, + cache_misses=len(missing_urls), + backend_elapsed_ms=backend_elapsed_ms, + mode="backend-batch", + ) @app.post("/embed/image") @@ -728,6 +1032,9 @@ async def embed_image( response: Response, normalize: Optional[bool] = None, ) -> List[Optional[List[float]]]: + if _image_model is None: + raise HTTPException(status_code=503, detail="Image embedding model not loaded in this service") + request_id = _resolve_request_id(http_request) response.headers["X-Request-ID"] = request_id @@ -741,8 +1048,32 @@ async def embed_image( raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: empty URL/path") urls.append(s) + cache_check_started = time.perf_counter() + cache_only = _try_full_image_cache_hit(urls, effective_normalize) + if cache_only is not None: + latency_ms = (time.perf_counter() - cache_check_started) * 1000.0 + _image_stats.record_completed( + success=True, + latency_ms=latency_ms, + backend_latency_ms=0.0, + cache_hits=cache_only.cache_hits, + cache_misses=0, + ) + logger.info( + "embed_image response | mode=cache-only inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 first_vector=%s latency_ms=%.2f", + len(urls), + effective_normalize, + len(cache_only.vectors[0]) if cache_only.vectors and cache_only.vectors[0] is not None else 0, + cache_only.cache_hits, + _preview_vector(cache_only.vectors[0] if cache_only.vectors else None), + latency_ms, + extra=_request_log_extra(request_id), + ) + return cache_only.vectors + accepted, active = _image_request_limiter.try_acquire() if not accepted: + _image_stats.record_rejected() logger.warning( "embed_image rejected | client=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", _request_client(http_request), @@ -760,6 +1091,9 @@ async def embed_image( request_started = time.perf_counter() success = False + backend_elapsed_ms = 0.0 + cache_hits = 0 + cache_misses = 0 try: logger.info( "embed_image request | client=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", @@ -777,30 +1111,52 @@ async def embed_image( effective_normalize, extra=_request_log_extra(request_id), ) - out = await run_in_threadpool(_embed_image_impl, urls, effective_normalize, request_id) + result = await run_in_threadpool(_embed_image_impl, urls, effective_normalize, request_id) success = True + backend_elapsed_ms = result.backend_elapsed_ms + cache_hits = result.cache_hits + cache_misses = result.cache_misses latency_ms = (time.perf_counter() - request_started) * 1000.0 + _image_stats.record_completed( + success=True, + latency_ms=latency_ms, + backend_latency_ms=backend_elapsed_ms, + cache_hits=cache_hits, + cache_misses=cache_misses, + ) logger.info( - "embed_image response | inputs=%d normalize=%s dim=%d first_vector=%s latency_ms=%.2f", + "embed_image response | mode=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d first_vector=%s latency_ms=%.2f", + result.mode, len(urls), effective_normalize, - len(out[0]) if out and out[0] is not None else 0, - _preview_vector(out[0] if out else None), + len(result.vectors[0]) if result.vectors and result.vectors[0] is not None else 0, + cache_hits, + cache_misses, + _preview_vector(result.vectors[0] if result.vectors else None), latency_ms, extra=_request_log_extra(request_id), ) verbose_logger.info( "embed_image result detail | count=%d first_vector=%s latency_ms=%.2f", - len(out), - out[0][: _VECTOR_PREVIEW_DIMS] if out and out[0] is not None else [], + len(result.vectors), + result.vectors[0][: _VECTOR_PREVIEW_DIMS] + if result.vectors and result.vectors[0] is not None + else [], latency_ms, extra=_request_log_extra(request_id), ) - return out + return result.vectors except HTTPException: raise except Exception as e: latency_ms = (time.perf_counter() - request_started) * 1000.0 + _image_stats.record_completed( + success=False, + latency_ms=latency_ms, + backend_latency_ms=backend_elapsed_ms, + cache_hits=cache_hits, + cache_misses=cache_misses, + ) logger.error( "embed_image failed | inputs=%d normalize=%s latency_ms=%.2f error=%s", len(urls), diff --git a/embeddings/text_encoder.py b/embeddings/text_encoder.py index d3f08fe..ca95b23 100644 --- a/embeddings/text_encoder.py +++ b/embeddings/text_encoder.py @@ -10,19 +10,26 @@ import requests logger = logging.getLogger(__name__) -from config.services_config import get_embedding_base_url +from config.services_config import get_embedding_text_base_url +from embeddings.cache_keys import build_text_cache_key from embeddings.redis_embedding_cache import RedisEmbeddingCache # Try to import REDIS_CONFIG, but allow import to fail from config.env_config import REDIS_CONFIG + class TextEmbeddingEncoder: """ Text embedding encoder using network service. """ def __init__(self, service_url: Optional[str] = None): - resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url() + resolved_url = ( + service_url + or os.getenv("EMBEDDING_TEXT_SERVICE_URL") + or os.getenv("EMBEDDING_SERVICE_URL") + or get_embedding_text_base_url() + ) self.service_url = str(resolved_url).rstrip("/") self.endpoint = f"{self.service_url}/embed/text" self.expire_time = timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180)) @@ -87,9 +94,8 @@ class TextEmbeddingEncoder: uncached_texts: List[str] = [] embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) - for i, text in enumerate(sentences): - cached = self._get_cached_embedding(text) + cached = self._get_cached_embedding(text, normalize_embeddings=normalize_embeddings) if cached is not None: embeddings[i] = cached else: @@ -115,7 +121,11 @@ class TextEmbeddingEncoder: embedding_array = np.array(embedding, dtype=np.float32) if self._is_valid_embedding(embedding_array): embeddings[original_idx] = embedding_array - self._set_cached_embedding(text, embedding_array) + self._set_cached_embedding( + text, + embedding_array, + normalize_embeddings=normalize_embeddings, + ) else: raise ValueError( f"Invalid embedding returned from service for text index {original_idx}" @@ -150,20 +160,32 @@ class TextEmbeddingEncoder: def _get_cached_embedding( self, query: str, + *, + normalize_embeddings: bool, ) -> Optional[np.ndarray]: """Get embedding from cache if exists (with sliding expiration).""" - embedding = self.cache.get(query) + embedding = self.cache.get(build_text_cache_key(query, normalize=normalize_embeddings)) if embedding is not None: - logger.debug(f"Cache hit for embedding: {query}") + logger.debug( + "Cache hit for text embedding | normalize=%s query=%s", + normalize_embeddings, + query, + ) return embedding def _set_cached_embedding( self, query: str, embedding: np.ndarray, + *, + normalize_embeddings: bool, ) -> bool: """Store embedding in cache.""" - ok = self.cache.set(query, embedding) + ok = self.cache.set(build_text_cache_key(query, normalize=normalize_embeddings), embedding) if ok: - logger.debug(f"Successfully cached embedding for query: {query}") + logger.debug( + "Successfully cached text embedding | normalize=%s query=%s", + normalize_embeddings, + query, + ) return ok diff --git a/providers/embedding.py b/providers/embedding.py index fd7a0be..ff96342 100644 --- a/providers/embedding.py +++ b/providers/embedding.py @@ -2,7 +2,11 @@ from __future__ import annotations -from config.services_config import get_embedding_config, get_embedding_base_url +from config.services_config import ( + get_embedding_config, + get_embedding_image_base_url, + get_embedding_text_base_url, +) def create_embedding_provider() -> "EmbeddingProvider": @@ -21,13 +25,14 @@ class EmbeddingProvider: """ def __init__(self) -> None: - self._base_url = get_embedding_base_url() + self._text_base_url = get_embedding_text_base_url() + self._image_base_url = get_embedding_image_base_url() from embeddings.text_encoder import TextEmbeddingEncoder from embeddings.image_encoder import CLIPImageEncoder # Initialize once; avoid per-access instantiation. - self._text_encoder = TextEmbeddingEncoder(service_url=self._base_url) - self._image_encoder = CLIPImageEncoder(service_url=self._base_url) + self._text_encoder = TextEmbeddingEncoder(service_url=self._text_base_url) + self._image_encoder = CLIPImageEncoder(service_url=self._image_base_url) @property def text_encoder(self): diff --git a/requirements_embedding_service.txt b/requirements_embedding_service.txt index a3406e4..1086750 100644 --- a/requirements_embedding_service.txt +++ b/requirements_embedding_service.txt @@ -10,6 +10,7 @@ pydantic>=2.0.0 requests>=2.31.0 numpy>=1.24.0 pyyaml>=6.0 +redis>=5.0.0 # Image backend via clip-as-service client setuptools<82 diff --git a/scripts/perf_api_benchmark.py b/scripts/perf_api_benchmark.py index 7f0defc..23acfbe 100755 --- a/scripts/perf_api_benchmark.py +++ b/scripts/perf_api_benchmark.py @@ -6,6 +6,7 @@ Default scenarios (aligned with docs/搜索API对接指南.md): - backend_search POST /search/ - backend_suggest GET /search/suggestions - embed_text POST /embed/text +- embed_image POST /embed/image - translate POST /translate - rerank POST /rerank @@ -158,6 +159,13 @@ def make_default_templates(tenant_id: str) -> Dict[str, List[RequestTemplate]]: json_body=["wireless mouse", "gaming keyboard", "barbie doll"], ) ], + "embed_image": [ + RequestTemplate( + method="POST", + path="/embed/image", + json_body=["/data/saas-search/docs/image-dress1.png"], + ) + ], "translate": [ RequestTemplate( method="POST", @@ -220,7 +228,8 @@ def build_scenarios(args: argparse.Namespace) -> Dict[str, Scenario]: scenario_base = { "backend_search": args.backend_base, "backend_suggest": args.backend_base, - "embed_text": args.embedding_base, + "embed_text": args.embedding_text_base, + "embed_image": args.embedding_image_base, "translate": args.translator_base, "rerank": args.reranker_base, } @@ -433,7 +442,7 @@ def parse_args() -> argparse.Namespace: "--scenario", type=str, default="all", - help="Scenario: backend_search | backend_suggest | embed_text | translate | rerank | all | comma-separated list", + help="Scenario: backend_search | backend_suggest | embed_text | embed_image | translate | rerank | all | comma-separated list", ) parser.add_argument("--tenant-id", type=str, default="162", help="Tenant ID for backend search/suggest") parser.add_argument("--duration", type=int, default=30, help="Duration seconds per scenario; <=0 means no duration cap") @@ -443,7 +452,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--max-errors", type=int, default=0, help="Stop scenario when accumulated errors reach this value") parser.add_argument("--backend-base", type=str, default="http://127.0.0.1:6002", help="Base URL for backend search API") - parser.add_argument("--embedding-base", type=str, default="http://127.0.0.1:6005", help="Base URL for embedding service") + parser.add_argument("--embedding-text-base", type=str, default="http://127.0.0.1:6005", help="Base URL for text embedding service") + parser.add_argument("--embedding-image-base", type=str, default="http://127.0.0.1:6008", help="Base URL for image embedding service") parser.add_argument("--translator-base", type=str, default="http://127.0.0.1:6006", help="Base URL for translation service") parser.add_argument("--reranker-base", type=str, default="http://127.0.0.1:6007", help="Base URL for reranker service") @@ -547,7 +557,7 @@ async def main_async() -> int: args = parse_args() scenarios = build_scenarios(args) - all_names = ["backend_search", "backend_suggest", "embed_text", "translate", "rerank"] + all_names = ["backend_search", "backend_suggest", "embed_text", "embed_image", "translate", "rerank"] if args.scenario == "all": run_names = [x for x in all_names if x in scenarios] else: @@ -595,7 +605,8 @@ async def main_async() -> int: print(f" timeout={args.timeout}s") print(f" max_errors={args.max_errors}") print(f" backend_base={args.backend_base}") - print(f" embedding_base={args.embedding_base}") + print(f" embedding_text_base={args.embedding_text_base}") + print(f" embedding_image_base={args.embedding_image_base}") print(f" translator_base={args.translator_base}") print(f" reranker_base={args.reranker_base}") if args.rerank_dynamic_docs: @@ -643,7 +654,8 @@ async def main_async() -> int: "timeout_sec": args.timeout, "max_errors": args.max_errors, "backend_base": args.backend_base, - "embedding_base": args.embedding_base, + "embedding_text_base": args.embedding_text_base, + "embedding_image_base": args.embedding_image_base, "translator_base": args.translator_base, "reranker_base": args.reranker_base, "cases_file": args.cases_file or None, diff --git a/scripts/service_ctl.sh b/scripts/service_ctl.sh index d6076c5..86eff9e 100755 --- a/scripts/service_ctl.sh +++ b/scripts/service_ctl.sh @@ -16,9 +16,9 @@ mkdir -p "${LOG_DIR}" source "${PROJECT_ROOT}/scripts/lib/load_env.sh" CORE_SERVICES=("backend" "indexer" "frontend") -OPTIONAL_SERVICES=("tei" "cnclip" "embedding" "translator" "reranker") +OPTIONAL_SERVICES=("tei" "cnclip" "embedding" "embedding-image" "translator" "reranker") FULL_SERVICES=("${OPTIONAL_SERVICES[@]}" "${CORE_SERVICES[@]}") -STOP_ORDER_SERVICES=("frontend" "indexer" "backend" "reranker" "translator" "embedding" "cnclip" "tei") +STOP_ORDER_SERVICES=("frontend" "indexer" "backend" "reranker" "translator" "embedding-image" "embedding" "cnclip" "tei") all_services() { echo "${FULL_SERVICES[@]}" @@ -30,7 +30,8 @@ get_port() { backend) echo "${API_PORT:-6002}" ;; indexer) echo "${INDEXER_PORT:-6004}" ;; frontend) echo "${FRONTEND_PORT:-6003}" ;; - embedding) echo "${EMBEDDING_PORT:-6005}" ;; + embedding) echo "${EMBEDDING_TEXT_PORT:-${EMBEDDING_PORT:-6005}}" ;; + embedding-image) echo "${EMBEDDING_IMAGE_PORT:-6008}" ;; translator) echo "${TRANSLATION_PORT:-6006}" ;; reranker) echo "${RERANKER_PORT:-6007}" ;; tei) echo "${TEI_PORT:-8080}" ;; @@ -65,7 +66,8 @@ service_start_cmd() { backend) echo "./scripts/start_backend.sh" ;; indexer) echo "./scripts/start_indexer.sh" ;; frontend) echo "./scripts/start_frontend.sh" ;; - embedding) echo "./scripts/start_embedding_service.sh" ;; + embedding) echo "./scripts/start_embedding_text_service.sh" ;; + embedding-image) echo "./scripts/start_embedding_image_service.sh" ;; translator) echo "./scripts/start_translator.sh" ;; reranker) echo "./scripts/start_reranker.sh" ;; tei) echo "./scripts/start_tei_service.sh" ;; @@ -77,7 +79,7 @@ service_start_cmd() { service_exists() { local service="$1" case "${service}" in - backend|indexer|frontend|embedding|translator|reranker|tei|cnclip) return 0 ;; + backend|indexer|frontend|embedding|embedding-image|translator|reranker|tei|cnclip) return 0 ;; *) return 1 ;; esac } @@ -95,7 +97,7 @@ validate_targets() { health_path_for_service() { local service="$1" case "${service}" in - backend|indexer|embedding|translator|reranker|tei) echo "/health" ;; + backend|indexer|embedding|embedding-image|translator|reranker|tei) echo "/health" ;; *) echo "" ;; esac } diff --git a/scripts/start_embedding_image_service.sh b/scripts/start_embedding_image_service.sh new file mode 100755 index 0000000..d9be1bd --- /dev/null +++ b/scripts/start_embedding_image_service.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +set -euo pipefail + +cd "$(dirname "$0")/.." + +exec ./scripts/start_embedding_service.sh image diff --git a/scripts/start_embedding_service.sh b/scripts/start_embedding_service.sh index 0eb7252..a289868 100755 --- a/scripts/start_embedding_service.sh +++ b/scripts/start_embedding_service.sh @@ -1,6 +1,6 @@ #!/bin/bash # -# Start Embedding Service (port 6005). +# Start Embedding Service (combined/text/image mode). # # Design: # - Run in isolated venv `.venv-embedding` (do not pollute main `.venv`) @@ -33,18 +33,55 @@ CLIP_AS_SERVICE_SERVER=$("${PYTHON_BIN}" -c "from embeddings.config import CONFI CLIP_AS_SERVICE_MODEL_NAME=$("${PYTHON_BIN}" -c "from embeddings.config import CONFIG; print(CONFIG.CLIP_AS_SERVICE_MODEL_NAME)") TEXT_BACKEND=$("${PYTHON_BIN}" -c "from config.services_config import get_embedding_backend_config; print(get_embedding_backend_config()[0])") TEI_BASE_URL=$("${PYTHON_BIN}" -c "import os; from config.services_config import get_embedding_backend_config; from embeddings.config import CONFIG; _, cfg = get_embedding_backend_config(); print(os.getenv('TEI_BASE_URL') or cfg.get('base_url') or CONFIG.TEI_BASE_URL)") +SERVICE_KIND="${1:-${EMBEDDING_SERVICE_KIND:-all}}" +SERVICE_KIND="$(echo "${SERVICE_KIND}" | tr '[:upper:]' '[:lower:]')" +if [[ "${SERVICE_KIND}" != "all" && "${SERVICE_KIND}" != "text" && "${SERVICE_KIND}" != "image" ]]; then + echo "ERROR: invalid embedding service kind: ${SERVICE_KIND}. expected all|text|image" >&2 + exit 1 +fi + +ENABLE_TEXT_MODEL="${EMBEDDING_ENABLE_TEXT_MODEL:-true}" +ENABLE_TEXT_MODEL="$(echo "${ENABLE_TEXT_MODEL}" | tr '[:upper:]' '[:lower:]')" ENABLE_IMAGE_MODEL="${EMBEDDING_ENABLE_IMAGE_MODEL:-true}" ENABLE_IMAGE_MODEL="$(echo "${ENABLE_IMAGE_MODEL}" | tr '[:upper:]' '[:lower:]')" -if [[ "${ENABLE_IMAGE_MODEL}" == "1" || "${ENABLE_IMAGE_MODEL}" == "true" || "${ENABLE_IMAGE_MODEL}" == "yes" ]]; then - IMAGE_MODEL_ENABLED=1 -else - IMAGE_MODEL_ENABLED=0 + +TEXT_MODEL_ENABLED=0 +IMAGE_MODEL_ENABLED=0 +if [[ "${SERVICE_KIND}" == "all" || "${SERVICE_KIND}" == "text" ]]; then + if [[ "${ENABLE_TEXT_MODEL}" == "1" || "${ENABLE_TEXT_MODEL}" == "true" || "${ENABLE_TEXT_MODEL}" == "yes" ]]; then + TEXT_MODEL_ENABLED=1 + fi +fi +if [[ "${SERVICE_KIND}" == "all" || "${SERVICE_KIND}" == "image" ]]; then + if [[ "${ENABLE_IMAGE_MODEL}" == "1" || "${ENABLE_IMAGE_MODEL}" == "true" || "${ENABLE_IMAGE_MODEL}" == "yes" ]]; then + IMAGE_MODEL_ENABLED=1 + fi fi EMBEDDING_SERVICE_HOST="${EMBEDDING_HOST:-${DEFAULT_EMBEDDING_SERVICE_HOST}}" -EMBEDDING_SERVICE_PORT="${EMBEDDING_PORT:-${DEFAULT_EMBEDDING_SERVICE_PORT}}" +if [[ "${SERVICE_KIND}" == "text" ]]; then + EMBEDDING_SERVICE_PORT="${EMBEDDING_TEXT_PORT:-${EMBEDDING_PORT:-${DEFAULT_EMBEDDING_SERVICE_PORT}}}" +elif [[ "${SERVICE_KIND}" == "image" ]]; then + EMBEDDING_SERVICE_PORT="${EMBEDDING_IMAGE_PORT:-6008}" +else + EMBEDDING_SERVICE_PORT="${EMBEDDING_PORT:-${DEFAULT_EMBEDDING_SERVICE_PORT}}" +fi -if [[ "${TEXT_BACKEND}" == "tei" ]]; then +export EMBEDDING_SERVICE_KIND="${SERVICE_KIND}" +export EMBEDDING_HOST="${EMBEDDING_SERVICE_HOST}" +export EMBEDDING_PORT="${EMBEDDING_SERVICE_PORT}" +if [[ "${TEXT_MODEL_ENABLED}" == "1" ]]; then + export EMBEDDING_ENABLE_TEXT_MODEL=true +else + export EMBEDDING_ENABLE_TEXT_MODEL=false +fi +if [[ "${IMAGE_MODEL_ENABLED}" == "1" ]]; then + export EMBEDDING_ENABLE_IMAGE_MODEL=true +else + export EMBEDDING_ENABLE_IMAGE_MODEL=false +fi + +if [[ "${TEXT_MODEL_ENABLED}" == "1" && "${TEXT_BACKEND}" == "tei" ]]; then if ! curl -sf "${TEI_BASE_URL%/}/health" >/dev/null 2>&1; then echo "ERROR: TEI backend is selected but TEI is not reachable: ${TEI_BASE_URL}/health" >&2 echo "Please start TEI first: ./scripts/start_tei_service.sh" >&2 @@ -81,12 +118,16 @@ fi echo "========================================" echo "Starting Embedding Service" echo "========================================" +echo "Kind: ${SERVICE_KIND}" echo "Python: ${PYTHON_BIN}" echo "Host: ${EMBEDDING_SERVICE_HOST}" echo "Port: ${EMBEDDING_SERVICE_PORT}" -echo "Text backend: ${TEXT_BACKEND}" -echo "Text max inflight: ${TEXT_MAX_INFLIGHT:-32}" -if [[ "${TEXT_BACKEND}" == "tei" ]]; then +echo "Text backend enabled: ${TEXT_MODEL_ENABLED}" +if [[ "${TEXT_MODEL_ENABLED}" == "1" ]]; then + echo "Text backend: ${TEXT_BACKEND}" + echo "Text max inflight: ${TEXT_MAX_INFLIGHT:-32}" +fi +if [[ "${TEXT_MODEL_ENABLED}" == "1" && "${TEXT_BACKEND}" == "tei" ]]; then echo "TEI URL: ${TEI_BASE_URL}" fi if [[ "${IMAGE_MODEL_ENABLED}" == "0" ]]; then @@ -94,12 +135,20 @@ if [[ "${IMAGE_MODEL_ENABLED}" == "0" ]]; then elif [[ "${USE_CLIP_AS_SERVICE}" == "1" ]]; then echo "Image backend: clip-as-service (${CLIP_AS_SERVICE_SERVER}, model=${CLIP_AS_SERVICE_MODEL_NAME})" fi -echo "Image max inflight: ${IMAGE_MAX_INFLIGHT:-1}" +if [[ "${IMAGE_MODEL_ENABLED}" == "1" ]]; then + echo "Image max inflight: ${IMAGE_MAX_INFLIGHT:-1}" +fi echo "Logs: logs/embedding_api.log, logs/embedding_api_error.log, logs/verbose/embedding_verbose.log" echo echo "Tips:" echo " - Use a single worker (GPU models cannot be safely duplicated across workers)." -echo " - Clients can set EMBEDDING_SERVICE_URL=http://localhost:${EMBEDDING_SERVICE_PORT}" +if [[ "${SERVICE_KIND}" == "text" ]]; then + echo " - Clients can set EMBEDDING_TEXT_SERVICE_URL=http://localhost:${EMBEDDING_SERVICE_PORT}" +elif [[ "${SERVICE_KIND}" == "image" ]]; then + echo " - Clients can set EMBEDDING_IMAGE_SERVICE_URL=http://localhost:${EMBEDDING_SERVICE_PORT}" +else + echo " - Clients can set EMBEDDING_SERVICE_URL=http://localhost:${EMBEDDING_SERVICE_PORT}" +fi echo UVICORN_LOG_LEVEL="${EMBEDDING_UVICORN_LOG_LEVEL:-info}" diff --git a/scripts/start_embedding_text_service.sh b/scripts/start_embedding_text_service.sh new file mode 100755 index 0000000..40b7652 --- /dev/null +++ b/scripts/start_embedding_text_service.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +set -euo pipefail + +cd "$(dirname "$0")/.." + +exec ./scripts/start_embedding_service.sh text diff --git a/tests/test_embedding_pipeline.py b/tests/test_embedding_pipeline.py index 8e826cd..3470fad 100644 --- a/tests/test_embedding_pipeline.py +++ b/tests/test_embedding_pipeline.py @@ -12,7 +12,9 @@ from config import ( SearchConfig, ) from embeddings.text_encoder import TextEmbeddingEncoder +from embeddings.image_encoder import CLIPImageEncoder from embeddings.bf16 import encode_embedding_for_redis +from embeddings.cache_keys import build_image_cache_key, build_text_cache_key from query import QueryParser @@ -67,6 +69,18 @@ class _FakeQueryEncoder: return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object) +class _FakeEmbeddingCache: + def __init__(self): + self.store: Dict[str, np.ndarray] = {} + + def get(self, key: str): + return self.store.get(key) + + def set(self, key: str, embedding: np.ndarray): + self.store[key] = np.asarray(embedding, dtype=np.float32) + return True + + def _build_test_config() -> SearchConfig: return SearchConfig( field_boosts={"title.en": 3.0}, @@ -91,8 +105,8 @@ def _build_test_config() -> SearchConfig: def test_text_embedding_encoder_response_alignment(monkeypatch): - fake_redis = _FakeRedis() - monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) + fake_cache = _FakeEmbeddingCache() + monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) def _fake_post(url, json, timeout, **kwargs): assert url.endswith("/embed/text") @@ -112,8 +126,8 @@ def test_text_embedding_encoder_response_alignment(monkeypatch): def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): - fake_redis = _FakeRedis() - monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) + fake_cache = _FakeEmbeddingCache() + monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) def _fake_post(url, json, timeout, **kwargs): return _FakeResponse([[0.1, 0.2], None]) @@ -126,10 +140,10 @@ def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): def test_text_embedding_encoder_cache_hit(monkeypatch): - fake_redis = _FakeRedis() + fake_cache = _FakeEmbeddingCache() cached = np.array([0.9, 0.8], dtype=np.float32) - fake_redis.store["embedding:cached-text"] = encode_embedding_for_redis(cached) - monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) + fake_cache.store[build_text_cache_key("cached-text", normalize=True)] = cached + monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) calls = {"count": 0} @@ -147,6 +161,29 @@ def test_text_embedding_encoder_cache_hit(monkeypatch): assert np.allclose(out[1], np.array([0.3, 0.4], dtype=np.float32)) +def test_image_embedding_encoder_cache_hit(monkeypatch): + fake_cache = _FakeEmbeddingCache() + cached = np.array([0.5, 0.6], dtype=np.float32) + url = "https://example.com/a.jpg" + fake_cache.store[build_image_cache_key(url, normalize=True)] = cached + monkeypatch.setattr("embeddings.image_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) + + calls = {"count": 0} + + def _fake_post(url, params, json, timeout, **kwargs): + calls["count"] += 1 + return _FakeResponse([[0.1, 0.2]]) + + monkeypatch.setattr("embeddings.image_encoder.requests.post", _fake_post) + + encoder = CLIPImageEncoder(service_url="http://127.0.0.1:6008") + out = encoder.encode_batch(["https://example.com/a.jpg", "https://example.com/b.jpg"]) + + assert calls["count"] == 1 + assert np.allclose(out[0], cached) + assert np.allclose(out[1], np.array([0.1, 0.2], dtype=np.float32)) + + def test_query_parser_generates_query_vector_with_encoder(): parser = QueryParser( config=_build_test_config(), diff --git a/tests/test_embedding_service_limits.py b/tests/test_embedding_service_limits.py index 2daa28d..7d14ab7 100644 --- a/tests/test_embedding_service_limits.py +++ b/tests/test_embedding_service_limits.py @@ -28,6 +28,24 @@ class _FakeTextModel: return [np.array([1.0, 2.0, 3.0], dtype=np.float32)] +class _FakeImageModel: + def encode_image_urls(self, urls, batch_size, normalize_embeddings): + raise AssertionError("image backend should not be called on cache hit") + + +class _FakeCache: + def __init__(self, store=None): + self.store = store or {} + self.redis_client = object() + + def get(self, key): + return self.store.get(key) + + def set(self, key, value): + self.store[key] = np.asarray(value, dtype=np.float32) + return True + + def test_health_exposes_limit_stats(monkeypatch): monkeypatch.setattr( embedding_server, @@ -39,6 +57,8 @@ def test_health_exposes_limit_stats(monkeypatch): "_image_request_limiter", embedding_server._InflightLimiter("image", 1), ) + monkeypatch.setattr(embedding_server, "_text_model", object()) + monkeypatch.setattr(embedding_server, "_image_model", object()) payload = embedding_server.health() @@ -53,6 +73,7 @@ def test_embed_image_rejects_when_image_lane_is_full(monkeypatch): acquired, _ = limiter.try_acquire() assert acquired is True monkeypatch.setattr(embedding_server, "_image_request_limiter", limiter) + monkeypatch.setattr(embedding_server, "_image_model", object()) response = _DummyResponse() with pytest.raises(embedding_server.HTTPException) as exc_info: @@ -91,3 +112,29 @@ def test_embed_text_returns_request_id_and_vector(monkeypatch): assert response.headers["X-Request-ID"] == "req-123456" assert result == [[1.0, 2.0, 3.0]] + + +def test_embed_image_service_cache_hit_bypasses_backend(monkeypatch): + cache_key = embedding_server.build_image_cache_key("https://example.com/a.jpg", normalize=True) + fake_cache = _FakeCache({cache_key: np.array([0.7, 0.8], dtype=np.float32)}) + monkeypatch.setattr( + embedding_server, + "_image_request_limiter", + embedding_server._InflightLimiter("image", 1), + ) + monkeypatch.setattr(embedding_server, "_image_model", _FakeImageModel()) + monkeypatch.setattr(embedding_server, "_image_cache", fake_cache) + + request = _DummyRequest(headers={"X-Request-ID": "img-cache-hit"}) + response = _DummyResponse() + result = asyncio.run( + embedding_server.embed_image( + ["https://example.com/a.jpg"], + request, + response, + normalize=True, + ) + ) + + assert response.headers["X-Request-ID"] == "img-cache-hit" + assert result == [[0.699999988079071, 0.800000011920929]] -- libgit2 0.21.2