server.py 3.33 KB
"""
Embedding service (FastAPI).

API (simple list-in, list-out; aligned by index; failures -> null):
- POST /embed/text   body: ["text1", "text2", ...] -> [[...], null, ...]
- POST /embed/image  body: ["url_or_path1", ...]  -> [[...], null, ...]
"""

import threading
from typing import Any, Dict, List, Optional

import numpy as np
from fastapi import FastAPI

from embeddings.config import CONFIG
from embeddings.bge_model import BgeTextModel
from embeddings.clip_model import ClipImageModel


app = FastAPI(title="SearchEngine Embedding Service", version="1.0.0")

_text_model = None
_image_model = None

_text_init_lock = threading.Lock()
_image_init_lock = threading.Lock()

_text_encode_lock = threading.Lock()
_image_encode_lock = threading.Lock()


def _get_text_model():
    global _text_model
    if _text_model is None:
        with _text_init_lock:
            if _text_model is None:
                _text_model = BgeTextModel(model_dir=CONFIG.TEXT_MODEL_DIR)
    return _text_model


def _get_image_model():
    global _image_model
    if _image_model is None:
        with _image_init_lock:
            if _image_model is None:
                _image_model = ClipImageModel(
                    model_name=CONFIG.IMAGE_MODEL_NAME,
                    device=CONFIG.IMAGE_DEVICE,
                )
    return _image_model


def _as_list(embedding: Optional[np.ndarray]) -> Optional[List[float]]:
    if embedding is None:
        return None
    if not isinstance(embedding, np.ndarray):
        embedding = np.array(embedding, dtype=np.float32)
    if embedding.ndim != 1:
        embedding = embedding.reshape(-1)
    return embedding.astype(np.float32).tolist()


@app.get("/health")
def health() -> Dict[str, Any]:
    return {"status": "ok"}


@app.post("/embed/text")
def embed_text(texts: List[str]) -> List[Optional[List[float]]]:
    model = _get_text_model()
    out: List[Optional[List[float]]] = [None] * len(texts)

    indexed_texts: List[tuple] = []
    for i, t in enumerate(texts):
        if t is None:
            continue
        if not isinstance(t, str):
            t = str(t)
        t = t.strip()
        if not t:
            continue
        indexed_texts.append((i, t))

    if not indexed_texts:
        return out

    batch_texts = [t for _, t in indexed_texts]
    try:
        with _text_encode_lock:
            embs = model.encode_batch(
                batch_texts, batch_size=int(CONFIG.TEXT_BATCH_SIZE), device=CONFIG.TEXT_DEVICE
            )
        for j, (idx, _t) in enumerate(indexed_texts):
            out[idx] = _as_list(embs[j])
    except Exception:
        # keep Nones
        pass
    return out


@app.post("/embed/image")
def embed_image(images: List[str]) -> List[Optional[List[float]]]:
    model = _get_image_model()
    out: List[Optional[List[float]]] = [None] * len(images)

    with _image_encode_lock:
        for i, url_or_path in enumerate(images):
            try:
                if url_or_path is None:
                    continue
                if not isinstance(url_or_path, str):
                    url_or_path = str(url_or_path)
                url_or_path = url_or_path.strip()
                if not url_or_path:
                    continue
                emb = model.encode_image_from_url(url_or_path)
                out[i] = _as_list(emb)
            except Exception:
                out[i] = None
    return out