server.py 4.07 KB
"""
Reranker service - unified /rerank API backed by pluggable backends (BGE, Qwen3-vLLM).

POST /rerank
Request: { "query": "...", "docs": ["doc1", "doc2", ...], "normalize": optional bool }
Response: { "scores": [float], "meta": {...} }

Backend selected via config: services.rerank.backend (bge | qwen3_vllm), env RERANK_BACKEND.
"""

import logging
import time
from typing import Any, Dict, List, Optional

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field

from config.services_config import get_rerank_backend_config
from reranker.backends import RerankBackendProtocol, get_rerank_backend
from reranker.config import CONFIG

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s | %(message)s",
)
logger = logging.getLogger("reranker.service")

app = FastAPI(title="saas-search Reranker Service", version="1.0.0")

_reranker: Optional[RerankBackendProtocol] = None
_backend_name: str = ""


class RerankRequest(BaseModel):
    query: str = Field(..., description="Search query")
    docs: List[str] = Field(..., description="Documents/passages to rerank")
    normalize: Optional[bool] = Field(
        default=CONFIG.NORMALIZE, description="Apply sigmoid normalization"
    )


class RerankResponse(BaseModel):
    scores: List[float] = Field(..., description="Scores aligned to input docs order")
    meta: Dict[str, Any] = Field(default_factory=dict)


@app.on_event("startup")
def load_model() -> None:
    global _reranker, _backend_name
    logger.info("Starting reranker service on port %s", CONFIG.PORT)
    try:
        backend_name, backend_cfg = get_rerank_backend_config()
        _backend_name = backend_name
        _reranker = get_rerank_backend(backend_name, backend_cfg)
        model_info = getattr(_reranker, "_model_name", None) or backend_cfg.get("model_name", backend_name)
        logger.info(
            "Reranker ready | backend=%s model=%s",
            _backend_name,
            model_info,
        )
    except Exception as exc:
        logger.error("Failed to initialize reranker: %s", exc, exc_info=True)
        raise


@app.get("/health")
def health() -> Dict[str, Any]:
    model_info = ""
    if _reranker is not None:
        model_info = getattr(_reranker, "_model_name", None) or getattr(
            _reranker, "_config", {}
        ).get("model_name", _backend_name)
    return {
        "status": "ok" if _reranker is not None else "unavailable",
        "model_loaded": _reranker is not None,
        "model": model_info,
        "backend": _backend_name,
    }


@app.post("/rerank", response_model=RerankResponse)
def rerank(request: RerankRequest) -> RerankResponse:
    if _reranker is None:
        raise HTTPException(status_code=503, detail="Reranker model not loaded")

    query = (request.query or "").strip()
    if not query:
        raise HTTPException(status_code=400, detail="query cannot be empty")

    if request.docs is None or len(request.docs) == 0:
        raise HTTPException(status_code=400, detail="docs cannot be empty")

    if len(request.docs) > CONFIG.MAX_DOCS:
        raise HTTPException(
            status_code=400,
            detail=f"Too many docs: {len(request.docs)} > {CONFIG.MAX_DOCS}",
        )

    normalize = CONFIG.NORMALIZE if request.normalize is None else bool(request.normalize)

    start_ts = time.time()
    logger.info(
        "Rerank request | docs=%d normalize=%s",
        len(request.docs),
        normalize,
    )
    scores, meta = _reranker.score_with_meta(query, request.docs, normalize=normalize)
    meta = dict(meta)
    meta.update({"service_elapsed_ms": round((time.time() - start_ts) * 1000.0, 3)})
    logger.info(
        "Rerank done | docs=%d unique=%s dedup=%s elapsed_ms=%s",
        meta.get("input_docs"),
        meta.get("unique_docs"),
        meta.get("dedup_ratio"),
        meta.get("service_elapsed_ms"),
    )

    return RerankResponse(scores=scores, meta=meta)


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        "reranker.server:app",
        host=CONFIG.HOST,
        port=CONFIG.PORT,
        reload=False,
        log_level="info",
    )