""" Embedding service (FastAPI). API (simple list-in, list-out; aligned by index): - POST /embed/text body: ["text1", "text2", ...] -> [[...], ...] (TEI/BGE,语义检索 title_embedding) - POST /embed/image body: ["url_or_path1", ...] -> [[...], ...] (CN-CLIP 图向量) - POST /embed/clip_text body: ["短语1", "短语2", ...] -> [[...], ...] (CN-CLIP 文本塔,与 /embed/image 同空间) """ import logging import os import pathlib import threading import time import uuid from collections import deque from dataclasses import dataclass from typing import Any, Dict, List, Optional import numpy as np from fastapi import FastAPI, HTTPException, Request, Response from fastapi.concurrency import run_in_threadpool from config.env_config import REDIS_CONFIG from config.services_config import get_embedding_backend_config from embeddings.cache_keys import ( build_clip_text_cache_key as _mm_clip_text_cache_key, build_image_cache_key as _mm_image_cache_key, build_text_cache_key, ) from embeddings.config import CONFIG from embeddings.protocols import ImageEncoderProtocol from embeddings.redis_embedding_cache import RedisEmbeddingCache from request_log_context import ( LOG_LINE_FORMAT, RequestLogContextFilter, bind_request_log_context, build_request_log_extra, reset_request_log_context, ) app = FastAPI(title="saas-search Embedding Service", version="1.0.0") def configure_embedding_logging() -> None: root_logger = logging.getLogger() if getattr(root_logger, "_embedding_logging_configured", False): return log_dir = pathlib.Path("logs") log_dir.mkdir(exist_ok=True) log_level = os.getenv("LOG_LEVEL", "INFO").upper() numeric_level = getattr(logging, log_level, logging.INFO) formatter = logging.Formatter(LOG_LINE_FORMAT) context_filter = RequestLogContextFilter() root_logger.setLevel(numeric_level) root_logger.handlers.clear() stream_handler = logging.StreamHandler() stream_handler.setLevel(numeric_level) stream_handler.setFormatter(formatter) stream_handler.addFilter(context_filter) root_logger.addHandler(stream_handler) verbose_logger = logging.getLogger("embedding.verbose") verbose_logger.setLevel(numeric_level) verbose_logger.handlers.clear() # Consolidate verbose logs into the main embedding log stream. verbose_logger.propagate = True root_logger._embedding_logging_configured = True # type: ignore[attr-defined] configure_embedding_logging() logger = logging.getLogger(__name__) verbose_logger = logging.getLogger("embedding.verbose") # Models are loaded at startup, not lazily _text_model: Optional[Any] = None _image_model: Optional[ImageEncoderProtocol] = None _text_backend_name: str = "" _SERVICE_KIND = (os.getenv("EMBEDDING_SERVICE_KIND", "all") or "all").strip().lower() if _SERVICE_KIND not in {"all", "text", "image"}: raise RuntimeError( f"Invalid EMBEDDING_SERVICE_KIND={_SERVICE_KIND!r}; expected all, text, or image" ) _TEXT_ENABLED_BY_ENV = os.getenv("EMBEDDING_ENABLE_TEXT_MODEL", "true").lower() in ("1", "true", "yes") _IMAGE_ENABLED_BY_ENV = os.getenv("EMBEDDING_ENABLE_IMAGE_MODEL", "true").lower() in ("1", "true", "yes") open_text_model = _TEXT_ENABLED_BY_ENV and _SERVICE_KIND in {"all", "text"} open_image_model = _IMAGE_ENABLED_BY_ENV and _SERVICE_KIND in {"all", "image"} _text_encode_lock = threading.Lock() _image_encode_lock = threading.Lock() _TEXT_MICROBATCH_WINDOW_SEC = max( 0.0, float(os.getenv("TEXT_MICROBATCH_WINDOW_MS", "4")) / 1000.0 ) _TEXT_REQUEST_TIMEOUT_SEC = max( 1.0, float(os.getenv("TEXT_REQUEST_TIMEOUT_SEC", "30")) ) _TEXT_MAX_INFLIGHT = max(1, int(os.getenv("TEXT_MAX_INFLIGHT", "32"))) _IMAGE_MAX_INFLIGHT = max(1, int(os.getenv("IMAGE_MAX_INFLIGHT", "20"))) _OVERLOAD_STATUS_CODE = int(os.getenv("EMBEDDING_OVERLOAD_STATUS_CODE", "503")) _LOG_PREVIEW_COUNT = max(1, int(os.getenv("EMBEDDING_LOG_PREVIEW_COUNT", "3"))) _LOG_TEXT_PREVIEW_CHARS = max(32, int(os.getenv("EMBEDDING_LOG_TEXT_PREVIEW_CHARS", "120"))) _LOG_IMAGE_PREVIEW_CHARS = max(32, int(os.getenv("EMBEDDING_LOG_IMAGE_PREVIEW_CHARS", "180"))) _VECTOR_PREVIEW_DIMS = max(1, int(os.getenv("EMBEDDING_VECTOR_PREVIEW_DIMS", "6"))) _CACHE_PREFIX = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" @dataclass class _EmbedResult: vectors: List[Optional[List[float]]] cache_hits: int cache_misses: int backend_elapsed_ms: float mode: str class _EndpointStats: def __init__(self, name: str): self.name = name self._lock = threading.Lock() self.request_total = 0 self.success_total = 0 self.failure_total = 0 self.rejected_total = 0 self.cache_hits = 0 self.cache_misses = 0 self.total_latency_ms = 0.0 self.total_backend_latency_ms = 0.0 def record_rejected(self) -> None: with self._lock: self.request_total += 1 self.rejected_total += 1 def record_completed( self, *, success: bool, latency_ms: float, backend_latency_ms: float, cache_hits: int, cache_misses: int, ) -> None: with self._lock: self.request_total += 1 if success: self.success_total += 1 else: self.failure_total += 1 self.cache_hits += max(0, int(cache_hits)) self.cache_misses += max(0, int(cache_misses)) self.total_latency_ms += max(0.0, float(latency_ms)) self.total_backend_latency_ms += max(0.0, float(backend_latency_ms)) def snapshot(self) -> Dict[str, Any]: with self._lock: completed = self.success_total + self.failure_total return { "request_total": self.request_total, "success_total": self.success_total, "failure_total": self.failure_total, "rejected_total": self.rejected_total, "cache_hits": self.cache_hits, "cache_misses": self.cache_misses, "avg_latency_ms": round(self.total_latency_ms / completed, 3) if completed else 0.0, "avg_backend_latency_ms": round(self.total_backend_latency_ms / completed, 3) if completed else 0.0, } class _InflightLimiter: def __init__(self, name: str, limit: int): self.name = name self.limit = max(1, int(limit)) self._lock = threading.Lock() self._active = 0 self._rejected = 0 self._completed = 0 self._failed = 0 self._max_active = 0 self._priority_bypass_total = 0 def try_acquire(self, *, bypass_limit: bool = False) -> tuple[bool, int]: with self._lock: if not bypass_limit and self._active >= self.limit: self._rejected += 1 active = self._active return False, active self._active += 1 self._max_active = max(self._max_active, self._active) if bypass_limit: self._priority_bypass_total += 1 active = self._active return True, active def release(self, *, success: bool) -> int: with self._lock: self._active = max(0, self._active - 1) if success: self._completed += 1 else: self._failed += 1 active = self._active return active def snapshot(self) -> Dict[str, int]: with self._lock: return { "limit": self.limit, "active": self._active, "rejected_total": self._rejected, "completed_total": self._completed, "failed_total": self._failed, "max_active": self._max_active, "priority_bypass_total": self._priority_bypass_total, } def _effective_priority(priority: int) -> int: return 1 if int(priority) > 0 else 0 def _priority_label(priority: int) -> str: return "high" if _effective_priority(priority) > 0 else "normal" @dataclass class _TextDispatchTask: normalized: List[str] effective_normalize: bool request_id: str user_id: str priority: int created_at: float done: threading.Event result: Optional[_EmbedResult] = None error: Optional[Exception] = None _text_dispatch_high_queue: "deque[_TextDispatchTask]" = deque() _text_dispatch_normal_queue: "deque[_TextDispatchTask]" = deque() _text_dispatch_cv = threading.Condition() _text_dispatch_workers: List[threading.Thread] = [] _text_dispatch_worker_stop = False _text_dispatch_worker_count = 0 def _text_dispatch_queue_depth() -> Dict[str, int]: with _text_dispatch_cv: return { "high": len(_text_dispatch_high_queue), "normal": len(_text_dispatch_normal_queue), "total": len(_text_dispatch_high_queue) + len(_text_dispatch_normal_queue), } def _pop_text_dispatch_task_locked() -> Optional["_TextDispatchTask"]: if _text_dispatch_high_queue: return _text_dispatch_high_queue.popleft() if _text_dispatch_normal_queue: return _text_dispatch_normal_queue.popleft() return None def _start_text_dispatch_workers() -> None: global _text_dispatch_workers, _text_dispatch_worker_stop, _text_dispatch_worker_count if _text_model is None: return target_worker_count = 1 if _text_backend_name == "local_st" else _TEXT_MAX_INFLIGHT alive_workers = [worker for worker in _text_dispatch_workers if worker.is_alive()] if len(alive_workers) == target_worker_count: _text_dispatch_workers = alive_workers _text_dispatch_worker_count = target_worker_count return _text_dispatch_worker_stop = False _text_dispatch_worker_count = target_worker_count _text_dispatch_workers = [] for idx in range(target_worker_count): worker = threading.Thread( target=_text_dispatch_worker_loop, args=(idx,), name=f"embed-text-dispatch-{idx}", daemon=True, ) worker.start() _text_dispatch_workers.append(worker) logger.info( "Started text dispatch workers | backend=%s workers=%d", _text_backend_name, target_worker_count, ) def _stop_text_dispatch_workers() -> None: global _text_dispatch_worker_stop with _text_dispatch_cv: _text_dispatch_worker_stop = True _text_dispatch_cv.notify_all() def _text_dispatch_worker_loop(worker_idx: int) -> None: while True: with _text_dispatch_cv: while ( not _text_dispatch_high_queue and not _text_dispatch_normal_queue and not _text_dispatch_worker_stop ): _text_dispatch_cv.wait() if _text_dispatch_worker_stop: return task = _pop_text_dispatch_task_locked() if task is None: continue try: queue_wait_ms = (time.perf_counter() - task.created_at) * 1000.0 logger.info( "text dispatch start | worker=%d priority=%s inputs=%d queue_wait_ms=%.2f", worker_idx, _priority_label(task.priority), len(task.normalized), queue_wait_ms, extra=build_request_log_extra(task.request_id, task.user_id), ) task.result = _embed_text_impl( task.normalized, task.effective_normalize, task.request_id, task.user_id, task.priority, ) except Exception as exc: task.error = exc finally: task.done.set() def _submit_text_dispatch_and_wait( normalized: List[str], effective_normalize: bool, request_id: str, user_id: str, priority: int, ) -> _EmbedResult: if not any(worker.is_alive() for worker in _text_dispatch_workers): _start_text_dispatch_workers() task = _TextDispatchTask( normalized=normalized, effective_normalize=effective_normalize, request_id=request_id, user_id=user_id, priority=_effective_priority(priority), created_at=time.perf_counter(), done=threading.Event(), ) with _text_dispatch_cv: if task.priority > 0: _text_dispatch_high_queue.append(task) else: _text_dispatch_normal_queue.append(task) _text_dispatch_cv.notify() task.done.wait() if task.error is not None: raise task.error if task.result is None: raise RuntimeError("Text dispatch worker returned empty result") return task.result _text_request_limiter = _InflightLimiter(name="text", limit=_TEXT_MAX_INFLIGHT) _image_request_limiter = _InflightLimiter(name="image", limit=_IMAGE_MAX_INFLIGHT) _text_stats = _EndpointStats(name="text") _image_stats = _EndpointStats(name="image") _text_cache = RedisEmbeddingCache(key_prefix=_CACHE_PREFIX, namespace="") _image_cache = RedisEmbeddingCache(key_prefix=_CACHE_PREFIX, namespace="image") _clip_text_cache = RedisEmbeddingCache(key_prefix=_CACHE_PREFIX, namespace="clip_text") @dataclass class _SingleTextTask: text: str normalize: bool priority: int created_at: float request_id: str user_id: str done: threading.Event result: Optional[List[float]] = None error: Optional[Exception] = None _text_single_high_queue: "deque[_SingleTextTask]" = deque() _text_single_normal_queue: "deque[_SingleTextTask]" = deque() _text_single_queue_cv = threading.Condition() _text_batch_worker: Optional[threading.Thread] = None _text_batch_worker_stop = False def _text_microbatch_queue_depth() -> Dict[str, int]: with _text_single_queue_cv: return { "high": len(_text_single_high_queue), "normal": len(_text_single_normal_queue), "total": len(_text_single_high_queue) + len(_text_single_normal_queue), } def _pop_single_text_task_locked() -> Optional["_SingleTextTask"]: if _text_single_high_queue: return _text_single_high_queue.popleft() if _text_single_normal_queue: return _text_single_normal_queue.popleft() return None def _compact_preview(text: str, max_chars: int) -> str: compact = " ".join((text or "").split()) if len(compact) <= max_chars: return compact return compact[:max_chars] + "..." def _preview_inputs(items: List[str], max_items: int, max_chars: int) -> List[Dict[str, Any]]: previews: List[Dict[str, Any]] = [] for idx, item in enumerate(items[:max_items]): previews.append( { "idx": idx, "len": len(item), "preview": _compact_preview(item, max_chars), } ) return previews def _preview_vector(vec: Optional[List[float]], max_dims: int = _VECTOR_PREVIEW_DIMS) -> List[float]: if not vec: return [] return [round(float(v), 6) for v in vec[:max_dims]] def _resolve_request_id(http_request: Request) -> str: header_value = http_request.headers.get("X-Request-ID") if header_value and header_value.strip(): return header_value.strip()[:32] return str(uuid.uuid4())[:8] def _resolve_user_id(http_request: Request) -> str: header_value = http_request.headers.get("X-User-ID") or http_request.headers.get("User-ID") if header_value and header_value.strip(): return header_value.strip()[:64] return "-1" def _request_client(http_request: Request) -> str: client = getattr(http_request, "client", None) host = getattr(client, "host", None) return str(host or "-") def _encode_local_st(texts: List[str], normalize_embeddings: bool) -> Any: with _text_encode_lock: return _text_model.encode( texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE, normalize_embeddings=normalize_embeddings, ) def _start_text_batch_worker() -> None: global _text_batch_worker, _text_batch_worker_stop if _text_batch_worker is not None and _text_batch_worker.is_alive(): return _text_batch_worker_stop = False _text_batch_worker = threading.Thread( target=_text_batch_worker_loop, name="embed-text-microbatch-worker", daemon=True, ) _text_batch_worker.start() logger.info( "Started local_st text micro-batch worker | window_ms=%.1f max_batch=%d", _TEXT_MICROBATCH_WINDOW_SEC * 1000.0, int(CONFIG.TEXT_BATCH_SIZE), ) def _stop_text_batch_worker() -> None: global _text_batch_worker_stop with _text_single_queue_cv: _text_batch_worker_stop = True _text_single_queue_cv.notify_all() def _text_batch_worker_loop() -> None: max_batch = max(1, int(CONFIG.TEXT_BATCH_SIZE)) while True: with _text_single_queue_cv: while ( not _text_single_high_queue and not _text_single_normal_queue and not _text_batch_worker_stop ): _text_single_queue_cv.wait() if _text_batch_worker_stop: return first_task = _pop_single_text_task_locked() if first_task is None: continue batch: List[_SingleTextTask] = [first_task] deadline = time.perf_counter() + _TEXT_MICROBATCH_WINDOW_SEC while len(batch) < max_batch: remaining = deadline - time.perf_counter() if remaining <= 0: break if not _text_single_high_queue and not _text_single_normal_queue: _text_single_queue_cv.wait(timeout=remaining) continue while len(batch) < max_batch: next_task = _pop_single_text_task_locked() if next_task is None: break batch.append(next_task) try: queue_wait_ms = [(time.perf_counter() - task.created_at) * 1000.0 for task in batch] reqids = [task.request_id for task in batch] uids = [task.user_id for task in batch] logger.info( "text microbatch dispatch | size=%d priority=%s queue_wait_ms_min=%.2f queue_wait_ms_max=%.2f reqids=%s uids=%s preview=%s", len(batch), _priority_label(max(task.priority for task in batch)), min(queue_wait_ms) if queue_wait_ms else 0.0, max(queue_wait_ms) if queue_wait_ms else 0.0, reqids, uids, _preview_inputs( [task.text for task in batch], _LOG_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS, ), extra=build_request_log_extra(), ) batch_t0 = time.perf_counter() embs = _encode_local_st([task.text for task in batch], normalize_embeddings=False) if embs is None or len(embs) != len(batch): raise RuntimeError( f"Text model response length mismatch in micro-batch: " f"expected {len(batch)}, got {0 if embs is None else len(embs)}" ) for task, emb in zip(batch, embs): vec = _as_list(emb, normalize=task.normalize) if vec is None: raise RuntimeError("Text model returned empty embedding in micro-batch") task.result = vec logger.info( "text microbatch done | size=%d reqids=%s uids=%s dim=%d backend_elapsed_ms=%.2f", len(batch), reqids, uids, len(batch[0].result) if batch and batch[0].result is not None else 0, (time.perf_counter() - batch_t0) * 1000.0, extra=build_request_log_extra(), ) except Exception as exc: logger.error( "text microbatch failed | size=%d reqids=%s uids=%s error=%s", len(batch), [task.request_id for task in batch], [task.user_id for task in batch], exc, exc_info=True, extra=build_request_log_extra(), ) for task in batch: task.error = exc finally: for task in batch: task.done.set() def _encode_single_text_with_microbatch( text: str, normalize: bool, request_id: str, user_id: str, priority: int, ) -> List[float]: task = _SingleTextTask( text=text, normalize=normalize, priority=_effective_priority(priority), created_at=time.perf_counter(), request_id=request_id, user_id=user_id, done=threading.Event(), ) with _text_single_queue_cv: if task.priority > 0: _text_single_high_queue.append(task) else: _text_single_normal_queue.append(task) _text_single_queue_cv.notify() if not task.done.wait(timeout=_TEXT_REQUEST_TIMEOUT_SEC): with _text_single_queue_cv: queue = _text_single_high_queue if task.priority > 0 else _text_single_normal_queue try: queue.remove(task) except ValueError: pass raise RuntimeError( f"Timed out waiting for text micro-batch worker ({_TEXT_REQUEST_TIMEOUT_SEC:.1f}s)" ) if task.error is not None: raise task.error if task.result is None: raise RuntimeError("Text micro-batch worker returned empty result") return task.result @app.on_event("startup") def load_models(): """Load models at service startup to avoid first-request latency.""" global _text_model, _image_model, _text_backend_name logger.info( "Loading embedding models at startup | service_kind=%s text_enabled=%s image_enabled=%s", _SERVICE_KIND, open_text_model, open_image_model, ) if open_text_model: try: backend_name, backend_cfg = get_embedding_backend_config() _text_backend_name = backend_name if backend_name == "tei": from embeddings.text_embedding_tei import TEITextModel base_url = backend_cfg.get("base_url") or CONFIG.TEI_BASE_URL timeout_sec = int(backend_cfg.get("timeout_sec") or CONFIG.TEI_TIMEOUT_SEC) logger.info("Loading text backend: tei (base_url=%s)", base_url) _text_model = TEITextModel( base_url=str(base_url), timeout_sec=timeout_sec, max_client_batch_size=int( backend_cfg.get("max_client_batch_size") or CONFIG.TEI_MAX_CLIENT_BATCH_SIZE ), ) elif backend_name == "local_st": from embeddings.text_embedding_sentence_transformers import Qwen3TextModel model_id = backend_cfg.get("model_id") or CONFIG.TEXT_MODEL_ID logger.info("Loading text backend: local_st (model=%s)", model_id) _text_model = Qwen3TextModel(model_id=str(model_id)) _start_text_batch_worker() else: raise ValueError( f"Unsupported embedding backend: {backend_name}. " "Supported: tei, local_st" ) _start_text_dispatch_workers() logger.info("Text backend loaded successfully: %s", _text_backend_name) except Exception as e: logger.error("Failed to load text model: %s", e, exc_info=True) raise if open_image_model: try: if CONFIG.USE_CLIP_AS_SERVICE: from embeddings.clip_as_service_encoder import ClipAsServiceImageEncoder logger.info( "Loading image encoder via clip-as-service: %s (configured model: %s)", CONFIG.CLIP_AS_SERVICE_SERVER, CONFIG.CLIP_AS_SERVICE_MODEL_NAME, ) _image_model = ClipAsServiceImageEncoder( server=CONFIG.CLIP_AS_SERVICE_SERVER, batch_size=CONFIG.IMAGE_BATCH_SIZE, ) logger.info("Image model (clip-as-service) loaded successfully") else: from embeddings.clip_model import ClipImageModel logger.info( "Loading local image model: %s (device: %s)", CONFIG.IMAGE_MODEL_NAME, CONFIG.IMAGE_DEVICE, ) _image_model = ClipImageModel( model_name=CONFIG.IMAGE_MODEL_NAME, device=CONFIG.IMAGE_DEVICE, ) logger.info("Image model (local CN-CLIP) loaded successfully") except Exception as e: logger.error("Failed to load image model: %s", e, exc_info=True) raise logger.info("All embedding models loaded successfully, service ready") @app.on_event("shutdown") def stop_workers() -> None: _stop_text_batch_worker() _stop_text_dispatch_workers() def _normalize_vector(vec: np.ndarray) -> np.ndarray: norm = float(np.linalg.norm(vec)) if not np.isfinite(norm) or norm <= 0.0: raise RuntimeError("Embedding vector has invalid norm (must be > 0)") return vec / norm def _as_list(embedding: Optional[np.ndarray], normalize: bool = False) -> Optional[List[float]]: if embedding is None: return None if not isinstance(embedding, np.ndarray): embedding = np.array(embedding, dtype=np.float32) if embedding.ndim != 1: embedding = embedding.reshape(-1) embedding = embedding.astype(np.float32, copy=False) if normalize: embedding = _normalize_vector(embedding).astype(np.float32, copy=False) return embedding.tolist() def _try_full_text_cache_hit( normalized: List[str], effective_normalize: bool, ) -> Optional[_EmbedResult]: out: List[Optional[List[float]]] = [] for text in normalized: cached = _text_cache.get(build_text_cache_key(text, normalize=effective_normalize)) if cached is None: return None vec = _as_list(cached, normalize=False) if vec is None: return None out.append(vec) return _EmbedResult( vectors=out, cache_hits=len(out), cache_misses=0, backend_elapsed_ms=0.0, mode="cache-only", ) def _try_full_image_lane_cache_hit( items: List[str], effective_normalize: bool, *, lane: str, ) -> Optional[_EmbedResult]: out: List[Optional[List[float]]] = [] for item in items: if lane == "image": ck = _mm_image_cache_key( item, normalize=effective_normalize, model_name=CONFIG.MULTIMODAL_MODEL_NAME ) cached = _image_cache.get(ck) else: ck = _mm_clip_text_cache_key( item, normalize=effective_normalize, model_name=CONFIG.MULTIMODAL_MODEL_NAME ) cached = _clip_text_cache.get(ck) if cached is None: return None vec = _as_list(cached, normalize=False) if vec is None: return None out.append(vec) return _EmbedResult( vectors=out, cache_hits=len(out), cache_misses=0, backend_elapsed_ms=0.0, mode="cache-only", ) def _embed_image_lane_impl( items: List[str], effective_normalize: bool, request_id: str, user_id: str, *, lane: str, ) -> _EmbedResult: if _image_model is None: raise RuntimeError("Image model not loaded") out: List[Optional[List[float]]] = [None] * len(items) missing_indices: List[int] = [] missing_items: List[str] = [] missing_keys: List[str] = [] cache_hits = 0 for idx, item in enumerate(items): if lane == "image": ck = _mm_image_cache_key( item, normalize=effective_normalize, model_name=CONFIG.MULTIMODAL_MODEL_NAME ) cached = _image_cache.get(ck) else: ck = _mm_clip_text_cache_key( item, normalize=effective_normalize, model_name=CONFIG.MULTIMODAL_MODEL_NAME ) cached = _clip_text_cache.get(ck) if cached is not None: vec = _as_list(cached, normalize=False) if vec is not None: out[idx] = vec cache_hits += 1 continue missing_indices.append(idx) missing_items.append(item) missing_keys.append(ck) if not missing_items: logger.info( "%s lane cache-only | inputs=%d normalize=%s dim=%d cache_hits=%d", lane, len(items), effective_normalize, len(out[0]) if out and out[0] is not None else 0, cache_hits, extra=build_request_log_extra(request_id=request_id, user_id=user_id), ) return _EmbedResult( vectors=out, cache_hits=cache_hits, cache_misses=0, backend_elapsed_ms=0.0, mode="cache-only", ) backend_t0 = time.perf_counter() with _image_encode_lock: if lane == "image": vectors = _image_model.encode_image_urls( missing_items, batch_size=CONFIG.IMAGE_BATCH_SIZE, normalize_embeddings=effective_normalize, ) else: vectors = _image_model.encode_clip_texts( missing_items, batch_size=CONFIG.IMAGE_BATCH_SIZE, normalize_embeddings=effective_normalize, ) if vectors is None or len(vectors) != len(missing_items): raise RuntimeError( f"{lane} lane length mismatch: expected {len(missing_items)}, " f"got {0 if vectors is None else len(vectors)}" ) for pos, ck, vec in zip(missing_indices, missing_keys, vectors): out_vec = _as_list(vec, normalize=effective_normalize) if out_vec is None: raise RuntimeError(f"{lane} lane empty embedding at position {pos}") out[pos] = out_vec if lane == "image": _image_cache.set(ck, np.asarray(out_vec, dtype=np.float32)) else: _clip_text_cache.set(ck, np.asarray(out_vec, dtype=np.float32)) backend_elapsed_ms = (time.perf_counter() - backend_t0) * 1000.0 logger.info( "%s lane backend-batch | inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d backend_elapsed_ms=%.2f", lane, len(items), effective_normalize, len(out[0]) if out and out[0] is not None else 0, cache_hits, len(missing_items), backend_elapsed_ms, extra=build_request_log_extra(request_id=request_id, user_id=user_id), ) return _EmbedResult( vectors=out, cache_hits=cache_hits, cache_misses=len(missing_items), backend_elapsed_ms=backend_elapsed_ms, mode="backend-batch", ) @app.get("/health") def health() -> Dict[str, Any]: """Health check endpoint. Returns status and current throttling stats.""" ready = (not open_text_model or _text_model is not None) and (not open_image_model or _image_model is not None) text_dispatch_depth = _text_dispatch_queue_depth() text_microbatch_depth = _text_microbatch_queue_depth() return { "status": "ok" if ready else "degraded", "service_kind": _SERVICE_KIND, "text_model_loaded": _text_model is not None, "text_backend": _text_backend_name, "image_model_loaded": _image_model is not None, "cache_enabled": { "text": _text_cache.redis_client is not None, "image": _image_cache.redis_client is not None, "clip_text": _clip_text_cache.redis_client is not None, }, "limits": { "text": _text_request_limiter.snapshot(), "image": _image_request_limiter.snapshot(), }, "stats": { "text": _text_stats.snapshot(), "image": _image_stats.snapshot(), }, "text_dispatch": { "workers": _text_dispatch_worker_count, "workers_alive": sum(1 for worker in _text_dispatch_workers if worker.is_alive()), "queue_depth": text_dispatch_depth["total"], "queue_depth_high": text_dispatch_depth["high"], "queue_depth_normal": text_dispatch_depth["normal"], }, "text_microbatch": { "window_ms": round(_TEXT_MICROBATCH_WINDOW_SEC * 1000.0, 3), "queue_depth": text_microbatch_depth["total"], "queue_depth_high": text_microbatch_depth["high"], "queue_depth_normal": text_microbatch_depth["normal"], "worker_alive": bool(_text_batch_worker is not None and _text_batch_worker.is_alive()), "request_timeout_sec": _TEXT_REQUEST_TIMEOUT_SEC, }, } @app.get("/ready") def ready() -> Dict[str, Any]: text_ready = (not open_text_model) or (_text_model is not None) image_ready = (not open_image_model) or (_image_model is not None) if not (text_ready and image_ready): raise HTTPException( status_code=503, detail={ "service_kind": _SERVICE_KIND, "text_ready": text_ready, "image_ready": image_ready, }, ) return { "status": "ready", "service_kind": _SERVICE_KIND, "text_ready": text_ready, "image_ready": image_ready, } def _embed_text_impl( normalized: List[str], effective_normalize: bool, request_id: str, user_id: str, priority: int = 0, ) -> _EmbedResult: if _text_model is None: raise RuntimeError("Text model not loaded") out: List[Optional[List[float]]] = [None] * len(normalized) missing_indices: List[int] = [] missing_texts: List[str] = [] missing_cache_keys: List[str] = [] cache_hits = 0 for idx, text in enumerate(normalized): cache_key = build_text_cache_key(text, normalize=effective_normalize) cached = _text_cache.get(cache_key) if cached is not None: vec = _as_list(cached, normalize=False) if vec is not None: out[idx] = vec cache_hits += 1 continue missing_indices.append(idx) missing_texts.append(text) missing_cache_keys.append(cache_key) if not missing_texts: logger.info( "text backend done | backend=%s mode=cache-only inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 backend_elapsed_ms=0.00", _text_backend_name, len(normalized), effective_normalize, len(out[0]) if out and out[0] is not None else 0, cache_hits, extra=build_request_log_extra(request_id, user_id), ) return _EmbedResult( vectors=out, cache_hits=cache_hits, cache_misses=0, backend_elapsed_ms=0.0, mode="cache-only", ) backend_t0 = time.perf_counter() try: if _text_backend_name == "local_st": if len(missing_texts) == 1 and _text_batch_worker is not None: computed = [ _encode_single_text_with_microbatch( missing_texts[0], normalize=effective_normalize, request_id=request_id, user_id=user_id, priority=priority, ) ] mode = "microbatch-single" else: embs = _encode_local_st(missing_texts, normalize_embeddings=False) computed = [] for i, emb in enumerate(embs): vec = _as_list(emb, normalize=effective_normalize) if vec is None: raise RuntimeError(f"Text model returned empty embedding for missing index {i}") computed.append(vec) mode = "direct-batch" else: embs = _text_model.encode( missing_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE, normalize_embeddings=effective_normalize, ) computed = [] for i, emb in enumerate(embs): vec = _as_list(emb, normalize=False) if vec is None: raise RuntimeError(f"Text model returned empty embedding for missing index {i}") computed.append(vec) mode = "backend-batch" except Exception as e: logger.error( "Text embedding backend failure: %s", e, exc_info=True, extra=build_request_log_extra(request_id, user_id), ) raise RuntimeError(f"Text embedding backend failure: {e}") from e if len(computed) != len(missing_texts): raise RuntimeError( f"Text model response length mismatch: expected {len(missing_texts)}, " f"got {len(computed)}" ) for pos, cache_key, vec in zip(missing_indices, missing_cache_keys, computed): out[pos] = vec _text_cache.set(cache_key, np.asarray(vec, dtype=np.float32)) backend_elapsed_ms = (time.perf_counter() - backend_t0) * 1000.0 logger.info( "text backend done | backend=%s mode=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d backend_elapsed_ms=%.2f", _text_backend_name, mode, len(normalized), effective_normalize, len(out[0]) if out and out[0] is not None else 0, cache_hits, len(missing_texts), backend_elapsed_ms, extra=build_request_log_extra(request_id, user_id), ) return _EmbedResult( vectors=out, cache_hits=cache_hits, cache_misses=len(missing_texts), backend_elapsed_ms=backend_elapsed_ms, mode=mode, ) @app.post("/embed/text") async def embed_text( texts: List[str], http_request: Request, response: Response, normalize: Optional[bool] = None, priority: int = 0, ) -> List[Optional[List[float]]]: if _text_model is None: raise HTTPException(status_code=503, detail="Text embedding model not loaded in this service") request_id = _resolve_request_id(http_request) user_id = _resolve_user_id(http_request) _, _, log_tokens = bind_request_log_context(request_id, user_id) response.headers["X-Request-ID"] = request_id response.headers["X-User-ID"] = user_id request_started = time.perf_counter() success = False backend_elapsed_ms = 0.0 cache_hits = 0 cache_misses = 0 limiter_acquired = False try: if priority < 0: raise HTTPException(status_code=400, detail="priority must be >= 0") effective_priority = _effective_priority(priority) effective_normalize = bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) normalized: List[str] = [] for i, t in enumerate(texts): if not isinstance(t, str): raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: must be string") s = t.strip() if not s: raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: empty string") normalized.append(s) cache_check_started = time.perf_counter() cache_only = _try_full_text_cache_hit(normalized, effective_normalize) if cache_only is not None: latency_ms = (time.perf_counter() - cache_check_started) * 1000.0 _text_stats.record_completed( success=True, latency_ms=latency_ms, backend_latency_ms=0.0, cache_hits=cache_only.cache_hits, cache_misses=0, ) logger.info( "embed_text response | backend=%s mode=cache-only priority=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 first_vector=%s latency_ms=%.2f", _text_backend_name, _priority_label(effective_priority), len(normalized), effective_normalize, len(cache_only.vectors[0]) if cache_only.vectors and cache_only.vectors[0] is not None else 0, cache_only.cache_hits, _preview_vector(cache_only.vectors[0] if cache_only.vectors else None), latency_ms, extra=build_request_log_extra(request_id, user_id), ) return cache_only.vectors accepted, active = _text_request_limiter.try_acquire(bypass_limit=effective_priority > 0) if not accepted: _text_stats.record_rejected() logger.warning( "embed_text rejected | client=%s backend=%s priority=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", _request_client(http_request), _text_backend_name, _priority_label(effective_priority), len(normalized), effective_normalize, active, _TEXT_MAX_INFLIGHT, _preview_inputs(normalized, _LOG_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS), extra=build_request_log_extra(request_id, user_id), ) raise HTTPException( status_code=_OVERLOAD_STATUS_CODE, detail=( "Text embedding service busy for priority=0 requests: " f"active={active}, limit={_TEXT_MAX_INFLIGHT}" ), ) limiter_acquired = True logger.info( "embed_text request | client=%s backend=%s priority=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", _request_client(http_request), _text_backend_name, _priority_label(effective_priority), len(normalized), effective_normalize, active, _TEXT_MAX_INFLIGHT, _preview_inputs(normalized, _LOG_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS), extra=build_request_log_extra(request_id, user_id), ) verbose_logger.info( "embed_text detail | payload=%s normalize=%s backend=%s priority=%s", normalized, effective_normalize, _text_backend_name, _priority_label(effective_priority), extra=build_request_log_extra(request_id, user_id), ) result = await run_in_threadpool( _submit_text_dispatch_and_wait, normalized, effective_normalize, request_id, user_id, effective_priority, ) success = True backend_elapsed_ms = result.backend_elapsed_ms cache_hits = result.cache_hits cache_misses = result.cache_misses latency_ms = (time.perf_counter() - request_started) * 1000.0 _text_stats.record_completed( success=True, latency_ms=latency_ms, backend_latency_ms=backend_elapsed_ms, cache_hits=cache_hits, cache_misses=cache_misses, ) logger.info( "embed_text response | backend=%s mode=%s priority=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d first_vector=%s latency_ms=%.2f", _text_backend_name, result.mode, _priority_label(effective_priority), len(normalized), effective_normalize, len(result.vectors[0]) if result.vectors and result.vectors[0] is not None else 0, cache_hits, cache_misses, _preview_vector(result.vectors[0] if result.vectors else None), latency_ms, extra=build_request_log_extra(request_id, user_id), ) verbose_logger.info( "embed_text result detail | count=%d priority=%s first_vector=%s latency_ms=%.2f", len(result.vectors), _priority_label(effective_priority), result.vectors[0][: _VECTOR_PREVIEW_DIMS] if result.vectors and result.vectors[0] is not None else [], latency_ms, extra=build_request_log_extra(request_id, user_id), ) return result.vectors except HTTPException: raise except Exception as e: latency_ms = (time.perf_counter() - request_started) * 1000.0 _text_stats.record_completed( success=False, latency_ms=latency_ms, backend_latency_ms=backend_elapsed_ms, cache_hits=cache_hits, cache_misses=cache_misses, ) logger.error( "embed_text failed | backend=%s priority=%s inputs=%d normalize=%s latency_ms=%.2f error=%s", _text_backend_name, _priority_label(effective_priority), len(normalized), effective_normalize, latency_ms, e, exc_info=True, extra=build_request_log_extra(request_id, user_id), ) raise HTTPException(status_code=502, detail=str(e)) from e finally: if limiter_acquired: remaining = _text_request_limiter.release(success=success) logger.info( "embed_text finalize | success=%s priority=%s active_after=%d", success, _priority_label(effective_priority), remaining, extra=build_request_log_extra(request_id, user_id), ) reset_request_log_context(log_tokens) def _parse_string_inputs(raw: List[Any], *, kind: str, empty_detail: str) -> List[str]: out: List[str] = [] for i, x in enumerate(raw): if not isinstance(x, str): raise HTTPException(status_code=400, detail=f"Invalid {kind} at index {i}: must be string") s = x.strip() if not s: raise HTTPException(status_code=400, detail=f"Invalid {kind} at index {i}: {empty_detail}") out.append(s) return out async def _run_image_lane_embed( *, route: str, lane: str, items: List[str], http_request: Request, response: Response, normalize: Optional[bool], priority: int, preview_chars: int, ) -> List[Optional[List[float]]]: request_id = _resolve_request_id(http_request) user_id = _resolve_user_id(http_request) _, _, log_tokens = bind_request_log_context(request_id, user_id) response.headers["X-Request-ID"] = request_id response.headers["X-User-ID"] = user_id request_started = time.perf_counter() success = False backend_elapsed_ms = 0.0 cache_hits = 0 cache_misses = 0 limiter_acquired = False items_in: List[str] = list(items) try: if priority < 0: raise HTTPException(status_code=400, detail="priority must be >= 0") effective_priority = _effective_priority(priority) effective_normalize = bool(CONFIG.IMAGE_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) cache_check_started = time.perf_counter() cache_only = _try_full_image_lane_cache_hit(items, effective_normalize, lane=lane) if cache_only is not None: latency_ms = (time.perf_counter() - cache_check_started) * 1000.0 _image_stats.record_completed( success=True, latency_ms=latency_ms, backend_latency_ms=0.0, cache_hits=cache_only.cache_hits, cache_misses=0, ) logger.info( "%s response | mode=cache-only priority=%s inputs=%d normalize=%s dim=%d cache_hits=%d first_vector=%s latency_ms=%.2f", route, _priority_label(effective_priority), len(items), effective_normalize, len(cache_only.vectors[0]) if cache_only.vectors and cache_only.vectors[0] is not None else 0, cache_only.cache_hits, _preview_vector(cache_only.vectors[0] if cache_only.vectors else None), latency_ms, extra=build_request_log_extra(request_id, user_id), ) return cache_only.vectors accepted, active = _image_request_limiter.try_acquire(bypass_limit=effective_priority > 0) if not accepted: _image_stats.record_rejected() logger.warning( "%s rejected | client=%s priority=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", route, _request_client(http_request), _priority_label(effective_priority), len(items), effective_normalize, active, _IMAGE_MAX_INFLIGHT, _preview_inputs(items, _LOG_PREVIEW_COUNT, preview_chars), extra=build_request_log_extra(request_id, user_id), ) raise HTTPException( status_code=_OVERLOAD_STATUS_CODE, detail=( "Image embedding service busy for priority=0 requests: " f"active={active}, limit={_IMAGE_MAX_INFLIGHT}" ), ) limiter_acquired = True logger.info( "%s request | client=%s priority=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", route, _request_client(http_request), _priority_label(effective_priority), len(items), effective_normalize, active, _IMAGE_MAX_INFLIGHT, _preview_inputs(items, _LOG_PREVIEW_COUNT, preview_chars), extra=build_request_log_extra(request_id, user_id), ) verbose_logger.info( "%s detail | payload=%s normalize=%s priority=%s", route, items, effective_normalize, _priority_label(effective_priority), extra=build_request_log_extra(request_id, user_id), ) result = await run_in_threadpool( _embed_image_lane_impl, items, effective_normalize, request_id, user_id, lane=lane, ) success = True backend_elapsed_ms = result.backend_elapsed_ms cache_hits = result.cache_hits cache_misses = result.cache_misses latency_ms = (time.perf_counter() - request_started) * 1000.0 _image_stats.record_completed( success=True, latency_ms=latency_ms, backend_latency_ms=backend_elapsed_ms, cache_hits=cache_hits, cache_misses=cache_misses, ) logger.info( "%s response | mode=%s priority=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=%d first_vector=%s latency_ms=%.2f", route, result.mode, _priority_label(effective_priority), len(items), effective_normalize, len(result.vectors[0]) if result.vectors and result.vectors[0] is not None else 0, cache_hits, cache_misses, _preview_vector(result.vectors[0] if result.vectors else None), latency_ms, extra=build_request_log_extra(request_id, user_id), ) verbose_logger.info( "%s result detail | count=%d first_vector=%s latency_ms=%.2f", route, len(result.vectors), result.vectors[0][: _VECTOR_PREVIEW_DIMS] if result.vectors and result.vectors[0] is not None else [], latency_ms, extra=build_request_log_extra(request_id, user_id), ) return result.vectors except HTTPException: raise except Exception as e: latency_ms = (time.perf_counter() - request_started) * 1000.0 _image_stats.record_completed( success=False, latency_ms=latency_ms, backend_latency_ms=backend_elapsed_ms, cache_hits=cache_hits, cache_misses=cache_misses, ) logger.error( "%s failed | priority=%s inputs=%d normalize=%s latency_ms=%.2f error=%s", route, _priority_label(effective_priority), len(items_in), effective_normalize, latency_ms, e, exc_info=True, extra=build_request_log_extra(request_id, user_id), ) raise HTTPException(status_code=502, detail=f"{route} backend failure: {e}") from e finally: if limiter_acquired: remaining = _image_request_limiter.release(success=success) logger.info( "%s finalize | success=%s priority=%s active_after=%d", route, success, _priority_label(effective_priority), remaining, extra=build_request_log_extra(request_id, user_id), ) reset_request_log_context(log_tokens) @app.post("/embed/image") async def embed_image( images: List[str], http_request: Request, response: Response, normalize: Optional[bool] = None, priority: int = 0, ) -> List[Optional[List[float]]]: if _image_model is None: raise HTTPException(status_code=503, detail="Image embedding model not loaded in this service") items = _parse_string_inputs(images, kind="image", empty_detail="empty URL/path") return await _run_image_lane_embed( route="embed_image", lane="image", items=items, http_request=http_request, response=response, normalize=normalize, priority=priority, preview_chars=_LOG_IMAGE_PREVIEW_CHARS, ) @app.post("/embed/clip_text") async def embed_clip_text( texts: List[str], http_request: Request, response: Response, normalize: Optional[bool] = None, priority: int = 0, ) -> List[Optional[List[float]]]: """CN-CLIP 文本塔,与 ``POST /embed/image`` 同向量空间。""" if _image_model is None: raise HTTPException(status_code=503, detail="Image embedding model not loaded in this service") items = _parse_string_inputs(texts, kind="text", empty_detail="empty string") return await _run_image_lane_embed( route="embed_clip_text", lane="clip_text", items=items, http_request=http_request, response=response, normalize=normalize, priority=priority, preview_chars=_LOG_TEXT_PREVIEW_CHARS, ) def build_image_cache_key(url: str, *, normalize: bool, model_name: Optional[str] = None) -> str: """Tests/tools: same key as ``/embed/image`` lane; defaults to ``CONFIG.MULTIMODAL_MODEL_NAME``.""" return _mm_image_cache_key( url, normalize=normalize, model_name=model_name or CONFIG.MULTIMODAL_MODEL_NAME ) def build_clip_text_cache_key(text: str, *, normalize: bool, model_name: Optional[str] = None) -> str: """Tests/tools: same key as ``/embed/clip_text`` lane; defaults to ``CONFIG.MULTIMODAL_MODEL_NAME``.""" return _mm_clip_text_cache_key( text, normalize=normalize, model_name=model_name or CONFIG.MULTIMODAL_MODEL_NAME )