diff --git a/config/config.yaml b/config/config.yaml index 7a8b268..b569ff1 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -395,8 +395,11 @@ services: jina_reranker_v3: model_name: "jinaai/jina-reranker-v3" device: null - dtype: "auto" + dtype: "float16" batch_size: 64 + max_doc_length: 160 + max_query_length: 64 + sort_by_doc_length: true cache_dir: "./model_cache" trust_remote_code: true qwen3_vllm: diff --git a/docs/搜索API对接指南-07-微服务接口(Embedding-Reranker-Translation).md b/docs/搜索API对接指南-07-微服务接口(Embedding-Reranker-Translation).md index de8b10e..fdf774a 100644 --- a/docs/搜索API对接指南-07-微服务接口(Embedding-Reranker-Translation).md +++ b/docs/搜索API对接指南-07-微服务接口(Embedding-Reranker-Translation).md @@ -190,6 +190,16 @@ curl "http://localhost:6008/ready" 说明:默认后端为 `qwen3_vllm`(`Qwen/Qwen3-Reranker-0.6B`),需要可用 GPU 显存。 +补充:若切换到 `jina_reranker_v3`,在当前 `Tesla T4` 上建议使用: + +- `dtype: float16` +- `batch_size: 64` +- `max_doc_length: 160` +- `max_query_length: 64` +- `sort_by_doc_length: true` + +原因:`jina_reranker_v3` 的 `auto` 在当前机器上会落到 `bfloat16`,性能明显差于 `float16`;而它的 listwise 架构在 T4 上对上下文长度更敏感,过大的 batch 会显著拉长延迟。 + 补充:`docs` 的请求大小与模型推理 `batch size` 解耦。即使一次传入 1000 条文档,服务端也会按 `services.rerank.backends.qwen3_vllm.infer_batch_size` 自动拆分。 #### 7.2.1 `POST /rerank` — 结果重排 @@ -439,4 +449,3 @@ curl "http://localhost:6006/health" 请求/响应格式、示例及错误码见 [-05-索引接口(Indexer)](./搜索API对接指南-05-索引接口(Indexer).md#58-内容理解字段生成接口)。 --- - diff --git a/requirements_reranker_jina_reranker_v3.txt b/requirements_reranker_jina_reranker_v3.txt index 9bfce7f..f245b55 100644 --- a/requirements_reranker_jina_reranker_v3.txt +++ b/requirements_reranker_jina_reranker_v3.txt @@ -1,5 +1,9 @@ # Isolated dependencies for jina_reranker_v3 reranker backend. +# +# Keep this stack aligned with the validated CUDA runtime on our hosts. +# On this machine, torch 2.11.0 + cu130 fails CUDA init, while torch 2.10.0 + cu128 works. +# Cap transformers <5 to stay on the same family as the known-good reranker envs. -r requirements_reranker_base.txt -torch>=2.0.0 -transformers>=4.51.0 +torch==2.10.0 +transformers>=4.51.0,<5 diff --git a/reranker/README.md b/reranker/README.md index 91ed711..e2a0e0f 100644 --- a/reranker/README.md +++ b/reranker/README.md @@ -152,12 +152,22 @@ services: jina_reranker_v3: model_name: "jinaai/jina-reranker-v3" device: null - dtype: "auto" + dtype: "float16" batch_size: 64 + max_doc_length: 160 + max_query_length: 64 + sort_by_doc_length: true cache_dir: "./model_cache" trust_remote_code: true ``` +T4 实测建议: + +- `dtype` 优先使用 `float16`;在当前机器上 `auto` 会加载成 `bfloat16`,明显更慢 +- 在线短文本商品重排建议从 `batch_size: 64` 起步;它比更大的 listwise block 更快,但会牺牲一部分“完整 listwise”排序一致性 +- 若你更看重接近完整 listwise 的排序结果,可提高到 `batch_size: 125`,代价是延迟明显上升 +- `max_doc_length: 160`、`max_query_length: 64` 更适合当前商品标题 / 短 query 场景 + ## 当前最优方案:`qwen3_vllm_score` diff --git a/reranker/backends/jina_reranker_v3.py b/reranker/backends/jina_reranker_v3.py index 0551e1e..6517618 100644 --- a/reranker/backends/jina_reranker_v3.py +++ b/reranker/backends/jina_reranker_v3.py @@ -35,19 +35,26 @@ class JinaRerankerV3Backend: self._config.get("model_name") or "jinaai/jina-reranker-v3" ) self._cache_dir = self._config.get("cache_dir") or "./model_cache" - self._dtype = str(self._config.get("dtype") or "auto") + self._dtype = str(self._config.get("dtype") or "float16") self._device = self._config.get("device") self._batch_size = max(1, int(self._config.get("batch_size", 64))) + self._max_doc_length = max(1, int(self._config.get("max_doc_length", 160))) + self._max_query_length = max(1, int(self._config.get("max_query_length", 64))) + self._sort_by_doc_length = bool(self._config.get("sort_by_doc_length", True)) self._return_embeddings = bool(self._config.get("return_embeddings", False)) self._trust_remote_code = bool(self._config.get("trust_remote_code", True)) self._lock = threading.Lock() logger.info( - "[Jina_Reranker_V3] Loading model %s (dtype=%s, device=%s, batch=%s)", + "[Jina_Reranker_V3] Loading model %s (dtype=%s, device=%s, batch=%s, " + "max_doc_length=%s, max_query_length=%s, sort_by_doc_length=%s)", self._model_name, self._dtype, self._device, self._batch_size, + self._max_doc_length, + self._max_query_length, + self._sort_by_doc_length, ) load_kwargs: Dict[str, Any] = { @@ -116,7 +123,6 @@ class JinaRerankerV3Backend: } unique_texts: List[str] = [] - unique_first_indices: List[int] = [] text_to_unique_idx: Dict[str, int] = {} for orig_idx, text in indexed: unique_idx = text_to_unique_idx.get(text) @@ -124,7 +130,6 @@ class JinaRerankerV3Backend: unique_idx = len(unique_texts) text_to_unique_idx[text] = unique_idx unique_texts.append(text) - unique_first_indices.append(orig_idx) effective_top_n = min(top_n, len(unique_texts)) if top_n is not None else None @@ -151,6 +156,9 @@ class JinaRerankerV3Backend: "device": self._device, "dtype": self._dtype, "batch_size": self._batch_size, + "max_doc_length": self._max_doc_length, + "max_query_length": self._max_query_length, + "sort_by_doc_length": self._sort_by_doc_length, "normalize": normalize, "normalize_note": "jina_reranker_v3 returns model relevance scores directly", } @@ -172,11 +180,15 @@ class JinaRerankerV3Backend: if not docs: return [] - unique_scores: List[float] = [0.0] * len(docs) + ordered_indices = list(range(len(docs))) + if self._sort_by_doc_length and len(ordered_indices) > 1: + ordered_indices.sort(key=lambda idx: len(docs[idx])) + unique_scores: List[float] = [0.0] * len(docs) with self._lock: - for start in range(0, len(docs), self._batch_size): - batch_docs = docs[start : start + self._batch_size] + for start in range(0, len(ordered_indices), self._batch_size): + batch_indices = ordered_indices[start : start + self._batch_size] + batch_docs = [docs[idx] for idx in batch_indices] batch_top_n = None if top_n is not None and len(docs) <= self._batch_size: batch_top_n = min(top_n, len(batch_docs)) @@ -185,9 +197,13 @@ class JinaRerankerV3Backend: batch_docs, top_n=batch_top_n, return_embeddings=self._return_embeddings, + max_doc_length=self._max_doc_length, + max_query_length=self._max_query_length, ) for item in results: batch_index = int(item["index"]) - unique_scores[start + batch_index] = float(item["relevance_score"]) + unique_scores[batch_indices[batch_index]] = float( + item["relevance_score"] + ) return unique_scores -- libgit2 0.21.2