rerank.py 2.12 KB
"""
Rerank provider - HTTP service (vllm reserved).
"""

from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional, Tuple

import requests

from config.services_config import get_rerank_config, get_rerank_service_url

logger = logging.getLogger(__name__)


class HttpRerankProvider:
    """Rerank via HTTP service."""

    def __init__(self, service_url: str):
        self.service_url = (service_url or "").rstrip("/")

    def rerank(
        self,
        query: str,
        docs: List[str],
        timeout_sec: float,
    ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]:
        if not docs:
            return [], {}
        try:
            payload = {"query": (query or "").strip(), "docs": docs}
            response = requests.post(self.service_url, json=payload, timeout=timeout_sec)
            if response.status_code != 200:
                logger.warning(
                    "Rerank service HTTP %s: %s",
                    response.status_code,
                    (response.text or "")[:200],
                )
                return None, None
            data = response.json()
            scores = data.get("scores")
            if not isinstance(scores, list):
                return None, None
            return scores, data.get("meta") or {}
        except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as exc:
            logger.warning(
                "Rerank request timed out after %.1fs (docs=%d); returning ES order. %s",
                timeout_sec,
                len(docs),
                exc,
            )
            return None, None
        except Exception as exc:
            logger.warning("Rerank request failed: %s", exc, exc_info=True)
            return None, None


def create_rerank_provider() -> HttpRerankProvider:
    """Create rerank provider from services config."""
    cfg = get_rerank_config()
    provider = (cfg.provider or "http").strip().lower()

    if provider == "vllm":
        logger.warning("rerank provider 'vllm' is reserved, using HTTP.")

    url = get_rerank_service_url()
    return HttpRerankProvider(service_url=url)