Blame view

reranker/backends/jina_reranker_v3.py 7.61 KB
971a0851   tangwang   补充reranker-jina,探...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
  """
  Jina reranker v3 backend using the model card's recommended AutoModel API.
  
  Reference: https://huggingface.co/jinaai/jina-reranker-v3
  Requires: transformers, torch.
  """
  
  from __future__ import annotations
  
  import logging
  import threading
  import time
  from typing import Any, Dict, List, Tuple
  
  import torch
  from transformers import AutoModel
  
  logger = logging.getLogger("reranker.backends.jina_reranker_v3")
  
  
  class JinaRerankerV3Backend:
      """
      jina-reranker-v3 backend using `AutoModel(..., trust_remote_code=True)`.
  
      The official model card recommends calling:
        model = AutoModel.from_pretrained(..., trust_remote_code=True)
        model.rerank(query, documents, top_n=...)
  
      Config from services.rerank.backends.jina_reranker_v3.
      """
  
      def __init__(self, config: Dict[str, Any]) -> None:
          self._config = config or {}
          self._model_name = str(
              self._config.get("model_name") or "jinaai/jina-reranker-v3"
          )
          self._cache_dir = self._config.get("cache_dir") or "./model_cache"
74116f05   tangwang   jina-reranker-v3性...
38
          self._dtype = str(self._config.get("dtype") or "float16")
971a0851   tangwang   补充reranker-jina,探...
39
40
          self._device = self._config.get("device")
          self._batch_size = max(1, int(self._config.get("batch_size", 64)))
74116f05   tangwang   jina-reranker-v3性...
41
42
43
          self._max_doc_length = max(1, int(self._config.get("max_doc_length", 160)))
          self._max_query_length = max(1, int(self._config.get("max_query_length", 64)))
          self._sort_by_doc_length = bool(self._config.get("sort_by_doc_length", True))
971a0851   tangwang   补充reranker-jina,探...
44
45
46
47
48
          self._return_embeddings = bool(self._config.get("return_embeddings", False))
          self._trust_remote_code = bool(self._config.get("trust_remote_code", True))
          self._lock = threading.Lock()
  
          logger.info(
74116f05   tangwang   jina-reranker-v3性...
49
50
              "[Jina_Reranker_V3] Loading model %s (dtype=%s, device=%s, batch=%s, "
              "max_doc_length=%s, max_query_length=%s, sort_by_doc_length=%s)",
971a0851   tangwang   补充reranker-jina,探...
51
52
53
54
              self._model_name,
              self._dtype,
              self._device,
              self._batch_size,
74116f05   tangwang   jina-reranker-v3性...
55
56
57
              self._max_doc_length,
              self._max_query_length,
              self._sort_by_doc_length,
971a0851   tangwang   补充reranker-jina,探...
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
          )
  
          load_kwargs: Dict[str, Any] = {
              "trust_remote_code": self._trust_remote_code,
              "cache_dir": self._cache_dir,
              "dtype": self._dtype,
          }
          self._model = AutoModel.from_pretrained(self._model_name, **load_kwargs)
          self._model.eval()
  
          if self._device is not None:
              self._model = self._model.to(self._device)
          elif torch.cuda.is_available():
              self._device = "cuda"
              self._model = self._model.to(self._device)
          else:
              self._device = "cpu"
  
          logger.info(
              "[Jina_Reranker_V3] Model ready | model=%s device=%s",
              self._model_name,
              self._device,
          )
  
      def score_with_meta(
          self,
          query: str,
          docs: List[str],
          normalize: bool = True,
      ) -> Tuple[List[float], Dict[str, Any]]:
          return self.score_with_meta_topn(query, docs, normalize=normalize, top_n=None)
  
      def score_with_meta_topn(
          self,
          query: str,
          docs: List[str],
          normalize: bool = True,
          top_n: int | None = None,
      ) -> Tuple[List[float], Dict[str, Any]]:
          start_ts = time.time()
          total_docs = len(docs) if docs else 0
          output_scores: List[float] = [0.0] * total_docs
  
          query = "" if query is None else str(query).strip()
          indexed: List[Tuple[int, str]] = []
          for i, doc in enumerate(docs or []):
              if doc is None:
                  continue
              text = str(doc).strip()
              if not text:
                  continue
              indexed.append((i, text))
  
          if not query or not indexed:
              elapsed_ms = (time.time() - start_ts) * 1000.0
              return output_scores, {
                  "input_docs": total_docs,
                  "usable_docs": len(indexed),
                  "unique_docs": 0,
                  "dedup_ratio": 0.0,
                  "elapsed_ms": round(elapsed_ms, 3),
                  "model": self._model_name,
                  "backend": "jina_reranker_v3",
                  "normalize": normalize,
                  "normalize_note": "jina_reranker_v3 returns model relevance scores directly",
              }
  
          unique_texts: List[str] = []
971a0851   tangwang   补充reranker-jina,探...
126
127
128
129
130
131
132
          text_to_unique_idx: Dict[str, int] = {}
          for orig_idx, text in indexed:
              unique_idx = text_to_unique_idx.get(text)
              if unique_idx is None:
                  unique_idx = len(unique_texts)
                  text_to_unique_idx[text] = unique_idx
                  unique_texts.append(text)
971a0851   tangwang   补充reranker-jina,探...
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
  
          effective_top_n = min(top_n, len(unique_texts)) if top_n is not None else None
  
          unique_scores = self._rerank_unique(
              query=query,
              docs=unique_texts,
              top_n=effective_top_n,
          )
  
          for orig_idx, text in indexed:
              unique_idx = text_to_unique_idx[text]
              output_scores[orig_idx] = float(unique_scores[unique_idx])
  
          elapsed_ms = (time.time() - start_ts) * 1000.0
          dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) if indexed else 0.0
          meta = {
              "input_docs": total_docs,
              "usable_docs": len(indexed),
              "unique_docs": len(unique_texts),
              "dedup_ratio": round(dedup_ratio, 4),
              "elapsed_ms": round(elapsed_ms, 3),
              "model": self._model_name,
              "backend": "jina_reranker_v3",
              "device": self._device,
              "dtype": self._dtype,
              "batch_size": self._batch_size,
74116f05   tangwang   jina-reranker-v3性...
159
160
161
              "max_doc_length": self._max_doc_length,
              "max_query_length": self._max_query_length,
              "sort_by_doc_length": self._sort_by_doc_length,
971a0851   tangwang   补充reranker-jina,探...
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
              "normalize": normalize,
              "normalize_note": "jina_reranker_v3 returns model relevance scores directly",
          }
          if effective_top_n is not None:
              meta["top_n"] = effective_top_n
              if len(unique_texts) > self._batch_size:
                  meta["top_n_note"] = (
                      "Applied as a request hint only; full scores were computed because "
                      "global top_n across multiple local batches would be lossy."
                  )
          return output_scores, meta
  
      def _rerank_unique(
          self,
          query: str,
          docs: List[str],
          top_n: int | None,
      ) -> List[float]:
          if not docs:
              return []
  
74116f05   tangwang   jina-reranker-v3性...
183
184
185
          ordered_indices = list(range(len(docs)))
          if self._sort_by_doc_length and len(ordered_indices) > 1:
              ordered_indices.sort(key=lambda idx: len(docs[idx]))
971a0851   tangwang   补充reranker-jina,探...
186
  
74116f05   tangwang   jina-reranker-v3性...
187
          unique_scores: List[float] = [0.0] * len(docs)
971a0851   tangwang   补充reranker-jina,探...
188
          with self._lock:
74116f05   tangwang   jina-reranker-v3性...
189
190
191
              for start in range(0, len(ordered_indices), self._batch_size):
                  batch_indices = ordered_indices[start : start + self._batch_size]
                  batch_docs = [docs[idx] for idx in batch_indices]
971a0851   tangwang   补充reranker-jina,探...
192
193
194
195
196
197
198
199
                  batch_top_n = None
                  if top_n is not None and len(docs) <= self._batch_size:
                      batch_top_n = min(top_n, len(batch_docs))
                  results = self._model.rerank(
                      query,
                      batch_docs,
                      top_n=batch_top_n,
                      return_embeddings=self._return_embeddings,
74116f05   tangwang   jina-reranker-v3性...
200
201
                      max_doc_length=self._max_doc_length,
                      max_query_length=self._max_query_length,
971a0851   tangwang   补充reranker-jina,探...
202
203
204
                  )
                  for item in results:
                      batch_index = int(item["index"])
74116f05   tangwang   jina-reranker-v3性...
205
206
207
                      unique_scores[batch_indices[batch_index]] = float(
                          item["relevance_score"]
                      )
971a0851   tangwang   补充reranker-jina,探...
208
209
  
          return unique_scores