server.py 6.03 KB
"""
Embedding service (FastAPI).

API (simple list-in, list-out; aligned by index; failures -> null):
- POST /embed/text   body: ["text1", "text2", ...] -> [[...], null, ...]
- POST /embed/image  body: ["url_or_path1", ...]  -> [[...], null, ...]
"""

import logging
import threading
from typing import Any, Dict, List, Optional

import numpy as np
from fastapi import FastAPI

from embeddings.config import CONFIG
from embeddings.protocols import ImageEncoderProtocol

logger = logging.getLogger(__name__)

app = FastAPI(title="saas-search Embedding Service", version="1.0.0")

# Models are loaded at startup, not lazily
_text_model: Optional[Any] = None
_image_model: Optional[ImageEncoderProtocol] = None
open_text_model = True
open_image_model = True  # Enable image embedding when using clip-as-service

_text_encode_lock = threading.Lock()
_image_encode_lock = threading.Lock()


@app.on_event("startup")
def load_models():
    """Load models at service startup to avoid first-request latency."""
    global _text_model, _image_model

    logger.info("Loading embedding models at startup...")

    # Load text model
    if open_text_model:
        try:
            from embeddings.qwen3_model import Qwen3TextModel

            logger.info(f"Loading text model: {CONFIG.TEXT_MODEL_ID}")
            _text_model = Qwen3TextModel(model_id=CONFIG.TEXT_MODEL_ID)
            logger.info("Text model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load text model: {e}", exc_info=True)
            raise
    

    # Load image model: clip-as-service (recommended) or local CN-CLIP
    # IMPORTANT: failures here should NOT prevent the whole service from starting.
    # If image model cannot be loaded, we keep `_image_model` as None and only
    # disable /embed/image while keeping /embed/text fully functional.
    if open_image_model:
        try:
            if CONFIG.USE_CLIP_AS_SERVICE:
                from embeddings.clip_as_service_encoder import ClipAsServiceImageEncoder

                logger.info(f"Loading image encoder via clip-as-service: {CONFIG.CLIP_AS_SERVICE_SERVER}")
                _image_model = ClipAsServiceImageEncoder(
                    server=CONFIG.CLIP_AS_SERVICE_SERVER,
                    batch_size=CONFIG.IMAGE_BATCH_SIZE,
                )
                logger.info("Image model (clip-as-service) loaded successfully")
            else:
                from embeddings.clip_model import ClipImageModel

                logger.info(f"Loading local image model: {CONFIG.IMAGE_MODEL_NAME} (device: {CONFIG.IMAGE_DEVICE})")
                _image_model = ClipImageModel(
                    model_name=CONFIG.IMAGE_MODEL_NAME,
                    device=CONFIG.IMAGE_DEVICE,
                )
                logger.info("Image model (local CN-CLIP) loaded successfully")
        except Exception as e:
            logger.error(
                "Failed to load image model; image embeddings will be disabled but text embeddings remain available: %s",
                e,
                exc_info=True,
            )
            _image_model = None

    logger.info("All embedding models loaded successfully, service ready")


def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]:
    if embedding is None:
        return None
    if not isinstance(embedding, np.ndarray):
        embedding = np.array(embedding, dtype=np.float32)
    if embedding.ndim != 1:
        embedding = embedding.reshape(-1)
    return embedding.astype(np.float32).tolist()


@app.get("/health")
def health() -> Dict[str, Any]:
    """Health check endpoint. Returns status and model loading state."""
    return {
        "status": "ok",
        "text_model_loaded": _text_model is not None,
        "image_model_loaded": _image_model is not None,
    }


@app.post("/embed/text")
def embed_text(texts: List[str]) -> List[Optional[List[float]]]:
    if _text_model is None:
        raise RuntimeError("Text model not loaded")
    out: List[Optional[List[float]]] = [None] * len(texts)

    indexed_texts: List[tuple] = []
    for i, t in enumerate(texts):
        if t is None:
            continue
        if not isinstance(t, str):
            t = str(t)
        t = t.strip()
        if not t:
            continue
        indexed_texts.append((i, t))

    if not indexed_texts:
        return out

    batch_texts = [t for _, t in indexed_texts]
    try:
        with _text_encode_lock:
            embs = _text_model.encode_batch(
                batch_texts,
                batch_size=int(CONFIG.TEXT_BATCH_SIZE),
                device=CONFIG.TEXT_DEVICE,
                normalize_embeddings=bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS),
            )
        for j, (idx, _t) in enumerate(indexed_texts):
            out[idx] = _as_list(embs[j])
    except Exception:
        # keep Nones
        pass
    return out


@app.post("/embed/image")
def embed_image(images: List[str]) -> List[Optional[List[float]]]:
    if _image_model is None:
        # Graceful degradation: keep API shape but return all None
        logger.warning("embed_image called but image model is not loaded; returning all None vectors")
        return [None] * len(images)
    out: List[Optional[List[float]]] = [None] * len(images)

    # Normalize inputs
    urls = []
    indices = []
    for i, url_or_path in enumerate(images):
        if url_or_path is None:
            continue
        if not isinstance(url_or_path, str):
            url_or_path = str(url_or_path)
        url_or_path = url_or_path.strip()
        if url_or_path:
            urls.append(url_or_path)
            indices.append(i)

    if not urls:
        return out

    with _image_encode_lock:
        try:
            # Both ClipAsServiceImageEncoder and ClipImageModel implement encode_image_urls(urls, batch_size)
            vectors = _image_model.encode_image_urls(urls, batch_size=CONFIG.IMAGE_BATCH_SIZE)
            for j, idx in enumerate(indices):
                out[idx] = _as_list(vectors[j] if j < len(vectors) else None)
        except Exception:
            for idx in indices:
                out[idx] = None
    return out