server.py 6.39 KB
"""
Reranker service - unified /rerank API backed by pluggable backends
(BGE, Qwen3-vLLM, Qwen3-Transformers, DashScope cloud rerank).

POST /rerank
Request: { "query": "...", "docs": ["doc1", "doc2", ...], "normalize": optional bool }
Response: { "scores": [float], "meta": {...} }

Backend selected via config: services.rerank.backend
(bge | qwen3_vllm | qwen3_vllm_score | qwen3_transformers | qwen3_transformers_packed | qwen3_gguf | qwen3_gguf_06b | dashscope_rerank), env RERANK_BACKEND.
"""

import logging
import os
import time
from typing import Any, Dict, List, Optional

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field

from config.services_config import get_rerank_backend_config
from reranker.backends import RerankBackendProtocol, get_rerank_backend
from reranker.config import CONFIG

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s | %(message)s",
)
logger = logging.getLogger("reranker.service")

app = FastAPI(title="saas-search Reranker Service", version="1.0.0")

_reranker: Optional[RerankBackendProtocol] = None
_backend_name: str = ""
_LOG_DOC_PREVIEW_COUNT = max(1, int(os.getenv("RERANK_LOG_DOC_PREVIEW_COUNT", "3")))
_LOG_TEXT_PREVIEW_CHARS = max(32, int(os.getenv("RERANK_LOG_TEXT_PREVIEW_CHARS", "180")))


def _compact_preview(text: str, max_chars: int) -> str:
    compact = " ".join((text or "").split())
    if len(compact) <= max_chars:
        return compact
    return compact[:max_chars] + "..."


def _preview_docs(docs: List[str], max_items: int, max_chars: int) -> List[Dict[str, Any]]:
    previews: List[Dict[str, Any]] = []
    for idx, doc in enumerate(docs[:max_items]):
        previews.append(
            {
                "idx": idx,
                "len": len(doc),
                "preview": _compact_preview(doc, max_chars),
            }
        )
    return previews


class RerankRequest(BaseModel):
    query: str = Field(..., description="Search query")
    docs: List[str] = Field(..., description="Documents/passages to rerank")
    normalize: Optional[bool] = Field(
        default=CONFIG.NORMALIZE, description="Apply sigmoid normalization"
    )
    top_n: Optional[int] = Field(
        default=None,
        description="Optional top_n hint for backends that support partial ranking",
    )


class RerankResponse(BaseModel):
    scores: List[float] = Field(..., description="Scores aligned to input docs order")
    meta: Dict[str, Any] = Field(default_factory=dict)


@app.on_event("startup")
def load_model() -> None:
    global _reranker, _backend_name
    logger.info("Starting reranker service on port %s", CONFIG.PORT)
    try:
        backend_name, backend_cfg = get_rerank_backend_config()
        _backend_name = backend_name
        _reranker = get_rerank_backend(backend_name, backend_cfg)
        model_info = getattr(_reranker, "_model_name", None) or backend_cfg.get("model_name", backend_name)
        logger.info(
            "Reranker ready | backend=%s model=%s",
            _backend_name,
            model_info,
        )
    except Exception as exc:
        logger.error("Failed to initialize reranker: %s", exc, exc_info=True)
        raise


@app.get("/health")
def health() -> Dict[str, Any]:
    model_info = ""
    if _reranker is not None:
        model_info = getattr(_reranker, "_model_name", None) or getattr(
            _reranker, "_config", {}
        ).get("model_name", _backend_name)
    payload: Dict[str, Any] = {
        "status": "ok" if _reranker is not None else "unavailable",
        "model_loaded": _reranker is not None,
        "model": model_info,
        "backend": _backend_name,
    }
    if _reranker is not None:
        _fmt = getattr(_reranker, "_instruction_format", None)
        if _fmt is not None:
            payload["instruction_format"] = _fmt
    return payload


@app.post("/rerank", response_model=RerankResponse)
def rerank(request: RerankRequest) -> RerankResponse:
    if _reranker is None:
        raise HTTPException(status_code=503, detail="Reranker model not loaded")

    query = (request.query or "").strip()
    if not query:
        raise HTTPException(status_code=400, detail="query cannot be empty")

    if request.docs is None or len(request.docs) == 0:
        raise HTTPException(status_code=400, detail="docs cannot be empty")

    if len(request.docs) > CONFIG.MAX_DOCS:
        raise HTTPException(
            status_code=400,
            detail=f"Too many docs: {len(request.docs)} > {CONFIG.MAX_DOCS}",
        )
    if request.top_n is not None and int(request.top_n) <= 0:
        raise HTTPException(status_code=400, detail="top_n must be > 0")

    normalize = CONFIG.NORMALIZE if request.normalize is None else bool(request.normalize)
    top_n = int(request.top_n) if request.top_n is not None else None

    start_ts = time.time()
    logger.info(
        "Rerank request | docs=%d normalize=%s query_len=%d query=%r doc_preview=%s",
        len(request.docs),
        normalize,
        len(query),
        _compact_preview(query, _LOG_TEXT_PREVIEW_CHARS),
        _preview_docs(request.docs, _LOG_DOC_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS),
    )
    if top_n is not None and hasattr(_reranker, "score_with_meta_topn"):
        scores, meta = getattr(_reranker, "score_with_meta_topn")(
            query,
            request.docs,
            normalize=normalize,
            top_n=top_n,
        )
    else:
        scores, meta = _reranker.score_with_meta(query, request.docs, normalize=normalize)
    meta = dict(meta)
    if top_n is not None:
        meta.setdefault("requested_top_n", top_n)
    meta.update({"service_elapsed_ms": round((time.time() - start_ts) * 1000.0, 3)})
    score_preview = [round(float(s), 6) for s in scores[:_LOG_DOC_PREVIEW_COUNT]]
    logger.info(
        "Rerank done | docs=%d unique=%s dedup=%s elapsed_ms=%s batches=%s batchsize=%s batch_concurrency=%s query=%r score_preview=%s",
        meta.get("input_docs"),
        meta.get("unique_docs"),
        meta.get("dedup_ratio"),
        meta.get("service_elapsed_ms"),
        meta.get("batches"),
        meta.get("batchsize"),
        meta.get("batch_concurrency"),
        _compact_preview(query, _LOG_TEXT_PREVIEW_CHARS),
        score_preview,
    )

    return RerankResponse(scores=scores, meta=meta)


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        "reranker.server:app",
        host=CONFIG.HOST,
        port=CONFIG.PORT,
        reload=False,
        log_level="info",
    )