971a0851
tangwang
补充reranker-jina,探...
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
"""
Jina reranker v3 backend using the model card's recommended AutoModel API.
Reference: https://huggingface.co/jinaai/jina-reranker-v3
Requires: transformers, torch.
"""
from __future__ import annotations
import logging
import threading
import time
from typing import Any, Dict, List, Tuple
import torch
from transformers import AutoModel
logger = logging.getLogger("reranker.backends.jina_reranker_v3")
class JinaRerankerV3Backend:
"""
jina-reranker-v3 backend using `AutoModel(..., trust_remote_code=True)`.
The official model card recommends calling:
model = AutoModel.from_pretrained(..., trust_remote_code=True)
model.rerank(query, documents, top_n=...)
Config from services.rerank.backends.jina_reranker_v3.
"""
def __init__(self, config: Dict[str, Any]) -> None:
self._config = config or {}
self._model_name = str(
self._config.get("model_name") or "jinaai/jina-reranker-v3"
)
self._cache_dir = self._config.get("cache_dir") or "./model_cache"
|
74116f05
tangwang
jina-reranker-v3性...
|
38
|
self._dtype = str(self._config.get("dtype") or "float16")
|
971a0851
tangwang
补充reranker-jina,探...
|
39
40
|
self._device = self._config.get("device")
self._batch_size = max(1, int(self._config.get("batch_size", 64)))
|
74116f05
tangwang
jina-reranker-v3性...
|
41
42
43
|
self._max_doc_length = max(1, int(self._config.get("max_doc_length", 160)))
self._max_query_length = max(1, int(self._config.get("max_query_length", 64)))
self._sort_by_doc_length = bool(self._config.get("sort_by_doc_length", True))
|
971a0851
tangwang
补充reranker-jina,探...
|
44
45
46
47
48
|
self._return_embeddings = bool(self._config.get("return_embeddings", False))
self._trust_remote_code = bool(self._config.get("trust_remote_code", True))
self._lock = threading.Lock()
logger.info(
|
74116f05
tangwang
jina-reranker-v3性...
|
49
50
|
"[Jina_Reranker_V3] Loading model %s (dtype=%s, device=%s, batch=%s, "
"max_doc_length=%s, max_query_length=%s, sort_by_doc_length=%s)",
|
971a0851
tangwang
补充reranker-jina,探...
|
51
52
53
54
|
self._model_name,
self._dtype,
self._device,
self._batch_size,
|
74116f05
tangwang
jina-reranker-v3性...
|
55
56
57
|
self._max_doc_length,
self._max_query_length,
self._sort_by_doc_length,
|
971a0851
tangwang
补充reranker-jina,探...
|
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
|
)
load_kwargs: Dict[str, Any] = {
"trust_remote_code": self._trust_remote_code,
"cache_dir": self._cache_dir,
"dtype": self._dtype,
}
self._model = AutoModel.from_pretrained(self._model_name, **load_kwargs)
self._model.eval()
if self._device is not None:
self._model = self._model.to(self._device)
elif torch.cuda.is_available():
self._device = "cuda"
self._model = self._model.to(self._device)
else:
self._device = "cpu"
logger.info(
"[Jina_Reranker_V3] Model ready | model=%s device=%s",
self._model_name,
self._device,
)
def score_with_meta(
self,
query: str,
docs: List[str],
normalize: bool = True,
) -> Tuple[List[float], Dict[str, Any]]:
return self.score_with_meta_topn(query, docs, normalize=normalize, top_n=None)
def score_with_meta_topn(
self,
query: str,
docs: List[str],
normalize: bool = True,
top_n: int | None = None,
) -> Tuple[List[float], Dict[str, Any]]:
start_ts = time.time()
total_docs = len(docs) if docs else 0
output_scores: List[float] = [0.0] * total_docs
query = "" if query is None else str(query).strip()
indexed: List[Tuple[int, str]] = []
for i, doc in enumerate(docs or []):
if doc is None:
continue
text = str(doc).strip()
if not text:
continue
indexed.append((i, text))
if not query or not indexed:
elapsed_ms = (time.time() - start_ts) * 1000.0
return output_scores, {
"input_docs": total_docs,
"usable_docs": len(indexed),
"unique_docs": 0,
"dedup_ratio": 0.0,
"elapsed_ms": round(elapsed_ms, 3),
"model": self._model_name,
"backend": "jina_reranker_v3",
"normalize": normalize,
"normalize_note": "jina_reranker_v3 returns model relevance scores directly",
}
unique_texts: List[str] = []
|
971a0851
tangwang
补充reranker-jina,探...
|
126
127
128
129
130
131
132
|
text_to_unique_idx: Dict[str, int] = {}
for orig_idx, text in indexed:
unique_idx = text_to_unique_idx.get(text)
if unique_idx is None:
unique_idx = len(unique_texts)
text_to_unique_idx[text] = unique_idx
unique_texts.append(text)
|
971a0851
tangwang
补充reranker-jina,探...
|
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
|
effective_top_n = min(top_n, len(unique_texts)) if top_n is not None else None
unique_scores = self._rerank_unique(
query=query,
docs=unique_texts,
top_n=effective_top_n,
)
for orig_idx, text in indexed:
unique_idx = text_to_unique_idx[text]
output_scores[orig_idx] = float(unique_scores[unique_idx])
elapsed_ms = (time.time() - start_ts) * 1000.0
dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) if indexed else 0.0
meta = {
"input_docs": total_docs,
"usable_docs": len(indexed),
"unique_docs": len(unique_texts),
"dedup_ratio": round(dedup_ratio, 4),
"elapsed_ms": round(elapsed_ms, 3),
"model": self._model_name,
"backend": "jina_reranker_v3",
"device": self._device,
"dtype": self._dtype,
"batch_size": self._batch_size,
|
74116f05
tangwang
jina-reranker-v3性...
|
159
160
161
|
"max_doc_length": self._max_doc_length,
"max_query_length": self._max_query_length,
"sort_by_doc_length": self._sort_by_doc_length,
|
971a0851
tangwang
补充reranker-jina,探...
|
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
|
"normalize": normalize,
"normalize_note": "jina_reranker_v3 returns model relevance scores directly",
}
if effective_top_n is not None:
meta["top_n"] = effective_top_n
if len(unique_texts) > self._batch_size:
meta["top_n_note"] = (
"Applied as a request hint only; full scores were computed because "
"global top_n across multiple local batches would be lossy."
)
return output_scores, meta
def _rerank_unique(
self,
query: str,
docs: List[str],
top_n: int | None,
) -> List[float]:
if not docs:
return []
|
74116f05
tangwang
jina-reranker-v3性...
|
183
184
185
|
ordered_indices = list(range(len(docs)))
if self._sort_by_doc_length and len(ordered_indices) > 1:
ordered_indices.sort(key=lambda idx: len(docs[idx]))
|
971a0851
tangwang
补充reranker-jina,探...
|
186
|
|
74116f05
tangwang
jina-reranker-v3性...
|
187
|
unique_scores: List[float] = [0.0] * len(docs)
|
971a0851
tangwang
补充reranker-jina,探...
|
188
|
with self._lock:
|
74116f05
tangwang
jina-reranker-v3性...
|
189
190
191
|
for start in range(0, len(ordered_indices), self._batch_size):
batch_indices = ordered_indices[start : start + self._batch_size]
batch_docs = [docs[idx] for idx in batch_indices]
|
971a0851
tangwang
补充reranker-jina,探...
|
192
193
194
195
196
197
198
199
|
batch_top_n = None
if top_n is not None and len(docs) <= self._batch_size:
batch_top_n = min(top_n, len(batch_docs))
results = self._model.rerank(
query,
batch_docs,
top_n=batch_top_n,
return_embeddings=self._return_embeddings,
|
74116f05
tangwang
jina-reranker-v3性...
|
200
201
|
max_doc_length=self._max_doc_length,
max_query_length=self._max_query_length,
|
971a0851
tangwang
补充reranker-jina,探...
|
202
203
204
|
)
for item in results:
batch_index = int(item["index"])
|
74116f05
tangwang
jina-reranker-v3性...
|
205
206
207
|
unique_scores[batch_indices[batch_index]] = float(
item["relevance_score"]
)
|
971a0851
tangwang
补充reranker-jina,探...
|
208
209
|
return unique_scores
|