server.py 3.78 KB
"""
FastAPI service for BGE reranking.

POST /rerank
Request:
{
  "query": "...",
  "docs": ["doc1", "doc2", ...]
}

Response:
{
  "scores": [0.98, 0.12, ...],
  "meta": {...}
}
"""

import logging
import time
from typing import Any, Dict, List, Optional

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field

from reranker.bge_reranker import BGEReranker
from reranker.config import CONFIG

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s | %(message)s",
)
logger = logging.getLogger("reranker.service")

app = FastAPI(title="SearchEngine Reranker Service", version="1.0.0")

_reranker: Optional[BGEReranker] = None


class RerankRequest(BaseModel):
    query: str = Field(..., description="Search query")
    docs: List[str] = Field(..., description="Documents/passages to rerank")
    normalize: Optional[bool] = Field(
        default=CONFIG.NORMALIZE, description="Apply sigmoid normalization"
    )


class RerankResponse(BaseModel):
    scores: List[float] = Field(..., description="Scores aligned to input docs order")
    meta: Dict[str, Any] = Field(default_factory=dict)


@app.on_event("startup")
def load_model() -> None:
    global _reranker
    logger.info("Starting reranker service on port %s", CONFIG.PORT)
    try:
        _reranker = BGEReranker(
            model_name=CONFIG.MODEL_NAME,
            device=CONFIG.DEVICE,
            batch_size=CONFIG.BATCH_SIZE,
            use_fp16=CONFIG.USE_FP16,
            max_length=CONFIG.MAX_LENGTH,
            cache_dir=CONFIG.CACHE_DIR,
            enable_warmup=CONFIG.ENABLE_WARMUP,
        )
        logger.info(
            "Reranker ready | model=%s device=%s fp16=%s batch=%s max_len=%s",
            CONFIG.MODEL_NAME,
            _reranker.device,
            _reranker.use_fp16,
            _reranker.batch_size,
            _reranker.max_length,
        )
    except Exception as exc:
        logger.error("Failed to initialize reranker: %s", exc, exc_info=True)
        raise


@app.get("/health")
def health() -> Dict[str, Any]:
    return {
        "status": "ok" if _reranker is not None else "unavailable",
        "model_loaded": _reranker is not None,
        "model": CONFIG.MODEL_NAME,
        "device": CONFIG.DEVICE,
    }


@app.post("/rerank", response_model=RerankResponse)
def rerank(request: RerankRequest) -> RerankResponse:
    if _reranker is None:
        raise HTTPException(status_code=503, detail="Reranker model not loaded")

    query = (request.query or "").strip()
    if not query:
        raise HTTPException(status_code=400, detail="query cannot be empty")

    if request.docs is None or len(request.docs) == 0:
        raise HTTPException(status_code=400, detail="docs cannot be empty")

    if len(request.docs) > CONFIG.MAX_DOCS:
        raise HTTPException(
            status_code=400,
            detail=f"Too many docs: {len(request.docs)} > {CONFIG.MAX_DOCS}",
        )

    normalize = CONFIG.NORMALIZE if request.normalize is None else bool(request.normalize)

    start_ts = time.time()
    logger.info(
        "Rerank request | docs=%d normalize=%s",
        len(request.docs),
        normalize,
    )
    scores, meta = _reranker.score_with_meta(query, request.docs, normalize=normalize)
    meta = dict(meta)
    meta.update({"service_elapsed_ms": round((time.time() - start_ts) * 1000.0, 3)})
    logger.info(
        "Rerank done | docs=%d unique=%s dedup=%s elapsed_ms=%s",
        meta.get("input_docs"),
        meta.get("unique_docs"),
        meta.get("dedup_ratio"),
        meta.get("service_elapsed_ms"),
    )

    return RerankResponse(scores=scores, meta=meta)


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        "reranker.server:app",
        host=CONFIG.HOST,
        port=CONFIG.PORT,
        reload=False,
        log_level="info",
    )