""" Embedding service (FastAPI). API (simple list-in, list-out; aligned by index; failures -> null): - POST /embed/text body: ["text1", "text2", ...] -> [[...], null, ...] - POST /embed/image body: ["url_or_path1", ...] -> [[...], null, ...] """ import threading from typing import Any, Dict, List, Optional import numpy as np from fastapi import FastAPI from embeddings.config import CONFIG from embeddings.bge_model import BgeTextModel from embeddings.clip_model import ClipImageModel app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0") _text_model = None _image_model = None _text_init_lock = threading.Lock() _image_init_lock = threading.Lock() _text_encode_lock = threading.Lock() _image_encode_lock = threading.Lock() def _get_text_model(): global _text_model if _text_model is None: with _text_init_lock: if _text_model is None: _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR) return _text_model def _get_image_model(): global _image_model if _image_model is None: with _image_init_lock: if _image_model is None: _image_model = ClipImageModel( model_name=CONFIG.IMAGE_MODEL_NAME, device=CONFIG.IMAGE_DEVICE, ) return _image_model def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]: if embedding is None: return None if not isinstance(embedding, np.ndarray): embedding = np.array(embedding, dtype=np.float32) if embedding.ndim != 1: embedding = embedding.reshape(-1) return embedding.astype(np.float32).tolist() @app.get("/health") def health() -> Dict[str, Any]: return {"status": "ok"} @app.post("/embed/text") def embed_text(texts: List[str]) -> List[Optional[List[float]]]: model = _get_text_model() out: List[Optional[List[float]]] = [None] * len(texts) indexed_texts: List[tuple] = [] for i, t in enumerate(texts): if t is None: continue if not isinstance(t, str): t = str(t) t = t.strip() if not t: continue indexed_texts.append((i, t)) if not indexed_texts: return out batch_texts = [t for _, t in indexed_texts] try: with _text_encode_lock: embs = model.encode_batch( batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE ) for j, (idx, _t) in enumerate(indexed_texts): out[idx] = _as_list(embs[j]) except Exception: # keep Nones pass return out @app.post("/embed/image") def embed_image(images: List[str]) -> List[Optional[List[float]]]: model = _get_image_model() out: List[Optional[List[float]]] = [None] * len(images) with _image_encode_lock: for i, url_or_path in enumerate(images): try: if url_or_path is None: continue if not isinstance(url_or_path, str): url_or_path = str(url_or_path) url_or_path = url_or_path.strip() if not url_or_path: continue emb = model.encode_image_from_url(url_or_path) out[i] = _as_list(emb) except Exception: out[i] = None return out