bge_reranker.py 9.14 KB
"""
Minimal BGE reranker for pairwise scoring (query, doc).

Features:
- Model loading with optional FP16
- Length-based sorting to reduce padding waste
- Deduplication to avoid redundant inference
- Scores returned in original doc order
"""

import logging
import math
import threading
import time
from typing import Any, Dict, List, Optional, Tuple

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

logger = logging.getLogger("reranker.core")


class BGEReranker:
    def __init__(
        self,
        model_name: str = "BAAI/bge-reranker-v2-m3",
        device: Optional[str] = None,
        batch_size: int = 64,
        use_fp16: bool = True,
        max_length: int = 512,
        cache_dir: str = "./model_cache",
        enable_warmup: bool = True,
    ) -> None:
        self.model_name = model_name
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_size = max(1, int(batch_size))
        self.max_length = int(max_length)
        self.use_fp16 = bool(use_fp16 and self.device == "cuda")
        self._lock = threading.Lock()

        logger.info(
            "[BGE_RERANKER] Loading model %s on %s (fp16=%s)",
            self.model_name,
            self.device,
            self.use_fp16,
        )

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, trust_remote_code=True, cache_dir=cache_dir
        )
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name, trust_remote_code=True, cache_dir=cache_dir
        )

        self.model = self.model.to(self.device)
        self.model.eval()

        if self.use_fp16:
            self.model = self.model.half()

        if self.device == "cuda":
            torch.backends.cudnn.benchmark = True

        if enable_warmup:
            self._warmup()

        logger.info(
            "[BGE_RERANKER] Model ready | model=%s device=%s fp16=%s batch=%s max_len=%s",
            self.model_name,
            self.device,
            self.use_fp16,
            self.batch_size,
            self.max_length,
        )

    def _warmup(self) -> None:
        try:
            with torch.inference_mode():
                pairs = [["warmup", "warmup"]]
                inputs = self.tokenizer(
                    pairs,
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                    max_length=self.max_length,
                )
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                if self.use_fp16:
                    inputs = {
                        k: (v.half() if v.dtype == torch.float32 else v)
                        for k, v in inputs.items()
                    }
                _ = self.model(**inputs, return_dict=True).logits
                if self.device == "cuda":
                    torch.cuda.synchronize()
        except Exception as exc:
            logger.warning("[BGE_RERANKER] Warmup failed: %s", exc)

    def score(self, query: str, docs: List[str], normalize: bool = True) -> List[float]:
        scores, _meta = self.score_with_meta(query, docs, normalize=normalize)
        return scores

    def score_with_meta(
        self, query: str, docs: List[str], normalize: bool = True
    ) -> Tuple[List[float], Dict[str, Any]]:
        start_ts = time.time()

        if docs is None:
            docs = []

        query = "" if query is None else str(query).strip()
        total_docs = len(docs)
        output_scores: List[float] = [0.0] * total_docs

        # Log request summary (query + first 3 docs preview)
        preview_docs: List[str] = []
        for d in docs[:3]:
            preview_docs.append("" if d is None else str(d))
        logger.info(
            "[BGE_RERANKER] Request | query=%r | docs=%d | docs_preview=%s",
            query,
            total_docs,
            preview_docs,
        )

        indexed_docs: List[Tuple[int, str]] = []
        for i, doc in enumerate(docs):
            if doc is None:
                continue
            text = str(doc).strip()
            if not text:
                continue
            indexed_docs.append((i, text))

        if not query or not indexed_docs:
            elapsed_ms = (time.time() - start_ts) * 1000.0
            return output_scores, {
                "input_docs": total_docs,
                "usable_docs": len(indexed_docs),
                "unique_docs": 0,
                "dedup_ratio": 0.0,
                "elapsed_ms": round(elapsed_ms, 3),
            }

        # Sort by estimated length + text to cluster similar lengths
        indexed_docs.sort(key=lambda x: (len(x[1]), x[1]))

        unique_texts: List[str] = []
        position_to_unique: List[int] = []
        prev_text: Optional[str] = None

        for _idx, text in indexed_docs:
            if text != prev_text:
                unique_texts.append(text)
                prev_text = text
            position_to_unique.append(len(unique_texts) - 1)

        logger.debug(
            "[BGE_RERANKER] Preprocess | input=%d usable=%d unique=%d",
            total_docs,
            len(indexed_docs),
            len(unique_texts),
        )

        unique_scores = self._score_unique(
            query=query, passages=unique_texts, normalize=normalize
        )

        for (orig_idx, _text), unique_idx in zip(indexed_docs, position_to_unique):
            output_scores[orig_idx] = float(unique_scores[unique_idx])

        # Log per-doc scores (aligned to original docs order)
        if 0:
            try:
                lines = []
                for i, d in enumerate(docs[:100]):
                    lines.append(f"{output_scores[i]},{'' if d is None else str(d)}")
                logger.info("[BGE_RERANKER] query:%s Scores (score,doc):\n%s", query, "\n".join(lines))
            except Exception:
                pass

        elapsed_ms = (time.time() - start_ts) * 1000.0
        dedup_ratio = 0.0
        if indexed_docs:
            dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed_docs)))

        meta = {
            "input_docs": total_docs,
            "usable_docs": len(indexed_docs),
            "unique_docs": len(unique_texts),
            "dedup_ratio": round(dedup_ratio, 4),
            "elapsed_ms": round(elapsed_ms, 3),
            "model": self.model_name,
            "device": self.device,
            "fp16": self.use_fp16,
            "batch_size": self.batch_size,
            "max_length": self.max_length,
            "normalize": normalize,
        }

        logger.info(
            "[BGE_RERANKER] Done | input=%d usable=%d unique=%d dedup=%s elapsed_ms=%s",
            meta["input_docs"],
            meta["usable_docs"],
            meta["unique_docs"],
            meta["dedup_ratio"],
            meta["elapsed_ms"],
        )

        return output_scores, meta

    def _compute_optimal_batch_size(self, total: int) -> int:
        if total <= 0:
            return 1
        current_batch_size = self.batch_size + 8
        current_batch_count = math.ceil(total / current_batch_size)

        optimal_batch_size = current_batch_size
        test_batch_size = current_batch_size - 4

        while test_batch_size > 0:
            test_batch_count = math.ceil(total / test_batch_size)
            if test_batch_count <= current_batch_count:
                optimal_batch_size = test_batch_size
                test_batch_size -= 4
            else:
                break

        return max(1, optimal_batch_size)

    def _score_unique(
        self, query: str, passages: List[str], normalize: bool = True
    ) -> List[float]:
        if not passages:
            return []

        optimal_batch_size = self._compute_optimal_batch_size(len(passages))

        logger.info(
            "[BGE_RERANKER] Reranking %d unique passages | batch=%d | device=%s | fp16=%s",
            len(passages),
            optimal_batch_size,
            self.device,
            self.use_fp16,
        )

        scores: List[float] = []

        with self._lock:
            for i in range(0, len(passages), optimal_batch_size):
                batch_passages = passages[i : i + optimal_batch_size]
                pairs = [[query, passage] for passage in batch_passages]

                with torch.inference_mode():
                    inputs = self.tokenizer(
                        pairs,
                        padding=True,
                        truncation=True,
                        return_tensors="pt",
                        max_length=self.max_length,
                        add_special_tokens=True,
                    )
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}

                    if self.use_fp16:
                        inputs = {
                            k: (v.half() if v.dtype == torch.float32 else v)
                            for k, v in inputs.items()
                        }

                    logits = self.model(**inputs, return_dict=True).logits.view(-1).float()
                    if normalize:
                        logits = torch.sigmoid(logits)
                    batch_scores = logits.detach().cpu().numpy().tolist()
                    scores.extend(batch_scores)

        return scores