server.py 4.35 KB
"""
Embedding service (FastAPI).

API (simple list-in, list-out; aligned by index; failures -> null):
- POST /embed/text   body: ["text1", "text2", ...] -> [[...], null, ...]
- POST /embed/image  body: ["url_or_path1", ...]  -> [[...], null, ...]
"""

import logging
import threading
from typing import Any, Dict, List, Optional

import numpy as np
from fastapi import FastAPI

from embeddings.config import CONFIG
from embeddings.bge_model import BgeTextModel
from embeddings.clip_model import ClipImageModel

logger = logging.getLogger(__name__)

app = FastAPI(title="saas-search Embedding Service", version="1.0.0")

# Models are loaded at startup, not lazily
_text_model: Optional[BgeTextModel] = None
_image_model: Optional[ClipImageModel] = None
open_text_model = True
open_image_model = False

_text_encode_lock = threading.Lock()
_image_encode_lock = threading.Lock()


@app.on_event("startup")
def load_models():
    """Load models at service startup to avoid first-request latency."""
    global _text_model, _image_model

    logger.info("Loading embedding models at startup...")

    # Load text model
    if open_text_model:
        try:
            logger.info(f"Loading text model: {CONFIG.TEXT_MODEL_DIR}")
            _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR)
            logger.info("Text model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load text model: {e}", exc_info=True)
            raise
    

    # Load image model
    if open_image_model:
        try:
            logger.info(f"Loading image model: {CONFIG.IMAGE_MODEL_NAME} (device: {CONFIG.IMAGE_DEVICE})")
            _image_model = ClipImageModel(
                model_name=CONFIG.IMAGE_MODEL_NAME,
                device=CONFIG.IMAGE_DEVICE,
            )
            logger.info("Image model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load image model: {e}", exc_info=True)
            raise

    logger.info("All embedding models loaded successfully, service ready")


def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]:
    if embedding is None:
        return None
    if not isinstance(embedding, np.ndarray):
        embedding = np.array(embedding, dtype=np.float32)
    if embedding.ndim != 1:
        embedding = embedding.reshape(-1)
    return embedding.astype(np.float32).tolist()


@app.get("/health")
def health() -> Dict[str, Any]:
    """Health check endpoint. Returns status and model loading state."""
    return {
        "status": "ok",
        "text_model_loaded": _text_model is not None,
        "image_model_loaded": _image_model is not None,
    }


@app.post("/embed/text")
def embed_text(texts: List[str]) -> List[Optional[List[float]]]:
    if _text_model is None:
        raise RuntimeError("Text model not loaded")
    out: List[Optional[List[float]]] = [None] * len(texts)

    indexed_texts: List[tuple] = []
    for i, t in enumerate(texts):
        if t is None:
            continue
        if not isinstance(t, str):
            t = str(t)
        t = t.strip()
        if not t:
            continue
        indexed_texts.append((i, t))

    if not indexed_texts:
        return out

    batch_texts = [t for _, t in indexed_texts]
    try:
        with _text_encode_lock:
            embs = _text_model.encode_batch(
                batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE
            )
        for j, (idx, _t) in enumerate(indexed_texts):
            out[idx] = _as_list(embs[j])
    except Exception:
        # keep Nones
        pass
    return out


@app.post("/embed/image")
def embed_image(images: List[str]) -> List[Optional[List[float]]]:
    if _image_model is None:
        raise RuntimeError("Image model not loaded")
    out: List[Optional[List[float]]] = [None] * len(images)

    with _image_encode_lock:
        for i, url_or_path in enumerate(images):
            try:
                if url_or_path is None:
                    continue
                if not isinstance(url_or_path, str):
                    url_or_path = str(url_or_path)
                url_or_path = url_or_path.strip()
                if not url_or_path:
                    continue
                emb = _image_model.encode_image_from_url(url_or_path)
                out[i] = _as_list(emb)
            except Exception:
                out[i] = None
    return out