""" Minimal BGE reranker for pairwise scoring (query, doc). Features: - Model loading with optional FP16 - Length-based sorting to reduce padding waste - Deduplication to avoid redundant inference - Scores returned in original doc order """ import logging import math import threading import time from typing import Any, Dict, List, Optional, Tuple import torch from modelscope import AutoModelForSequenceClassification, AutoTokenizer logger = logging.getLogger("reranker.core") class BGEReranker: def __init__( self, model_name: str = "BAAI/bge-reranker-v2-m3", device: Optional[str] = None, batch_size: int = 64, use_fp16: bool = True, max_length: int = 512, cache_dir: str = "./model_cache", enable_warmup: bool = True, ) -> None: self.model_name = model_name self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.batch_size = max(1, int(batch_size)) self.max_length = int(max_length) self.use_fp16 = bool(use_fp16 and self.device == "cuda") self._lock = threading.Lock() logger.info( "[BGE_RERANKER] Loading model %s on %s (fp16=%s)", self.model_name, self.device, self.use_fp16, ) self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True, cache_dir=cache_dir ) self.model = AutoModelForSequenceClassification.from_pretrained( self.model_name, trust_remote_code=True, cache_dir=cache_dir ) self.model = self.model.to(self.device) self.model.eval() if self.use_fp16: self.model = self.model.half() if self.device == "cuda": torch.backends.cudnn.benchmark = True if enable_warmup: self._warmup() logger.info( "[BGE_RERANKER] Model ready | model=%s device=%s fp16=%s batch=%s max_len=%s", self.model_name, self.device, self.use_fp16, self.batch_size, self.max_length, ) def _warmup(self) -> None: try: with torch.inference_mode(): pairs = [["warmup", "warmup"]] inputs = self.tokenizer( pairs, padding=True, truncation=True, return_tensors="pt", max_length=self.max_length, ) inputs = {k: v.to(self.device) for k, v in inputs.items()} if self.use_fp16: inputs = { k: (v.half() if v.dtype == torch.float32 else v) for k, v in inputs.items() } _ = self.model(**inputs, return_dict=True).logits if self.device == "cuda": torch.cuda.synchronize() except Exception as exc: logger.warning("[BGE_RERANKER] Warmup failed: %s", exc) def score(self, query: str, docs: List[str], normalize: bool = True) -> List[float]: scores, _meta = self.score_with_meta(query, docs, normalize=normalize) return scores def score_with_meta( self, query: str, docs: List[str], normalize: bool = True ) -> Tuple[List[float], Dict[str, Any]]: start_ts = time.time() if docs is None: docs = [] query = "" if query is None else str(query).strip() total_docs = len(docs) output_scores: List[float] = [0.0] * total_docs # Log request summary (query + first 3 docs preview) preview_docs: List[str] = [] for d in docs[:3]: preview_docs.append("" if d is None else str(d)) logger.info( "[BGE_RERANKER] Request | query=%r | docs=%d | docs_preview=%s", query, total_docs, preview_docs, ) indexed_docs: List[Tuple[int, str]] = [] for i, doc in enumerate(docs): if doc is None: continue text = str(doc).strip() if not text: continue indexed_docs.append((i, text)) if not query or not indexed_docs: elapsed_ms = (time.time() - start_ts) * 1000.0 return output_scores, { "input_docs": total_docs, "usable_docs": len(indexed_docs), "unique_docs": 0, "dedup_ratio": 0.0, "elapsed_ms": round(elapsed_ms, 3), } # Sort by estimated length + text to cluster similar lengths indexed_docs.sort(key=lambda x: (len(x[1]), x[1])) unique_texts: List[str] = [] position_to_unique: List[int] = [] prev_text: Optional[str] = None for _idx, text in indexed_docs: if text != prev_text: unique_texts.append(text) prev_text = text position_to_unique.append(len(unique_texts) - 1) logger.debug( "[BGE_RERANKER] Preprocess | input=%d usable=%d unique=%d", total_docs, len(indexed_docs), len(unique_texts), ) unique_scores = self._score_unique( query=query, passages=unique_texts, normalize=normalize ) for (orig_idx, _text), unique_idx in zip(indexed_docs, position_to_unique): output_scores[orig_idx] = float(unique_scores[unique_idx]) # Log per-doc scores (aligned to original docs order) try: lines = [] for i, d in enumerate(docs[:100]): lines.append(f"{output_scores[i]},{'' if d is None else str(d)}") logger.info("[BGE_RERANKER] query:%s Scores (score,doc):\n%s", query, "\n".join(lines)) except Exception: pass elapsed_ms = (time.time() - start_ts) * 1000.0 dedup_ratio = 0.0 if indexed_docs: dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed_docs))) meta = { "input_docs": total_docs, "usable_docs": len(indexed_docs), "unique_docs": len(unique_texts), "dedup_ratio": round(dedup_ratio, 4), "elapsed_ms": round(elapsed_ms, 3), "model": self.model_name, "device": self.device, "fp16": self.use_fp16, "batch_size": self.batch_size, "max_length": self.max_length, "normalize": normalize, } logger.info( "[BGE_RERANKER] Done | input=%d usable=%d unique=%d dedup=%s elapsed_ms=%s", meta["input_docs"], meta["usable_docs"], meta["unique_docs"], meta["dedup_ratio"], meta["elapsed_ms"], ) return output_scores, meta def _compute_optimal_batch_size(self, total: int) -> int: if total <= 0: return 1 current_batch_size = self.batch_size + 8 current_batch_count = math.ceil(total / current_batch_size) optimal_batch_size = current_batch_size test_batch_size = current_batch_size - 4 while test_batch_size > 0: test_batch_count = math.ceil(total / test_batch_size) if test_batch_count <= current_batch_count: optimal_batch_size = test_batch_size test_batch_size -= 4 else: break return max(1, optimal_batch_size) def _score_unique( self, query: str, passages: List[str], normalize: bool = True ) -> List[float]: if not passages: return [] optimal_batch_size = self._compute_optimal_batch_size(len(passages)) logger.info( "[BGE_RERANKER] Reranking %d unique passages | batch=%d | device=%s | fp16=%s", len(passages), optimal_batch_size, self.device, self.use_fp16, ) scores: List[float] = [] with self._lock: for i in range(0, len(passages), optimal_batch_size): batch_passages = passages[i : i + optimal_batch_size] pairs = [[query, passage] for passage in batch_passages] with torch.inference_mode(): inputs = self.tokenizer( pairs, padding=True, truncation=True, return_tensors="pt", max_length=self.max_length, add_special_tokens=True, ) inputs = {k: v.to(self.device) for k, v in inputs.items()} if self.use_fp16: inputs = { k: (v.half() if v.dtype == torch.float32 else v) for k, v in inputs.items() } logits = self.model(**inputs, return_dict=True).logits.view(-1).float() if normalize: logits = torch.sigmoid(logits) batch_scores = logits.detach().cpu().numpy().tolist() scores.extend(batch_scores) return scores