Commit 0d3e73ba8718d32d00a4f66eb749375ba333e634
1 parent
d387e05d
rerank mini batch
Showing
9 changed files
with
306 additions
and
45 deletions
Show diff stats
| @@ -35,9 +35,10 @@ CACHE_DIR=.cache | @@ -35,9 +35,10 @@ CACHE_DIR=.cache | ||
| 35 | API_BASE_URL=http://43.166.252.75:6002 | 35 | API_BASE_URL=http://43.166.252.75:6002 |
| 36 | 36 | ||
| 37 | 37 | ||
| 38 | -# 国内 | 38 | +# 通用 DashScope key(翻译/内容理解等模块) |
| 39 | DASHSCOPE_API_KEY=sk-c3b8d4db061840aa8effb748df2a997b | 39 | DASHSCOPE_API_KEY=sk-c3b8d4db061840aa8effb748df2a997b |
| 40 | -# 美国 | ||
| 41 | -DASHSCOPE_API_KEY=sk-482cc3ff37a8467dab134a7a46830556 | 40 | +# Reranker 专用 key(按地域) |
| 41 | +RERANK_DASHSCOPE_API_KEY_CN=sk-c3b8d4db061840aa8effb748df2a997b | ||
| 42 | +RERANK_DASHSCOPE_API_KEY_US=sk-482cc3ff37a8467dab134a7a46830556 | ||
| 42 | 43 | ||
| 43 | OPENAI_API_KEY=sk-HvmTMKtuznibZ75l7L2uF2jiaYocCthqd8Cbdkl09KTE7Ft0 | 44 | OPENAI_API_KEY=sk-HvmTMKtuznibZ75l7L2uF2jiaYocCthqd8Cbdkl09KTE7Ft0 |
.env.example
| @@ -45,7 +45,9 @@ TEI_HEALTH_TIMEOUT_SEC=300 | @@ -45,7 +45,9 @@ TEI_HEALTH_TIMEOUT_SEC=300 | ||
| 45 | RERANK_PROVIDER=http | 45 | RERANK_PROVIDER=http |
| 46 | RERANK_BACKEND=qwen3_vllm | 46 | RERANK_BACKEND=qwen3_vllm |
| 47 | # Optional for cloud rerank backend (RERANK_BACKEND=dashscope_rerank) | 47 | # Optional for cloud rerank backend (RERANK_BACKEND=dashscope_rerank) |
| 48 | -DASHSCOPE_API_KEY= | 48 | +# Reranker cloud API keys by region |
| 49 | +RERANK_DASHSCOPE_API_KEY_CN= | ||
| 50 | +RERANK_DASHSCOPE_API_KEY_US= | ||
| 49 | # Example: | 51 | # Example: |
| 50 | # RERANK_DASHSCOPE_ENDPOINT=https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks | 52 | # RERANK_DASHSCOPE_ENDPOINT=https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks |
| 51 | RERANK_DASHSCOPE_ENDPOINT= | 53 | RERANK_DASHSCOPE_ENDPOINT= |
config/config.yaml
| @@ -166,7 +166,7 @@ services: | @@ -166,7 +166,7 @@ services: | ||
| 166 | base_url: "http://127.0.0.1:6007" | 166 | base_url: "http://127.0.0.1:6007" |
| 167 | service_url: "http://127.0.0.1:6007/rerank" | 167 | service_url: "http://127.0.0.1:6007/rerank" |
| 168 | # 服务内后端(reranker 进程启动时读取) | 168 | # 服务内后端(reranker 进程启动时读取) |
| 169 | - backend: "qwen3_vllm" # bge | qwen3_vllm | qwen3_transformers | dashscope_rerank | 169 | + backend: "dashscope_rerank" # bge | qwen3_vllm | qwen3_transformers | dashscope_rerank |
| 170 | backends: | 170 | backends: |
| 171 | bge: | 171 | bge: |
| 172 | model_name: "BAAI/bge-reranker-v2-m3" | 172 | model_name: "BAAI/bge-reranker-v2-m3" |
| @@ -203,9 +203,10 @@ services: | @@ -203,9 +203,10 @@ services: | ||
| 203 | # 新加坡: https://dashscope-intl.aliyuncs.com/compatible-api/v1/reranks | 203 | # 新加坡: https://dashscope-intl.aliyuncs.com/compatible-api/v1/reranks |
| 204 | # 美国: https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks | 204 | # 美国: https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks |
| 205 | endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" | 205 | endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" |
| 206 | - api_key: null # 推荐通过环境变量 DASHSCOPE_API_KEY 设置 | ||
| 207 | - timeout_sec: 15.0 | 206 | + api_key_env: "RERANK_DASHSCOPE_API_KEY_CN" |
| 207 | + timeout_sec: 10.0 # | ||
| 208 | top_n_cap: 0 # 0 表示 top_n=当前请求文档数;>0 则限制 top_n 上限 | 208 | top_n_cap: 0 # 0 表示 top_n=当前请求文档数;>0 则限制 top_n 上限 |
| 209 | + batchsize: 64 # 0 关闭;>0 启用并发小包调度(top_n/top_n_cap 仍生效,分包后全局截断) | ||
| 209 | instruct: "Given a shopping query, rank product titles by relevance" | 210 | instruct: "Given a shopping query, rank product titles by relevance" |
| 210 | max_retries: 2 | 211 | max_retries: 2 |
| 211 | retry_backoff_sec: 0.2 | 212 | retry_backoff_sec: 0.2 |
docs/DEVELOPER_GUIDE.md
| @@ -334,7 +334,7 @@ services: | @@ -334,7 +334,7 @@ services: | ||
| 334 | 334 | ||
| 335 | - **单一路径**:Provider 和 backend 必须由 `config/config.yaml` 的 `services` 块显式指定;未知配置应直接报错。 | 335 | - **单一路径**:Provider 和 backend 必须由 `config/config.yaml` 的 `services` 块显式指定;未知配置应直接报错。 |
| 336 | - **无兼容回退**:不保留“旧配置自动推导/兜底默认值”机制,避免静默行为偏差。 | 336 | - **无兼容回退**:不保留“旧配置自动推导/兜底默认值”机制,避免静默行为偏差。 |
| 337 | -- **环境变量覆盖**:允许环境变量覆盖(如 `RERANKER_SERVICE_URL`、`RERANK_BACKEND`、`DASHSCOPE_API_KEY`、`RERANK_DASHSCOPE_ENDPOINT`、`EMBEDDING_SERVICE_URL`、`EMBEDDING_BACKEND`、`TEI_BASE_URL`),但覆盖后仍需满足合法性校验。 | 337 | +- **环境变量覆盖**:允许环境变量覆盖(如 `RERANKER_SERVICE_URL`、`RERANK_BACKEND`、`RERANK_DASHSCOPE_API_KEY_CN`/`RERANK_DASHSCOPE_API_KEY_US`、`RERANK_DASHSCOPE_ENDPOINT`、`EMBEDDING_SERVICE_URL`、`EMBEDDING_BACKEND`、`TEI_BASE_URL`),但覆盖后仍需满足合法性校验。 |
| 338 | 338 | ||
| 339 | --- | 339 | --- |
| 340 | 340 |
docs/QUICKSTART.md
| @@ -423,7 +423,7 @@ services: | @@ -423,7 +423,7 @@ services: | ||
| 423 | - `TEI_BASE_URL` | 423 | - `TEI_BASE_URL` |
| 424 | - `RERANKER_SERVICE_URL` | 424 | - `RERANKER_SERVICE_URL` |
| 425 | - `RERANK_BACKEND`(服务内后端) | 425 | - `RERANK_BACKEND`(服务内后端) |
| 426 | -- `DASHSCOPE_API_KEY`(`dashscope_rerank` 后端鉴权) | 426 | +- `RERANK_DASHSCOPE_API_KEY_CN` / `RERANK_DASHSCOPE_API_KEY_US`(`dashscope_rerank` 后端鉴权) |
| 427 | - `RERANK_DASHSCOPE_ENDPOINT`(`dashscope_rerank` 地域 endpoint 覆盖) | 427 | - `RERANK_DASHSCOPE_ENDPOINT`(`dashscope_rerank` 地域 endpoint 覆盖) |
| 428 | 428 | ||
| 429 | ### 3.3 新增 provider 的最小步骤 | 429 | ### 3.3 新增 provider 的最小步骤 |
reranker/README.md
| @@ -69,9 +69,10 @@ services: | @@ -69,9 +69,10 @@ services: | ||
| 69 | dashscope_rerank: | 69 | dashscope_rerank: |
| 70 | model_name: "qwen3-rerank" | 70 | model_name: "qwen3-rerank" |
| 71 | endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" | 71 | endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" |
| 72 | - api_key: null # 推荐使用环境变量 DASHSCOPE_API_KEY | 72 | + api_key_env: "RERANK_DASHSCOPE_API_KEY_CN" |
| 73 | timeout_sec: 15.0 | 73 | timeout_sec: 15.0 |
| 74 | top_n_cap: 0 | 74 | top_n_cap: 0 |
| 75 | + batchsize: 64 # 0关闭;>0并发小包调度(top_n/top_n_cap 仍生效,分包后全局截断) | ||
| 75 | instruct: "Given a shopping query, rank product titles by relevance" | 76 | instruct: "Given a shopping query, rank product titles by relevance" |
| 76 | max_retries: 2 | 77 | max_retries: 2 |
| 77 | retry_backoff_sec: 0.2 | 78 | retry_backoff_sec: 0.2 |
| @@ -83,8 +84,10 @@ DashScope endpoint 地域示例: | @@ -83,8 +84,10 @@ DashScope endpoint 地域示例: | ||
| 83 | - 美国:`https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks` | 84 | - 美国:`https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks` |
| 84 | 85 | ||
| 85 | DashScope 认证: | 86 | DashScope 认证: |
| 86 | -- `api_key` 支持配置在 `config.yaml` | ||
| 87 | -- 推荐通过环境变量注入:`DASHSCOPE_API_KEY=...` | 87 | +- `api_key_env` 必填,表示该后端读取哪个环境变量作为 API Key |
| 88 | +- 推荐按地域分别注入: | ||
| 89 | + - `RERANK_DASHSCOPE_API_KEY_CN=...` | ||
| 90 | + - `RERANK_DASHSCOPE_API_KEY_US=...` | ||
| 88 | 91 | ||
| 89 | - 服务端口、请求限制等仍在 `reranker/config.py`(或环境变量 `RERANKER_PORT`、`RERANKER_HOST`)。 | 92 | - 服务端口、请求限制等仍在 `reranker/config.py`(或环境变量 `RERANKER_PORT`、`RERANKER_HOST`)。 |
| 90 | 93 |
reranker/backends/dashscope_rerank.py
| @@ -16,11 +16,12 @@ import logging | @@ -16,11 +16,12 @@ import logging | ||
| 16 | import math | 16 | import math |
| 17 | import os | 17 | import os |
| 18 | import time | 18 | import time |
| 19 | +from concurrent.futures import ThreadPoolExecutor, as_completed | ||
| 19 | from typing import Any, Dict, List, Tuple | 20 | from typing import Any, Dict, List, Tuple |
| 20 | from urllib import error as urllib_error | 21 | from urllib import error as urllib_error |
| 21 | from urllib import request as urllib_request | 22 | from urllib import request as urllib_request |
| 22 | 23 | ||
| 23 | -from reranker.backends.batching_utils import deduplicate_with_positions | 24 | +from reranker.backends.batching_utils import deduplicate_with_positions, iter_batches |
| 24 | 25 | ||
| 25 | logger = logging.getLogger("reranker.backends.dashscope_rerank") | 26 | logger = logging.getLogger("reranker.backends.dashscope_rerank") |
| 26 | 27 | ||
| @@ -32,19 +33,20 @@ class DashScopeRerankBackend: | @@ -32,19 +33,20 @@ class DashScopeRerankBackend: | ||
| 32 | Config from services.rerank.backends.dashscope_rerank: | 33 | Config from services.rerank.backends.dashscope_rerank: |
| 33 | - model_name: str, default "qwen3-rerank" | 34 | - model_name: str, default "qwen3-rerank" |
| 34 | - endpoint: str, default "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" | 35 | - endpoint: str, default "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" |
| 35 | - - api_key: optional str (or env DASHSCOPE_API_KEY) | 36 | + - api_key_env: str, required env var name for this backend key |
| 36 | - timeout_sec: float, default 15.0 | 37 | - timeout_sec: float, default 15.0 |
| 37 | - top_n_cap: int, optional cap; 0 means use all docs in request | 38 | - top_n_cap: int, optional cap; 0 means use all docs in request |
| 39 | + - batchsize: int, optional; 0 disables batching; >0 enables concurrent small-batch scheduling | ||
| 38 | - instruct: optional str | 40 | - instruct: optional str |
| 39 | - max_retries: int, default 1 | 41 | - max_retries: int, default 1 |
| 40 | - retry_backoff_sec: float, default 0.2 | 42 | - retry_backoff_sec: float, default 0.2 |
| 41 | 43 | ||
| 42 | Env overrides: | 44 | Env overrides: |
| 43 | - - DASHSCOPE_API_KEY | ||
| 44 | - RERANK_DASHSCOPE_ENDPOINT | 45 | - RERANK_DASHSCOPE_ENDPOINT |
| 45 | - RERANK_DASHSCOPE_MODEL | 46 | - RERANK_DASHSCOPE_MODEL |
| 46 | - RERANK_DASHSCOPE_TIMEOUT_SEC | 47 | - RERANK_DASHSCOPE_TIMEOUT_SEC |
| 47 | - RERANK_DASHSCOPE_TOP_N_CAP | 48 | - RERANK_DASHSCOPE_TOP_N_CAP |
| 49 | + - RERANK_DASHSCOPE_BATCHSIZE | ||
| 48 | """ | 50 | """ |
| 49 | 51 | ||
| 50 | def __init__(self, config: Dict[str, Any]) -> None: | 52 | def __init__(self, config: Dict[str, Any]) -> None: |
| @@ -59,11 +61,8 @@ class DashScopeRerankBackend: | @@ -59,11 +61,8 @@ class DashScopeRerankBackend: | ||
| 59 | or self._config.get("endpoint") | 61 | or self._config.get("endpoint") |
| 60 | or "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" | 62 | or "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" |
| 61 | ).strip() | 63 | ).strip() |
| 62 | - self._api_key = str( | ||
| 63 | - os.getenv("DASHSCOPE_API_KEY") | ||
| 64 | - or self._config.get("api_key") | ||
| 65 | - or "" | ||
| 66 | - ).strip().strip('"').strip("'") | 64 | + self._api_key_env = str(self._config.get("api_key_env") or "").strip() |
| 65 | + self._api_key = str(os.getenv(self._api_key_env) or "").strip().strip('"').strip("'") | ||
| 67 | self._timeout_sec = float( | 66 | self._timeout_sec = float( |
| 68 | os.getenv("RERANK_DASHSCOPE_TIMEOUT_SEC") | 67 | os.getenv("RERANK_DASHSCOPE_TIMEOUT_SEC") |
| 69 | or self._config.get("timeout_sec") | 68 | or self._config.get("timeout_sec") |
| @@ -74,21 +73,29 @@ class DashScopeRerankBackend: | @@ -74,21 +73,29 @@ class DashScopeRerankBackend: | ||
| 74 | or self._config.get("top_n_cap") | 73 | or self._config.get("top_n_cap") |
| 75 | or 0 | 74 | or 0 |
| 76 | ) | 75 | ) |
| 76 | + self._batchsize = int( | ||
| 77 | + os.getenv("RERANK_DASHSCOPE_BATCHSIZE") | ||
| 78 | + or self._config.get("batchsize") | ||
| 79 | + or 0 | ||
| 80 | + ) | ||
| 77 | self._instruct = str(self._config.get("instruct") or "").strip() | 81 | self._instruct = str(self._config.get("instruct") or "").strip() |
| 78 | self._max_retries = int(self._config.get("max_retries", 1)) | 82 | self._max_retries = int(self._config.get("max_retries", 1)) |
| 79 | self._retry_backoff_sec = float(self._config.get("retry_backoff_sec", 0.2)) | 83 | self._retry_backoff_sec = float(self._config.get("retry_backoff_sec", 0.2)) |
| 80 | 84 | ||
| 81 | if not self._endpoint: | 85 | if not self._endpoint: |
| 82 | raise ValueError("dashscope_rerank endpoint is required") | 86 | raise ValueError("dashscope_rerank endpoint is required") |
| 87 | + if not self._api_key_env: | ||
| 88 | + raise ValueError("dashscope_rerank api_key_env is required") | ||
| 83 | if not self._api_key: | 89 | if not self._api_key: |
| 84 | raise ValueError( | 90 | raise ValueError( |
| 85 | - "dashscope_rerank api_key is required (set services.rerank.backends.dashscope_rerank.api_key " | ||
| 86 | - "or env DASHSCOPE_API_KEY)" | 91 | + f"dashscope_rerank api key is required (set env {self._api_key_env})" |
| 87 | ) | 92 | ) |
| 88 | if self._timeout_sec <= 0: | 93 | if self._timeout_sec <= 0: |
| 89 | raise ValueError(f"dashscope_rerank timeout_sec must be > 0, got {self._timeout_sec}") | 94 | raise ValueError(f"dashscope_rerank timeout_sec must be > 0, got {self._timeout_sec}") |
| 90 | if self._top_n_cap < 0: | 95 | if self._top_n_cap < 0: |
| 91 | raise ValueError(f"dashscope_rerank top_n_cap must be >= 0, got {self._top_n_cap}") | 96 | raise ValueError(f"dashscope_rerank top_n_cap must be >= 0, got {self._top_n_cap}") |
| 97 | + if self._batchsize < 0: | ||
| 98 | + raise ValueError(f"dashscope_rerank batchsize must be >= 0, got {self._batchsize}") | ||
| 92 | if self._max_retries <= 0: | 99 | if self._max_retries <= 0: |
| 93 | raise ValueError(f"dashscope_rerank max_retries must be > 0, got {self._max_retries}") | 100 | raise ValueError(f"dashscope_rerank max_retries must be > 0, got {self._max_retries}") |
| 94 | if self._retry_backoff_sec < 0: | 101 | if self._retry_backoff_sec < 0: |
| @@ -97,11 +104,12 @@ class DashScopeRerankBackend: | @@ -97,11 +104,12 @@ class DashScopeRerankBackend: | ||
| 97 | ) | 104 | ) |
| 98 | 105 | ||
| 99 | logger.info( | 106 | logger.info( |
| 100 | - "DashScope reranker ready | endpoint=%s model=%s timeout_sec=%s top_n_cap=%s", | 107 | + "DashScope reranker ready | endpoint=%s model=%s timeout_sec=%s top_n_cap=%s batchsize=%s", |
| 101 | self._endpoint, | 108 | self._endpoint, |
| 102 | self._model_name, | 109 | self._model_name, |
| 103 | self._timeout_sec, | 110 | self._timeout_sec, |
| 104 | self._top_n_cap, | 111 | self._top_n_cap, |
| 112 | + self._batchsize, | ||
| 105 | ) | 113 | ) |
| 106 | 114 | ||
| 107 | def _http_post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]: | 115 | def _http_post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]: |
| @@ -162,6 +170,95 @@ class DashScopeRerankBackend: | @@ -162,6 +170,95 @@ class DashScopeRerankBackend: | ||
| 162 | 170 | ||
| 163 | raise RuntimeError(str(last_exc) if last_exc else "DashScope rerank failed with unknown error") | 171 | raise RuntimeError(str(last_exc) if last_exc else "DashScope rerank failed with unknown error") |
| 164 | 172 | ||
| 173 | + def _score_single_request( | ||
| 174 | + self, | ||
| 175 | + query: str, | ||
| 176 | + unique_texts: List[str], | ||
| 177 | + normalize: bool, | ||
| 178 | + top_n: int, | ||
| 179 | + ) -> Tuple[List[float], int]: | ||
| 180 | + response = self._post_rerank(query=query, docs=unique_texts, top_n=top_n) | ||
| 181 | + results = self._extract_results(response) | ||
| 182 | + | ||
| 183 | + unique_scores: List[float] = [0.0] * len(unique_texts) | ||
| 184 | + for rank, item in enumerate(results): | ||
| 185 | + raw_idx = item.get("index", rank) | ||
| 186 | + try: | ||
| 187 | + idx = int(raw_idx) | ||
| 188 | + except (TypeError, ValueError): | ||
| 189 | + continue | ||
| 190 | + if idx < 0 or idx >= len(unique_scores): | ||
| 191 | + continue | ||
| 192 | + raw_score = item.get("relevance_score", item.get("score")) | ||
| 193 | + unique_scores[idx] = self._coerce_score(raw_score, normalize=normalize) | ||
| 194 | + return unique_scores, len(results) | ||
| 195 | + | ||
| 196 | + def _score_batched_concurrent( | ||
| 197 | + self, | ||
| 198 | + query: str, | ||
| 199 | + unique_texts: List[str], | ||
| 200 | + normalize: bool, | ||
| 201 | + ) -> Tuple[List[float], Dict[str, int]]: | ||
| 202 | + """ | ||
| 203 | + Concurrent batch scoring. | ||
| 204 | + | ||
| 205 | + We intentionally request full local scores in each batch (top_n=len(batch)), | ||
| 206 | + then apply global top_n/top_n_cap truncation after merge if needed. | ||
| 207 | + """ | ||
| 208 | + indices = list(range(len(unique_texts))) | ||
| 209 | + batches = list(iter_batches(indices, batch_size=self._batchsize)) | ||
| 210 | + num_batches = len(batches) | ||
| 211 | + max_workers = min(8, num_batches) if num_batches > 0 else 1 | ||
| 212 | + unique_scores: List[float] = [0.0] * len(unique_texts) | ||
| 213 | + response_results = 0 | ||
| 214 | + | ||
| 215 | + def _run_one(batch_no: int, batch_indices: List[int]) -> Tuple[int, List[int], Dict[str, Any], float]: | ||
| 216 | + docs = [unique_texts[i] for i in batch_indices] | ||
| 217 | + # Ask each batch for all docs to avoid local truncation. | ||
| 218 | + start_ts = time.perf_counter() | ||
| 219 | + data = self._post_rerank(query=query, docs=docs, top_n=len(docs)) | ||
| 220 | + elapsed_ms = round((time.perf_counter() - start_ts) * 1000.0, 3) | ||
| 221 | + return batch_no, batch_indices, data, elapsed_ms | ||
| 222 | + | ||
| 223 | + with ThreadPoolExecutor(max_workers=max_workers) as ex: | ||
| 224 | + future_to_batch = {ex.submit(_run_one, i + 1, b): b for i, b in enumerate(batches)} | ||
| 225 | + for fut in as_completed(future_to_batch): | ||
| 226 | + batch_indices = future_to_batch[fut] | ||
| 227 | + try: | ||
| 228 | + batch_no, _, data, batch_elapsed_ms = fut.result() | ||
| 229 | + except Exception as exc: | ||
| 230 | + raise RuntimeError( | ||
| 231 | + f"DashScope rerank batch failed | batch_size={len(batch_indices)} error={exc}" | ||
| 232 | + ) from exc | ||
| 233 | + results = self._extract_results(data) | ||
| 234 | + logger.info( | ||
| 235 | + "DashScope batch response | batch=%d/%d docs=%d elapsed_ms=%s results=%d query=%r", | ||
| 236 | + batch_no, | ||
| 237 | + num_batches, | ||
| 238 | + len(batch_indices), | ||
| 239 | + batch_elapsed_ms, | ||
| 240 | + len(results), | ||
| 241 | + query[:80], | ||
| 242 | + ) | ||
| 243 | + response_results += len(results) | ||
| 244 | + for rank, item in enumerate(results): | ||
| 245 | + raw_idx = item.get("index", rank) | ||
| 246 | + try: | ||
| 247 | + local_idx = int(raw_idx) | ||
| 248 | + except (TypeError, ValueError): | ||
| 249 | + continue | ||
| 250 | + if local_idx < 0 or local_idx >= len(batch_indices): | ||
| 251 | + continue | ||
| 252 | + global_idx = batch_indices[local_idx] | ||
| 253 | + raw_score = item.get("relevance_score", item.get("score")) | ||
| 254 | + unique_scores[global_idx] = self._coerce_score(raw_score, normalize=normalize) | ||
| 255 | + | ||
| 256 | + return unique_scores, { | ||
| 257 | + "batches": num_batches, | ||
| 258 | + "batch_concurrency": max_workers, | ||
| 259 | + "response_results": response_results, | ||
| 260 | + } | ||
| 261 | + | ||
| 165 | @staticmethod | 262 | @staticmethod |
| 166 | def _extract_results(data: Dict[str, Any]) -> List[Dict[str, Any]]: | 263 | def _extract_results(data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| 167 | # Compatible API style: {"results":[...]} | 264 | # Compatible API style: {"results":[...]} |
| @@ -240,21 +337,34 @@ class DashScopeRerankBackend: | @@ -240,21 +337,34 @@ class DashScopeRerankBackend: | ||
| 240 | top_n_effective = min(top_n_effective, int(top_n)) | 337 | top_n_effective = min(top_n_effective, int(top_n)) |
| 241 | if self._top_n_cap > 0: | 338 | if self._top_n_cap > 0: |
| 242 | top_n_effective = min(top_n_effective, self._top_n_cap) | 339 | top_n_effective = min(top_n_effective, self._top_n_cap) |
| 243 | - | ||
| 244 | - response = self._post_rerank(query=query, docs=unique_texts, top_n=top_n_effective) | ||
| 245 | - results = self._extract_results(response) | ||
| 246 | - | ||
| 247 | - unique_scores: List[float] = [0.0] * len(unique_texts) | ||
| 248 | - for rank, item in enumerate(results): | ||
| 249 | - raw_idx = item.get("index", rank) | ||
| 250 | - try: | ||
| 251 | - idx = int(raw_idx) | ||
| 252 | - except (TypeError, ValueError): | ||
| 253 | - continue | ||
| 254 | - if idx < 0 or idx >= len(unique_scores): | ||
| 255 | - continue | ||
| 256 | - raw_score = item.get("relevance_score", item.get("score")) | ||
| 257 | - unique_scores[idx] = self._coerce_score(raw_score, normalize=normalize) | 340 | + can_batch = ( |
| 341 | + self._batchsize > 0 | ||
| 342 | + and len(unique_texts) > self._batchsize | ||
| 343 | + ) | ||
| 344 | + if can_batch: | ||
| 345 | + unique_scores, batch_meta = self._score_batched_concurrent( | ||
| 346 | + query=query, | ||
| 347 | + unique_texts=unique_texts, | ||
| 348 | + normalize=normalize, | ||
| 349 | + ) | ||
| 350 | + if top_n_effective < len(unique_scores): | ||
| 351 | + order = sorted(range(len(unique_scores)), key=lambda i: (-unique_scores[i], i)) | ||
| 352 | + keep = set(order[:top_n_effective]) | ||
| 353 | + for i in range(len(unique_scores)): | ||
| 354 | + if i not in keep: | ||
| 355 | + unique_scores[i] = 0.0 | ||
| 356 | + response_results = int(batch_meta["response_results"]) | ||
| 357 | + batches = int(batch_meta["batches"]) | ||
| 358 | + batch_concurrency = int(batch_meta["batch_concurrency"]) | ||
| 359 | + else: | ||
| 360 | + unique_scores, response_results = self._score_single_request( | ||
| 361 | + query=query, | ||
| 362 | + unique_texts=unique_texts, | ||
| 363 | + normalize=normalize, | ||
| 364 | + top_n=top_n_effective, | ||
| 365 | + ) | ||
| 366 | + batches = 1 | ||
| 367 | + batch_concurrency = 1 | ||
| 258 | 368 | ||
| 259 | for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): | 369 | for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): |
| 260 | output_scores[orig_idx] = float(unique_scores[unique_idx]) | 370 | output_scores[orig_idx] = float(unique_scores[unique_idx]) |
| @@ -275,7 +385,10 @@ class DashScopeRerankBackend: | @@ -275,7 +385,10 @@ class DashScopeRerankBackend: | ||
| 275 | "normalize": normalize, | 385 | "normalize": normalize, |
| 276 | "top_n": top_n_effective, | 386 | "top_n": top_n_effective, |
| 277 | "requested_top_n": int(top_n) if top_n is not None else None, | 387 | "requested_top_n": int(top_n) if top_n is not None else None, |
| 278 | - "response_results": len(results), | 388 | + "response_results": response_results, |
| 389 | + "batchsize": self._batchsize, | ||
| 390 | + "batches": batches, | ||
| 391 | + "batch_concurrency": batch_concurrency, | ||
| 279 | "endpoint": self._endpoint, | 392 | "endpoint": self._endpoint, |
| 280 | } | 393 | } |
| 281 | 394 |
reranker/server.py
| @@ -154,11 +154,14 @@ def rerank(request: RerankRequest) -> RerankResponse: | @@ -154,11 +154,14 @@ def rerank(request: RerankRequest) -> RerankResponse: | ||
| 154 | meta.update({"service_elapsed_ms": round((time.time() - start_ts) * 1000.0, 3)}) | 154 | meta.update({"service_elapsed_ms": round((time.time() - start_ts) * 1000.0, 3)}) |
| 155 | score_preview = [round(float(s), 6) for s in scores[:_LOG_DOC_PREVIEW_COUNT]] | 155 | score_preview = [round(float(s), 6) for s in scores[:_LOG_DOC_PREVIEW_COUNT]] |
| 156 | logger.info( | 156 | logger.info( |
| 157 | - "Rerank done | docs=%d unique=%s dedup=%s elapsed_ms=%s query=%r score_preview=%s", | 157 | + "Rerank done | docs=%d unique=%s dedup=%s elapsed_ms=%s batches=%s batchsize=%s batch_concurrency=%s query=%r score_preview=%s", |
| 158 | meta.get("input_docs"), | 158 | meta.get("input_docs"), |
| 159 | meta.get("unique_docs"), | 159 | meta.get("unique_docs"), |
| 160 | meta.get("dedup_ratio"), | 160 | meta.get("dedup_ratio"), |
| 161 | meta.get("service_elapsed_ms"), | 161 | meta.get("service_elapsed_ms"), |
| 162 | + meta.get("batches"), | ||
| 163 | + meta.get("batchsize"), | ||
| 164 | + meta.get("batch_concurrency"), | ||
| 162 | _compact_preview(query, _LOG_TEXT_PREVIEW_CHARS), | 165 | _compact_preview(query, _LOG_TEXT_PREVIEW_CHARS), |
| 163 | score_preview, | 166 | score_preview, |
| 164 | ) | 167 | ) |
tests/test_reranker_dashscope_backend.py
| 1 | from __future__ import annotations | 1 | from __future__ import annotations |
| 2 | 2 | ||
| 3 | +import time | ||
| 4 | + | ||
| 5 | +import pytest | ||
| 6 | + | ||
| 3 | from reranker.backends import get_rerank_backend | 7 | from reranker.backends import get_rerank_backend |
| 4 | from reranker.backends.dashscope_rerank import DashScopeRerankBackend | 8 | from reranker.backends.dashscope_rerank import DashScopeRerankBackend |
| 5 | 9 | ||
| 6 | 10 | ||
| 7 | -def test_dashscope_backend_factory_loads(): | 11 | +@pytest.fixture(autouse=True) |
| 12 | +def _clear_global_dashscope_key(monkeypatch): | ||
| 13 | + # Prevent accidental pass-through from unrelated global key. | ||
| 14 | + monkeypatch.delenv("DASHSCOPE_API_KEY", raising=False) | ||
| 15 | + | ||
| 16 | + | ||
| 17 | +def test_dashscope_backend_factory_loads(monkeypatch): | ||
| 18 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | ||
| 8 | backend = get_rerank_backend( | 19 | backend = get_rerank_backend( |
| 9 | "dashscope_rerank", | 20 | "dashscope_rerank", |
| 10 | { | 21 | { |
| 11 | "model_name": "qwen3-rerank", | 22 | "model_name": "qwen3-rerank", |
| 12 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | 23 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", |
| 13 | - "api_key": "test-key", | 24 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", |
| 14 | }, | 25 | }, |
| 15 | ) | 26 | ) |
| 16 | assert isinstance(backend, DashScopeRerankBackend) | 27 | assert isinstance(backend, DashScopeRerankBackend) |
| 17 | 28 | ||
| 18 | 29 | ||
| 19 | def test_dashscope_backend_score_with_meta_dedup_and_restore(monkeypatch): | 30 | def test_dashscope_backend_score_with_meta_dedup_and_restore(monkeypatch): |
| 31 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | ||
| 20 | backend = DashScopeRerankBackend( | 32 | backend = DashScopeRerankBackend( |
| 21 | { | 33 | { |
| 22 | "model_name": "qwen3-rerank", | 34 | "model_name": "qwen3-rerank", |
| 23 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | 35 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", |
| 24 | - "api_key": "test-key", | 36 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", |
| 25 | "top_n_cap": 0, | 37 | "top_n_cap": 0, |
| 26 | } | 38 | } |
| 27 | ) | 39 | ) |
| @@ -55,11 +67,12 @@ def test_dashscope_backend_score_with_meta_dedup_and_restore(monkeypatch): | @@ -55,11 +67,12 @@ def test_dashscope_backend_score_with_meta_dedup_and_restore(monkeypatch): | ||
| 55 | 67 | ||
| 56 | 68 | ||
| 57 | def test_dashscope_backend_top_n_cap_and_normalize_fallback(monkeypatch): | 69 | def test_dashscope_backend_top_n_cap_and_normalize_fallback(monkeypatch): |
| 70 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | ||
| 58 | backend = DashScopeRerankBackend( | 71 | backend = DashScopeRerankBackend( |
| 59 | { | 72 | { |
| 60 | "model_name": "qwen3-rerank", | 73 | "model_name": "qwen3-rerank", |
| 61 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | 74 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", |
| 62 | - "api_key": "test-key", | 75 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", |
| 63 | "top_n_cap": 1, | 76 | "top_n_cap": 1, |
| 64 | } | 77 | } |
| 65 | ) | 78 | ) |
| @@ -81,11 +94,12 @@ def test_dashscope_backend_top_n_cap_and_normalize_fallback(monkeypatch): | @@ -81,11 +94,12 @@ def test_dashscope_backend_top_n_cap_and_normalize_fallback(monkeypatch): | ||
| 81 | 94 | ||
| 82 | 95 | ||
| 83 | def test_dashscope_backend_score_with_meta_topn_request(monkeypatch): | 96 | def test_dashscope_backend_score_with_meta_topn_request(monkeypatch): |
| 97 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | ||
| 84 | backend = DashScopeRerankBackend( | 98 | backend = DashScopeRerankBackend( |
| 85 | { | 99 | { |
| 86 | "model_name": "qwen3-rerank", | 100 | "model_name": "qwen3-rerank", |
| 87 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | 101 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", |
| 88 | - "api_key": "test-key", | 102 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", |
| 89 | "top_n_cap": 0, | 103 | "top_n_cap": 0, |
| 90 | } | 104 | } |
| 91 | ) | 105 | ) |
| @@ -101,3 +115,127 @@ def test_dashscope_backend_score_with_meta_topn_request(monkeypatch): | @@ -101,3 +115,127 @@ def test_dashscope_backend_score_with_meta_topn_request(monkeypatch): | ||
| 101 | assert scores == [0.3, 0.0, 0.8] | 115 | assert scores == [0.3, 0.0, 0.8] |
| 102 | assert meta["top_n"] == 2 | 116 | assert meta["top_n"] == 2 |
| 103 | assert meta["requested_top_n"] == 2 | 117 | assert meta["requested_top_n"] == 2 |
| 118 | + | ||
| 119 | + | ||
| 120 | +def test_dashscope_backend_batchsize_concurrent_full_topn(monkeypatch): | ||
| 121 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | ||
| 122 | + backend = DashScopeRerankBackend( | ||
| 123 | + { | ||
| 124 | + "model_name": "qwen3-rerank", | ||
| 125 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | ||
| 126 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | ||
| 127 | + "top_n_cap": 0, | ||
| 128 | + "batchsize": 2, | ||
| 129 | + } | ||
| 130 | + ) | ||
| 131 | + | ||
| 132 | + def _fake_post(query: str, docs: list[str], top_n: int): | ||
| 133 | + assert query == "q" | ||
| 134 | + # batching path asks every batch for full local list | ||
| 135 | + assert top_n == len(docs) | ||
| 136 | + time.sleep(0.05) | ||
| 137 | + return { | ||
| 138 | + "results": [ | ||
| 139 | + {"index": i, "relevance_score": float(i + 1) / 10.0} | ||
| 140 | + for i, _ in enumerate(docs) | ||
| 141 | + ] | ||
| 142 | + } | ||
| 143 | + | ||
| 144 | + monkeypatch.setattr(backend, "_post_rerank", _fake_post) | ||
| 145 | + start = time.perf_counter() | ||
| 146 | + scores, meta = backend.score_with_meta(query="q", docs=["d1", "d2", "d3", "d4", "d5", "d6"]) | ||
| 147 | + elapsed = time.perf_counter() - start | ||
| 148 | + | ||
| 149 | + # 3 batches * 50ms serial ~=150ms; concurrent should be significantly lower. | ||
| 150 | + assert elapsed < 0.14 | ||
| 151 | + assert len(scores) == 6 | ||
| 152 | + assert meta["batches"] == 3 | ||
| 153 | + assert meta["batch_concurrency"] == 3 | ||
| 154 | + assert meta["response_results"] == 6 | ||
| 155 | + | ||
| 156 | + | ||
| 157 | +def test_dashscope_backend_batchsize_still_effective_when_topn_limited(monkeypatch): | ||
| 158 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | ||
| 159 | + backend = DashScopeRerankBackend( | ||
| 160 | + { | ||
| 161 | + "model_name": "qwen3-rerank", | ||
| 162 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | ||
| 163 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | ||
| 164 | + "top_n_cap": 0, | ||
| 165 | + "batchsize": 2, | ||
| 166 | + } | ||
| 167 | + ) | ||
| 168 | + | ||
| 169 | + called = {"count": 0} | ||
| 170 | + | ||
| 171 | + def _fake_post(query: str, docs: list[str], top_n: int): | ||
| 172 | + called["count"] += 1 | ||
| 173 | + # batching remains enabled; each batch asks for full local scores | ||
| 174 | + assert top_n == len(docs) | ||
| 175 | + score_map = {"d1": 0.9, "d2": 0.1, "d3": 0.8, "d4": 0.2} | ||
| 176 | + return { | ||
| 177 | + "results": [ | ||
| 178 | + {"index": i, "relevance_score": score_map[doc]} | ||
| 179 | + for i, doc in enumerate(docs) | ||
| 180 | + ] | ||
| 181 | + } | ||
| 182 | + | ||
| 183 | + monkeypatch.setattr(backend, "_post_rerank", _fake_post) | ||
| 184 | + scores, meta = backend.score_with_meta_topn(query="q", docs=["d1", "d2", "d3", "d4"], top_n=2) | ||
| 185 | + | ||
| 186 | + assert called["count"] == 2 | ||
| 187 | + assert scores == [0.9, 0.0, 0.8, 0.0] | ||
| 188 | + assert meta["batches"] == 2 | ||
| 189 | + assert meta["top_n"] == 2 | ||
| 190 | + | ||
| 191 | + | ||
| 192 | +def test_dashscope_backend_batchsize_raises_when_one_batch_fails(monkeypatch): | ||
| 193 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | ||
| 194 | + backend = DashScopeRerankBackend( | ||
| 195 | + { | ||
| 196 | + "model_name": "qwen3-rerank", | ||
| 197 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | ||
| 198 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | ||
| 199 | + "top_n_cap": 0, | ||
| 200 | + "batchsize": 2, | ||
| 201 | + } | ||
| 202 | + ) | ||
| 203 | + | ||
| 204 | + def _fake_post(query: str, docs: list[str], top_n: int): | ||
| 205 | + if docs == ["d3", "d4"]: | ||
| 206 | + raise RuntimeError("provider temporary error") | ||
| 207 | + return { | ||
| 208 | + "results": [ | ||
| 209 | + {"index": i, "relevance_score": 0.1} | ||
| 210 | + for i, _ in enumerate(docs) | ||
| 211 | + ] | ||
| 212 | + } | ||
| 213 | + | ||
| 214 | + monkeypatch.setattr(backend, "_post_rerank", _fake_post) | ||
| 215 | + | ||
| 216 | + with pytest.raises(RuntimeError, match="DashScope rerank batch failed"): | ||
| 217 | + backend.score_with_meta(query="q", docs=["d1", "d2", "d3", "d4"]) | ||
| 218 | + | ||
| 219 | + | ||
| 220 | +def test_dashscope_backend_requires_api_key_env(): | ||
| 221 | + with pytest.raises(ValueError, match="api_key_env is required"): | ||
| 222 | + DashScopeRerankBackend( | ||
| 223 | + { | ||
| 224 | + "model_name": "qwen3-rerank", | ||
| 225 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | ||
| 226 | + "top_n_cap": 0, | ||
| 227 | + } | ||
| 228 | + ) | ||
| 229 | + | ||
| 230 | + | ||
| 231 | +def test_dashscope_backend_requires_api_key_env_value(monkeypatch): | ||
| 232 | + monkeypatch.delenv("TEST_RERANK_DASHSCOPE_API_KEY", raising=False) | ||
| 233 | + with pytest.raises(ValueError, match="set env TEST_RERANK_DASHSCOPE_API_KEY"): | ||
| 234 | + DashScopeRerankBackend( | ||
| 235 | + { | ||
| 236 | + "model_name": "qwen3-rerank", | ||
| 237 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | ||
| 238 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | ||
| 239 | + "top_n_cap": 0, | ||
| 240 | + } | ||
| 241 | + ) |