jina_reranker_v3.py 6.67 KB
"""
Jina reranker v3 backend using the model card's recommended AutoModel API.

Reference: https://huggingface.co/jinaai/jina-reranker-v3
Requires: transformers, torch.
"""

from __future__ import annotations

import logging
import threading
import time
from typing import Any, Dict, List, Tuple

import torch
from transformers import AutoModel

logger = logging.getLogger("reranker.backends.jina_reranker_v3")


class JinaRerankerV3Backend:
    """
    jina-reranker-v3 backend using `AutoModel(..., trust_remote_code=True)`.

    The official model card recommends calling:
      model = AutoModel.from_pretrained(..., trust_remote_code=True)
      model.rerank(query, documents, top_n=...)

    Config from services.rerank.backends.jina_reranker_v3.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self._config = config or {}
        self._model_name = str(
            self._config.get("model_name") or "jinaai/jina-reranker-v3"
        )
        self._cache_dir = self._config.get("cache_dir") or "./model_cache"
        self._dtype = str(self._config.get("dtype") or "auto")
        self._device = self._config.get("device")
        self._batch_size = max(1, int(self._config.get("batch_size", 64)))
        self._return_embeddings = bool(self._config.get("return_embeddings", False))
        self._trust_remote_code = bool(self._config.get("trust_remote_code", True))
        self._lock = threading.Lock()

        logger.info(
            "[Jina_Reranker_V3] Loading model %s (dtype=%s, device=%s, batch=%s)",
            self._model_name,
            self._dtype,
            self._device,
            self._batch_size,
        )

        load_kwargs: Dict[str, Any] = {
            "trust_remote_code": self._trust_remote_code,
            "cache_dir": self._cache_dir,
            "dtype": self._dtype,
        }
        self._model = AutoModel.from_pretrained(self._model_name, **load_kwargs)
        self._model.eval()

        if self._device is not None:
            self._model = self._model.to(self._device)
        elif torch.cuda.is_available():
            self._device = "cuda"
            self._model = self._model.to(self._device)
        else:
            self._device = "cpu"

        logger.info(
            "[Jina_Reranker_V3] Model ready | model=%s device=%s",
            self._model_name,
            self._device,
        )

    def score_with_meta(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
    ) -> Tuple[List[float], Dict[str, Any]]:
        return self.score_with_meta_topn(query, docs, normalize=normalize, top_n=None)

    def score_with_meta_topn(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
        top_n: int | None = None,
    ) -> Tuple[List[float], Dict[str, Any]]:
        start_ts = time.time()
        total_docs = len(docs) if docs else 0
        output_scores: List[float] = [0.0] * total_docs

        query = "" if query is None else str(query).strip()
        indexed: List[Tuple[int, str]] = []
        for i, doc in enumerate(docs or []):
            if doc is None:
                continue
            text = str(doc).strip()
            if not text:
                continue
            indexed.append((i, text))

        if not query or not indexed:
            elapsed_ms = (time.time() - start_ts) * 1000.0
            return output_scores, {
                "input_docs": total_docs,
                "usable_docs": len(indexed),
                "unique_docs": 0,
                "dedup_ratio": 0.0,
                "elapsed_ms": round(elapsed_ms, 3),
                "model": self._model_name,
                "backend": "jina_reranker_v3",
                "normalize": normalize,
                "normalize_note": "jina_reranker_v3 returns model relevance scores directly",
            }

        unique_texts: List[str] = []
        unique_first_indices: List[int] = []
        text_to_unique_idx: Dict[str, int] = {}
        for orig_idx, text in indexed:
            unique_idx = text_to_unique_idx.get(text)
            if unique_idx is None:
                unique_idx = len(unique_texts)
                text_to_unique_idx[text] = unique_idx
                unique_texts.append(text)
                unique_first_indices.append(orig_idx)

        effective_top_n = min(top_n, len(unique_texts)) if top_n is not None else None

        unique_scores = self._rerank_unique(
            query=query,
            docs=unique_texts,
            top_n=effective_top_n,
        )

        for orig_idx, text in indexed:
            unique_idx = text_to_unique_idx[text]
            output_scores[orig_idx] = float(unique_scores[unique_idx])

        elapsed_ms = (time.time() - start_ts) * 1000.0
        dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) if indexed else 0.0
        meta = {
            "input_docs": total_docs,
            "usable_docs": len(indexed),
            "unique_docs": len(unique_texts),
            "dedup_ratio": round(dedup_ratio, 4),
            "elapsed_ms": round(elapsed_ms, 3),
            "model": self._model_name,
            "backend": "jina_reranker_v3",
            "device": self._device,
            "dtype": self._dtype,
            "batch_size": self._batch_size,
            "normalize": normalize,
            "normalize_note": "jina_reranker_v3 returns model relevance scores directly",
        }
        if effective_top_n is not None:
            meta["top_n"] = effective_top_n
            if len(unique_texts) > self._batch_size:
                meta["top_n_note"] = (
                    "Applied as a request hint only; full scores were computed because "
                    "global top_n across multiple local batches would be lossy."
                )
        return output_scores, meta

    def _rerank_unique(
        self,
        query: str,
        docs: List[str],
        top_n: int | None,
    ) -> List[float]:
        if not docs:
            return []

        unique_scores: List[float] = [0.0] * len(docs)

        with self._lock:
            for start in range(0, len(docs), self._batch_size):
                batch_docs = docs[start : start + self._batch_size]
                batch_top_n = None
                if top_n is not None and len(docs) <= self._batch_size:
                    batch_top_n = min(top_n, len(batch_docs))
                results = self._model.rerank(
                    query,
                    batch_docs,
                    top_n=batch_top_n,
                    return_embeddings=self._return_embeddings,
                )
                for item in results:
                    batch_index = int(item["index"])
                    unique_scores[start + batch_index] = float(item["relevance_score"])

        return unique_scores