Commit 74116f0573b1333975cc6a0eae2e8f58cd4972cf

Authored by tangwang
1 parent 971a0851

jina-reranker-v3性能测试和参数优化

config/config.yaml
... ... @@ -395,8 +395,11 @@ services:
395 395 jina_reranker_v3:
396 396 model_name: "jinaai/jina-reranker-v3"
397 397 device: null
398   - dtype: "auto"
  398 + dtype: "float16"
399 399 batch_size: 64
  400 + max_doc_length: 160
  401 + max_query_length: 64
  402 + sort_by_doc_length: true
400 403 cache_dir: "./model_cache"
401 404 trust_remote_code: true
402 405 qwen3_vllm:
... ...
docs/搜索API对接指南-07-微服务接口(Embedding-Reranker-Translation).md
... ... @@ -190,6 +190,16 @@ curl "http://localhost:6008/ready"
190 190  
191 191 说明:默认后端为 `qwen3_vllm`(`Qwen/Qwen3-Reranker-0.6B`),需要可用 GPU 显存。
192 192  
  193 +补充:若切换到 `jina_reranker_v3`,在当前 `Tesla T4` 上建议使用:
  194 +
  195 +- `dtype: float16`
  196 +- `batch_size: 64`
  197 +- `max_doc_length: 160`
  198 +- `max_query_length: 64`
  199 +- `sort_by_doc_length: true`
  200 +
  201 +原因:`jina_reranker_v3` 的 `auto` 在当前机器上会落到 `bfloat16`,性能明显差于 `float16`;而它的 listwise 架构在 T4 上对上下文长度更敏感,过大的 batch 会显著拉长延迟。
  202 +
193 203 补充:`docs` 的请求大小与模型推理 `batch size` 解耦。即使一次传入 1000 条文档,服务端也会按 `services.rerank.backends.qwen3_vllm.infer_batch_size` 自动拆分。
194 204  
195 205 #### 7.2.1 `POST /rerank` — 结果重排
... ... @@ -439,4 +449,3 @@ curl "http://localhost:6006/health"
439 449 请求/响应格式、示例及错误码见 [-05-索引接口(Indexer)](./搜索API对接指南-05-索引接口(Indexer).md#58-内容理解字段生成接口)。
440 450  
441 451 ---
442   -
... ...
requirements_reranker_jina_reranker_v3.txt
1 1 # Isolated dependencies for jina_reranker_v3 reranker backend.
  2 +#
  3 +# Keep this stack aligned with the validated CUDA runtime on our hosts.
  4 +# On this machine, torch 2.11.0 + cu130 fails CUDA init, while torch 2.10.0 + cu128 works.
  5 +# Cap transformers <5 to stay on the same family as the known-good reranker envs.
2 6  
3 7 -r requirements_reranker_base.txt
4   -torch>=2.0.0
5   -transformers>=4.51.0
  8 +torch==2.10.0
  9 +transformers>=4.51.0,<5
... ...
reranker/README.md
... ... @@ -152,12 +152,22 @@ services:
152 152 jina_reranker_v3:
153 153 model_name: "jinaai/jina-reranker-v3"
154 154 device: null
155   - dtype: "auto"
  155 + dtype: "float16"
156 156 batch_size: 64
  157 + max_doc_length: 160
  158 + max_query_length: 64
  159 + sort_by_doc_length: true
157 160 cache_dir: "./model_cache"
158 161 trust_remote_code: true
159 162 ```
160 163  
  164 +T4 实测建议:
  165 +
  166 +- `dtype` 优先使用 `float16`;在当前机器上 `auto` 会加载成 `bfloat16`,明显更慢
  167 +- 在线短文本商品重排建议从 `batch_size: 64` 起步;它比更大的 listwise block 更快,但会牺牲一部分“完整 listwise”排序一致性
  168 +- 若你更看重接近完整 listwise 的排序结果,可提高到 `batch_size: 125`,代价是延迟明显上升
  169 +- `max_doc_length: 160`、`max_query_length: 64` 更适合当前商品标题 / 短 query 场景
  170 +
161 171 ## 当前最优方案:`qwen3_vllm_score`
162 172  
163 173  
... ...
reranker/backends/jina_reranker_v3.py
... ... @@ -35,19 +35,26 @@ class JinaRerankerV3Backend:
35 35 self._config.get("model_name") or "jinaai/jina-reranker-v3"
36 36 )
37 37 self._cache_dir = self._config.get("cache_dir") or "./model_cache"
38   - self._dtype = str(self._config.get("dtype") or "auto")
  38 + self._dtype = str(self._config.get("dtype") or "float16")
39 39 self._device = self._config.get("device")
40 40 self._batch_size = max(1, int(self._config.get("batch_size", 64)))
  41 + self._max_doc_length = max(1, int(self._config.get("max_doc_length", 160)))
  42 + self._max_query_length = max(1, int(self._config.get("max_query_length", 64)))
  43 + self._sort_by_doc_length = bool(self._config.get("sort_by_doc_length", True))
41 44 self._return_embeddings = bool(self._config.get("return_embeddings", False))
42 45 self._trust_remote_code = bool(self._config.get("trust_remote_code", True))
43 46 self._lock = threading.Lock()
44 47  
45 48 logger.info(
46   - "[Jina_Reranker_V3] Loading model %s (dtype=%s, device=%s, batch=%s)",
  49 + "[Jina_Reranker_V3] Loading model %s (dtype=%s, device=%s, batch=%s, "
  50 + "max_doc_length=%s, max_query_length=%s, sort_by_doc_length=%s)",
47 51 self._model_name,
48 52 self._dtype,
49 53 self._device,
50 54 self._batch_size,
  55 + self._max_doc_length,
  56 + self._max_query_length,
  57 + self._sort_by_doc_length,
51 58 )
52 59  
53 60 load_kwargs: Dict[str, Any] = {
... ... @@ -116,7 +123,6 @@ class JinaRerankerV3Backend:
116 123 }
117 124  
118 125 unique_texts: List[str] = []
119   - unique_first_indices: List[int] = []
120 126 text_to_unique_idx: Dict[str, int] = {}
121 127 for orig_idx, text in indexed:
122 128 unique_idx = text_to_unique_idx.get(text)
... ... @@ -124,7 +130,6 @@ class JinaRerankerV3Backend:
124 130 unique_idx = len(unique_texts)
125 131 text_to_unique_idx[text] = unique_idx
126 132 unique_texts.append(text)
127   - unique_first_indices.append(orig_idx)
128 133  
129 134 effective_top_n = min(top_n, len(unique_texts)) if top_n is not None else None
130 135  
... ... @@ -151,6 +156,9 @@ class JinaRerankerV3Backend:
151 156 "device": self._device,
152 157 "dtype": self._dtype,
153 158 "batch_size": self._batch_size,
  159 + "max_doc_length": self._max_doc_length,
  160 + "max_query_length": self._max_query_length,
  161 + "sort_by_doc_length": self._sort_by_doc_length,
154 162 "normalize": normalize,
155 163 "normalize_note": "jina_reranker_v3 returns model relevance scores directly",
156 164 }
... ... @@ -172,11 +180,15 @@ class JinaRerankerV3Backend:
172 180 if not docs:
173 181 return []
174 182  
175   - unique_scores: List[float] = [0.0] * len(docs)
  183 + ordered_indices = list(range(len(docs)))
  184 + if self._sort_by_doc_length and len(ordered_indices) > 1:
  185 + ordered_indices.sort(key=lambda idx: len(docs[idx]))
176 186  
  187 + unique_scores: List[float] = [0.0] * len(docs)
177 188 with self._lock:
178   - for start in range(0, len(docs), self._batch_size):
179   - batch_docs = docs[start : start + self._batch_size]
  189 + for start in range(0, len(ordered_indices), self._batch_size):
  190 + batch_indices = ordered_indices[start : start + self._batch_size]
  191 + batch_docs = [docs[idx] for idx in batch_indices]
180 192 batch_top_n = None
181 193 if top_n is not None and len(docs) <= self._batch_size:
182 194 batch_top_n = min(top_n, len(batch_docs))
... ... @@ -185,9 +197,13 @@ class JinaRerankerV3Backend:
185 197 batch_docs,
186 198 top_n=batch_top_n,
187 199 return_embeddings=self._return_embeddings,
  200 + max_doc_length=self._max_doc_length,
  201 + max_query_length=self._max_query_length,
188 202 )
189 203 for item in results:
190 204 batch_index = int(item["index"])
191   - unique_scores[start + batch_index] = float(item["relevance_score"])
  205 + unique_scores[batch_indices[batch_index]] = float(
  206 + item["relevance_score"]
  207 + )
192 208  
193 209 return unique_scores
... ...