server.py 5.78 KB
"""
Embedding service (FastAPI).

API (simple list-in, list-out; aligned by index):
- POST /embed/text   body: ["text1", "text2", ...] -> [[...], ...]
- POST /embed/image  body: ["url_or_path1", ...]  -> [[...], ...]
"""

import logging
import threading
from typing import Any, Dict, List, Optional

import numpy as np
from fastapi import FastAPI, HTTPException

from embeddings.config import CONFIG
from embeddings.protocols import ImageEncoderProtocol

logger = logging.getLogger(__name__)

app = FastAPI(title="saas-search Embedding Service", version="1.0.0")

# Models are loaded at startup, not lazily
_text_model: Optional[Any] = None
_image_model: Optional[ImageEncoderProtocol] = None
open_text_model = True
open_image_model = True  # Enable image embedding when using clip-as-service

_text_encode_lock = threading.Lock()
_image_encode_lock = threading.Lock()


@app.on_event("startup")
def load_models():
    """Load models at service startup to avoid first-request latency."""
    global _text_model, _image_model

    logger.info("Loading embedding models at startup...")

    # Load text model
    if open_text_model:
        try:
            from embeddings.qwen3_model import Qwen3TextModel

            logger.info(f"Loading text model: {CONFIG.TEXT_MODEL_ID}")
            _text_model = Qwen3TextModel(model_id=CONFIG.TEXT_MODEL_ID)
            logger.info("Text model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load text model: {e}", exc_info=True)
            raise
    

    # Load image model: clip-as-service (recommended) or local CN-CLIP
    if open_image_model:
        try:
            if CONFIG.USE_CLIP_AS_SERVICE:
                from embeddings.clip_as_service_encoder import ClipAsServiceImageEncoder

                logger.info(f"Loading image encoder via clip-as-service: {CONFIG.CLIP_AS_SERVICE_SERVER}")
                _image_model = ClipAsServiceImageEncoder(
                    server=CONFIG.CLIP_AS_SERVICE_SERVER,
                    batch_size=CONFIG.IMAGE_BATCH_SIZE,
                )
                logger.info("Image model (clip-as-service) loaded successfully")
            else:
                from embeddings.clip_model import ClipImageModel

                logger.info(f"Loading local image model: {CONFIG.IMAGE_MODEL_NAME} (device: {CONFIG.IMAGE_DEVICE})")
                _image_model = ClipImageModel(
                    model_name=CONFIG.IMAGE_MODEL_NAME,
                    device=CONFIG.IMAGE_DEVICE,
                )
                logger.info("Image model (local CN-CLIP) loaded successfully")
        except Exception as e:
            logger.error("Failed to load image model: %s", e, exc_info=True)
            raise

    logger.info("All embedding models loaded successfully, service ready")


def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]:
    if embedding is None:
        return None
    if not isinstance(embedding, np.ndarray):
        embedding = np.array(embedding, dtype=np.float32)
    if embedding.ndim != 1:
        embedding = embedding.reshape(-1)
    return embedding.astype(np.float32).tolist()


@app.get("/health")
def health() -> Dict[str, Any]:
    """Health check endpoint. Returns status and model loading state."""
    return {
        "status": "ok",
        "text_model_loaded": _text_model is not None,
        "image_model_loaded": _image_model is not None,
    }


@app.post("/embed/text")
def embed_text(texts: List[str]) -> List[Optional[List[float]]]:
    if _text_model is None:
        raise RuntimeError("Text model not loaded")
    normalized: List[str] = []
    for i, t in enumerate(texts):
        if not isinstance(t, str):
            raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: must be string")
        s = t.strip()
        if not s:
            raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: empty string")
        normalized.append(s)

    with _text_encode_lock:
        embs = _text_model.encode_batch(
            normalized,
            batch_size=int(CONFIG.TEXT_BATCH_SIZE),
            device=CONFIG.TEXT_DEVICE,
            normalize_embeddings=bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS),
        )
    if embs is None or len(embs) != len(normalized):
        raise RuntimeError(
            f"Text model response length mismatch: expected {len(normalized)}, "
            f"got {0 if embs is None else len(embs)}"
        )
    out: List[Optional[List[float]]] = []
    for i, emb in enumerate(embs):
        vec = _as_list(emb)
        if vec is None:
            raise RuntimeError(f"Text model returned empty embedding for index {i}")
        out.append(vec)
    return out


@app.post("/embed/image")
def embed_image(images: List[str]) -> List[Optional[List[float]]]:
    if _image_model is None:
        raise RuntimeError("Image model not loaded")
    urls: List[str] = []
    for i, url_or_path in enumerate(images):
        if not isinstance(url_or_path, str):
            raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: must be string URL/path")
        s = url_or_path.strip()
        if not s:
            raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: empty URL/path")
        urls.append(s)

    with _image_encode_lock:
        vectors = _image_model.encode_image_urls(urls, batch_size=CONFIG.IMAGE_BATCH_SIZE)
    if vectors is None or len(vectors) != len(urls):
        raise RuntimeError(
            f"Image model response length mismatch: expected {len(urls)}, "
            f"got {0 if vectors is None else len(vectors)}"
        )
    out: List[Optional[List[float]]] = []
    for i, vec in enumerate(vectors):
        out_vec = _as_list(vec)
        if out_vec is None:
            raise RuntimeError(f"Image model returned empty embedding for index {i}")
        out.append(out_vec)
    return out