""" Jina reranker v3 backend using the model card's recommended AutoModel API. Reference: https://huggingface.co/jinaai/jina-reranker-v3 Requires: transformers, torch. """ from __future__ import annotations import logging import threading import time from typing import Any, Dict, List, Tuple import torch from transformers import AutoModel logger = logging.getLogger("reranker.backends.jina_reranker_v3") class JinaRerankerV3Backend: """ jina-reranker-v3 backend using `AutoModel(..., trust_remote_code=True)`. The official model card recommends calling: model = AutoModel.from_pretrained(..., trust_remote_code=True) model.rerank(query, documents, top_n=...) Config from services.rerank.backends.jina_reranker_v3. """ def __init__(self, config: Dict[str, Any]) -> None: self._config = config or {} self._model_name = str( self._config.get("model_name") or "jinaai/jina-reranker-v3" ) self._cache_dir = self._config.get("cache_dir") or "./model_cache" self._dtype = str(self._config.get("dtype") or "float16") self._device = self._config.get("device") self._batch_size = max(1, int(self._config.get("batch_size", 64))) self._max_doc_length = max(1, int(self._config.get("max_doc_length", 160))) self._max_query_length = max(1, int(self._config.get("max_query_length", 64))) self._sort_by_doc_length = bool(self._config.get("sort_by_doc_length", True)) self._return_embeddings = bool(self._config.get("return_embeddings", False)) self._trust_remote_code = bool(self._config.get("trust_remote_code", True)) self._lock = threading.Lock() logger.info( "[Jina_Reranker_V3] Loading model %s (dtype=%s, device=%s, batch=%s, " "max_doc_length=%s, max_query_length=%s, sort_by_doc_length=%s)", self._model_name, self._dtype, self._device, self._batch_size, self._max_doc_length, self._max_query_length, self._sort_by_doc_length, ) load_kwargs: Dict[str, Any] = { "trust_remote_code": self._trust_remote_code, "cache_dir": self._cache_dir, "dtype": self._dtype, } self._model = AutoModel.from_pretrained(self._model_name, **load_kwargs) self._model.eval() if self._device is not None: self._model = self._model.to(self._device) elif torch.cuda.is_available(): self._device = "cuda" self._model = self._model.to(self._device) else: self._device = "cpu" logger.info( "[Jina_Reranker_V3] Model ready | model=%s device=%s", self._model_name, self._device, ) def score_with_meta( self, query: str, docs: List[str], normalize: bool = True, ) -> Tuple[List[float], Dict[str, Any]]: return self.score_with_meta_topn(query, docs, normalize=normalize, top_n=None) def score_with_meta_topn( self, query: str, docs: List[str], normalize: bool = True, top_n: int | None = None, ) -> Tuple[List[float], Dict[str, Any]]: start_ts = time.time() total_docs = len(docs) if docs else 0 output_scores: List[float] = [0.0] * total_docs query = "" if query is None else str(query).strip() indexed: List[Tuple[int, str]] = [] for i, doc in enumerate(docs or []): if doc is None: continue text = str(doc).strip() if not text: continue indexed.append((i, text)) if not query or not indexed: elapsed_ms = (time.time() - start_ts) * 1000.0 return output_scores, { "input_docs": total_docs, "usable_docs": len(indexed), "unique_docs": 0, "dedup_ratio": 0.0, "elapsed_ms": round(elapsed_ms, 3), "model": self._model_name, "backend": "jina_reranker_v3", "normalize": normalize, "normalize_note": "jina_reranker_v3 returns model relevance scores directly", } unique_texts: List[str] = [] text_to_unique_idx: Dict[str, int] = {} for orig_idx, text in indexed: unique_idx = text_to_unique_idx.get(text) if unique_idx is None: unique_idx = len(unique_texts) text_to_unique_idx[text] = unique_idx unique_texts.append(text) effective_top_n = min(top_n, len(unique_texts)) if top_n is not None else None unique_scores = self._rerank_unique( query=query, docs=unique_texts, top_n=effective_top_n, ) for orig_idx, text in indexed: unique_idx = text_to_unique_idx[text] output_scores[orig_idx] = float(unique_scores[unique_idx]) elapsed_ms = (time.time() - start_ts) * 1000.0 dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) if indexed else 0.0 meta = { "input_docs": total_docs, "usable_docs": len(indexed), "unique_docs": len(unique_texts), "dedup_ratio": round(dedup_ratio, 4), "elapsed_ms": round(elapsed_ms, 3), "model": self._model_name, "backend": "jina_reranker_v3", "device": self._device, "dtype": self._dtype, "batch_size": self._batch_size, "max_doc_length": self._max_doc_length, "max_query_length": self._max_query_length, "sort_by_doc_length": self._sort_by_doc_length, "normalize": normalize, "normalize_note": "jina_reranker_v3 returns model relevance scores directly", } if effective_top_n is not None: meta["top_n"] = effective_top_n if len(unique_texts) > self._batch_size: meta["top_n_note"] = ( "Applied as a request hint only; full scores were computed because " "global top_n across multiple local batches would be lossy." ) return output_scores, meta def _rerank_unique( self, query: str, docs: List[str], top_n: int | None, ) -> List[float]: if not docs: return [] ordered_indices = list(range(len(docs))) if self._sort_by_doc_length and len(ordered_indices) > 1: ordered_indices.sort(key=lambda idx: len(docs[idx])) unique_scores: List[float] = [0.0] * len(docs) with self._lock: for start in range(0, len(ordered_indices), self._batch_size): batch_indices = ordered_indices[start : start + self._batch_size] batch_docs = [docs[idx] for idx in batch_indices] batch_top_n = None if top_n is not None and len(docs) <= self._batch_size: batch_top_n = min(top_n, len(batch_docs)) results = self._model.rerank( query, batch_docs, top_n=batch_top_n, return_embeddings=self._return_embeddings, max_doc_length=self._max_doc_length, max_query_length=self._max_query_length, ) for item in results: batch_index = int(item["index"]) unique_scores[batch_indices[batch_index]] = float( item["relevance_score"] ) return unique_scores