Commit 0d3e73ba8718d32d00a4f66eb749375ba333e634
1 parent
d387e05d
rerank mini batch
Showing
9 changed files
with
306 additions
and
45 deletions
Show diff stats
| ... | ... | @@ -35,9 +35,10 @@ CACHE_DIR=.cache |
| 35 | 35 | API_BASE_URL=http://43.166.252.75:6002 |
| 36 | 36 | |
| 37 | 37 | |
| 38 | -# 国内 | |
| 38 | +# 通用 DashScope key(翻译/内容理解等模块) | |
| 39 | 39 | DASHSCOPE_API_KEY=sk-c3b8d4db061840aa8effb748df2a997b |
| 40 | -# 美国 | |
| 41 | -DASHSCOPE_API_KEY=sk-482cc3ff37a8467dab134a7a46830556 | |
| 40 | +# Reranker 专用 key(按地域) | |
| 41 | +RERANK_DASHSCOPE_API_KEY_CN=sk-c3b8d4db061840aa8effb748df2a997b | |
| 42 | +RERANK_DASHSCOPE_API_KEY_US=sk-482cc3ff37a8467dab134a7a46830556 | |
| 42 | 43 | |
| 43 | 44 | OPENAI_API_KEY=sk-HvmTMKtuznibZ75l7L2uF2jiaYocCthqd8Cbdkl09KTE7Ft0 | ... | ... |
.env.example
| ... | ... | @@ -45,7 +45,9 @@ TEI_HEALTH_TIMEOUT_SEC=300 |
| 45 | 45 | RERANK_PROVIDER=http |
| 46 | 46 | RERANK_BACKEND=qwen3_vllm |
| 47 | 47 | # Optional for cloud rerank backend (RERANK_BACKEND=dashscope_rerank) |
| 48 | -DASHSCOPE_API_KEY= | |
| 48 | +# Reranker cloud API keys by region | |
| 49 | +RERANK_DASHSCOPE_API_KEY_CN= | |
| 50 | +RERANK_DASHSCOPE_API_KEY_US= | |
| 49 | 51 | # Example: |
| 50 | 52 | # RERANK_DASHSCOPE_ENDPOINT=https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks |
| 51 | 53 | RERANK_DASHSCOPE_ENDPOINT= | ... | ... |
config/config.yaml
| ... | ... | @@ -166,7 +166,7 @@ services: |
| 166 | 166 | base_url: "http://127.0.0.1:6007" |
| 167 | 167 | service_url: "http://127.0.0.1:6007/rerank" |
| 168 | 168 | # 服务内后端(reranker 进程启动时读取) |
| 169 | - backend: "qwen3_vllm" # bge | qwen3_vllm | qwen3_transformers | dashscope_rerank | |
| 169 | + backend: "dashscope_rerank" # bge | qwen3_vllm | qwen3_transformers | dashscope_rerank | |
| 170 | 170 | backends: |
| 171 | 171 | bge: |
| 172 | 172 | model_name: "BAAI/bge-reranker-v2-m3" |
| ... | ... | @@ -203,9 +203,10 @@ services: |
| 203 | 203 | # 新加坡: https://dashscope-intl.aliyuncs.com/compatible-api/v1/reranks |
| 204 | 204 | # 美国: https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks |
| 205 | 205 | endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" |
| 206 | - api_key: null # 推荐通过环境变量 DASHSCOPE_API_KEY 设置 | |
| 207 | - timeout_sec: 15.0 | |
| 206 | + api_key_env: "RERANK_DASHSCOPE_API_KEY_CN" | |
| 207 | + timeout_sec: 10.0 # | |
| 208 | 208 | top_n_cap: 0 # 0 表示 top_n=当前请求文档数;>0 则限制 top_n 上限 |
| 209 | + batchsize: 64 # 0 关闭;>0 启用并发小包调度(top_n/top_n_cap 仍生效,分包后全局截断) | |
| 209 | 210 | instruct: "Given a shopping query, rank product titles by relevance" |
| 210 | 211 | max_retries: 2 |
| 211 | 212 | retry_backoff_sec: 0.2 | ... | ... |
docs/DEVELOPER_GUIDE.md
| ... | ... | @@ -334,7 +334,7 @@ services: |
| 334 | 334 | |
| 335 | 335 | - **单一路径**:Provider 和 backend 必须由 `config/config.yaml` 的 `services` 块显式指定;未知配置应直接报错。 |
| 336 | 336 | - **无兼容回退**:不保留“旧配置自动推导/兜底默认值”机制,避免静默行为偏差。 |
| 337 | -- **环境变量覆盖**:允许环境变量覆盖(如 `RERANKER_SERVICE_URL`、`RERANK_BACKEND`、`DASHSCOPE_API_KEY`、`RERANK_DASHSCOPE_ENDPOINT`、`EMBEDDING_SERVICE_URL`、`EMBEDDING_BACKEND`、`TEI_BASE_URL`),但覆盖后仍需满足合法性校验。 | |
| 337 | +- **环境变量覆盖**:允许环境变量覆盖(如 `RERANKER_SERVICE_URL`、`RERANK_BACKEND`、`RERANK_DASHSCOPE_API_KEY_CN`/`RERANK_DASHSCOPE_API_KEY_US`、`RERANK_DASHSCOPE_ENDPOINT`、`EMBEDDING_SERVICE_URL`、`EMBEDDING_BACKEND`、`TEI_BASE_URL`),但覆盖后仍需满足合法性校验。 | |
| 338 | 338 | |
| 339 | 339 | --- |
| 340 | 340 | ... | ... |
docs/QUICKSTART.md
| ... | ... | @@ -423,7 +423,7 @@ services: |
| 423 | 423 | - `TEI_BASE_URL` |
| 424 | 424 | - `RERANKER_SERVICE_URL` |
| 425 | 425 | - `RERANK_BACKEND`(服务内后端) |
| 426 | -- `DASHSCOPE_API_KEY`(`dashscope_rerank` 后端鉴权) | |
| 426 | +- `RERANK_DASHSCOPE_API_KEY_CN` / `RERANK_DASHSCOPE_API_KEY_US`(`dashscope_rerank` 后端鉴权) | |
| 427 | 427 | - `RERANK_DASHSCOPE_ENDPOINT`(`dashscope_rerank` 地域 endpoint 覆盖) |
| 428 | 428 | |
| 429 | 429 | ### 3.3 新增 provider 的最小步骤 | ... | ... |
reranker/README.md
| ... | ... | @@ -69,9 +69,10 @@ services: |
| 69 | 69 | dashscope_rerank: |
| 70 | 70 | model_name: "qwen3-rerank" |
| 71 | 71 | endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" |
| 72 | - api_key: null # 推荐使用环境变量 DASHSCOPE_API_KEY | |
| 72 | + api_key_env: "RERANK_DASHSCOPE_API_KEY_CN" | |
| 73 | 73 | timeout_sec: 15.0 |
| 74 | 74 | top_n_cap: 0 |
| 75 | + batchsize: 64 # 0关闭;>0并发小包调度(top_n/top_n_cap 仍生效,分包后全局截断) | |
| 75 | 76 | instruct: "Given a shopping query, rank product titles by relevance" |
| 76 | 77 | max_retries: 2 |
| 77 | 78 | retry_backoff_sec: 0.2 |
| ... | ... | @@ -83,8 +84,10 @@ DashScope endpoint 地域示例: |
| 83 | 84 | - 美国:`https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks` |
| 84 | 85 | |
| 85 | 86 | DashScope 认证: |
| 86 | -- `api_key` 支持配置在 `config.yaml` | |
| 87 | -- 推荐通过环境变量注入:`DASHSCOPE_API_KEY=...` | |
| 87 | +- `api_key_env` 必填,表示该后端读取哪个环境变量作为 API Key | |
| 88 | +- 推荐按地域分别注入: | |
| 89 | + - `RERANK_DASHSCOPE_API_KEY_CN=...` | |
| 90 | + - `RERANK_DASHSCOPE_API_KEY_US=...` | |
| 88 | 91 | |
| 89 | 92 | - 服务端口、请求限制等仍在 `reranker/config.py`(或环境变量 `RERANKER_PORT`、`RERANKER_HOST`)。 |
| 90 | 93 | ... | ... |
reranker/backends/dashscope_rerank.py
| ... | ... | @@ -16,11 +16,12 @@ import logging |
| 16 | 16 | import math |
| 17 | 17 | import os |
| 18 | 18 | import time |
| 19 | +from concurrent.futures import ThreadPoolExecutor, as_completed | |
| 19 | 20 | from typing import Any, Dict, List, Tuple |
| 20 | 21 | from urllib import error as urllib_error |
| 21 | 22 | from urllib import request as urllib_request |
| 22 | 23 | |
| 23 | -from reranker.backends.batching_utils import deduplicate_with_positions | |
| 24 | +from reranker.backends.batching_utils import deduplicate_with_positions, iter_batches | |
| 24 | 25 | |
| 25 | 26 | logger = logging.getLogger("reranker.backends.dashscope_rerank") |
| 26 | 27 | |
| ... | ... | @@ -32,19 +33,20 @@ class DashScopeRerankBackend: |
| 32 | 33 | Config from services.rerank.backends.dashscope_rerank: |
| 33 | 34 | - model_name: str, default "qwen3-rerank" |
| 34 | 35 | - endpoint: str, default "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" |
| 35 | - - api_key: optional str (or env DASHSCOPE_API_KEY) | |
| 36 | + - api_key_env: str, required env var name for this backend key | |
| 36 | 37 | - timeout_sec: float, default 15.0 |
| 37 | 38 | - top_n_cap: int, optional cap; 0 means use all docs in request |
| 39 | + - batchsize: int, optional; 0 disables batching; >0 enables concurrent small-batch scheduling | |
| 38 | 40 | - instruct: optional str |
| 39 | 41 | - max_retries: int, default 1 |
| 40 | 42 | - retry_backoff_sec: float, default 0.2 |
| 41 | 43 | |
| 42 | 44 | Env overrides: |
| 43 | - - DASHSCOPE_API_KEY | |
| 44 | 45 | - RERANK_DASHSCOPE_ENDPOINT |
| 45 | 46 | - RERANK_DASHSCOPE_MODEL |
| 46 | 47 | - RERANK_DASHSCOPE_TIMEOUT_SEC |
| 47 | 48 | - RERANK_DASHSCOPE_TOP_N_CAP |
| 49 | + - RERANK_DASHSCOPE_BATCHSIZE | |
| 48 | 50 | """ |
| 49 | 51 | |
| 50 | 52 | def __init__(self, config: Dict[str, Any]) -> None: |
| ... | ... | @@ -59,11 +61,8 @@ class DashScopeRerankBackend: |
| 59 | 61 | or self._config.get("endpoint") |
| 60 | 62 | or "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" |
| 61 | 63 | ).strip() |
| 62 | - self._api_key = str( | |
| 63 | - os.getenv("DASHSCOPE_API_KEY") | |
| 64 | - or self._config.get("api_key") | |
| 65 | - or "" | |
| 66 | - ).strip().strip('"').strip("'") | |
| 64 | + self._api_key_env = str(self._config.get("api_key_env") or "").strip() | |
| 65 | + self._api_key = str(os.getenv(self._api_key_env) or "").strip().strip('"').strip("'") | |
| 67 | 66 | self._timeout_sec = float( |
| 68 | 67 | os.getenv("RERANK_DASHSCOPE_TIMEOUT_SEC") |
| 69 | 68 | or self._config.get("timeout_sec") |
| ... | ... | @@ -74,21 +73,29 @@ class DashScopeRerankBackend: |
| 74 | 73 | or self._config.get("top_n_cap") |
| 75 | 74 | or 0 |
| 76 | 75 | ) |
| 76 | + self._batchsize = int( | |
| 77 | + os.getenv("RERANK_DASHSCOPE_BATCHSIZE") | |
| 78 | + or self._config.get("batchsize") | |
| 79 | + or 0 | |
| 80 | + ) | |
| 77 | 81 | self._instruct = str(self._config.get("instruct") or "").strip() |
| 78 | 82 | self._max_retries = int(self._config.get("max_retries", 1)) |
| 79 | 83 | self._retry_backoff_sec = float(self._config.get("retry_backoff_sec", 0.2)) |
| 80 | 84 | |
| 81 | 85 | if not self._endpoint: |
| 82 | 86 | raise ValueError("dashscope_rerank endpoint is required") |
| 87 | + if not self._api_key_env: | |
| 88 | + raise ValueError("dashscope_rerank api_key_env is required") | |
| 83 | 89 | if not self._api_key: |
| 84 | 90 | raise ValueError( |
| 85 | - "dashscope_rerank api_key is required (set services.rerank.backends.dashscope_rerank.api_key " | |
| 86 | - "or env DASHSCOPE_API_KEY)" | |
| 91 | + f"dashscope_rerank api key is required (set env {self._api_key_env})" | |
| 87 | 92 | ) |
| 88 | 93 | if self._timeout_sec <= 0: |
| 89 | 94 | raise ValueError(f"dashscope_rerank timeout_sec must be > 0, got {self._timeout_sec}") |
| 90 | 95 | if self._top_n_cap < 0: |
| 91 | 96 | raise ValueError(f"dashscope_rerank top_n_cap must be >= 0, got {self._top_n_cap}") |
| 97 | + if self._batchsize < 0: | |
| 98 | + raise ValueError(f"dashscope_rerank batchsize must be >= 0, got {self._batchsize}") | |
| 92 | 99 | if self._max_retries <= 0: |
| 93 | 100 | raise ValueError(f"dashscope_rerank max_retries must be > 0, got {self._max_retries}") |
| 94 | 101 | if self._retry_backoff_sec < 0: |
| ... | ... | @@ -97,11 +104,12 @@ class DashScopeRerankBackend: |
| 97 | 104 | ) |
| 98 | 105 | |
| 99 | 106 | logger.info( |
| 100 | - "DashScope reranker ready | endpoint=%s model=%s timeout_sec=%s top_n_cap=%s", | |
| 107 | + "DashScope reranker ready | endpoint=%s model=%s timeout_sec=%s top_n_cap=%s batchsize=%s", | |
| 101 | 108 | self._endpoint, |
| 102 | 109 | self._model_name, |
| 103 | 110 | self._timeout_sec, |
| 104 | 111 | self._top_n_cap, |
| 112 | + self._batchsize, | |
| 105 | 113 | ) |
| 106 | 114 | |
| 107 | 115 | def _http_post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]: |
| ... | ... | @@ -162,6 +170,95 @@ class DashScopeRerankBackend: |
| 162 | 170 | |
| 163 | 171 | raise RuntimeError(str(last_exc) if last_exc else "DashScope rerank failed with unknown error") |
| 164 | 172 | |
| 173 | + def _score_single_request( | |
| 174 | + self, | |
| 175 | + query: str, | |
| 176 | + unique_texts: List[str], | |
| 177 | + normalize: bool, | |
| 178 | + top_n: int, | |
| 179 | + ) -> Tuple[List[float], int]: | |
| 180 | + response = self._post_rerank(query=query, docs=unique_texts, top_n=top_n) | |
| 181 | + results = self._extract_results(response) | |
| 182 | + | |
| 183 | + unique_scores: List[float] = [0.0] * len(unique_texts) | |
| 184 | + for rank, item in enumerate(results): | |
| 185 | + raw_idx = item.get("index", rank) | |
| 186 | + try: | |
| 187 | + idx = int(raw_idx) | |
| 188 | + except (TypeError, ValueError): | |
| 189 | + continue | |
| 190 | + if idx < 0 or idx >= len(unique_scores): | |
| 191 | + continue | |
| 192 | + raw_score = item.get("relevance_score", item.get("score")) | |
| 193 | + unique_scores[idx] = self._coerce_score(raw_score, normalize=normalize) | |
| 194 | + return unique_scores, len(results) | |
| 195 | + | |
| 196 | + def _score_batched_concurrent( | |
| 197 | + self, | |
| 198 | + query: str, | |
| 199 | + unique_texts: List[str], | |
| 200 | + normalize: bool, | |
| 201 | + ) -> Tuple[List[float], Dict[str, int]]: | |
| 202 | + """ | |
| 203 | + Concurrent batch scoring. | |
| 204 | + | |
| 205 | + We intentionally request full local scores in each batch (top_n=len(batch)), | |
| 206 | + then apply global top_n/top_n_cap truncation after merge if needed. | |
| 207 | + """ | |
| 208 | + indices = list(range(len(unique_texts))) | |
| 209 | + batches = list(iter_batches(indices, batch_size=self._batchsize)) | |
| 210 | + num_batches = len(batches) | |
| 211 | + max_workers = min(8, num_batches) if num_batches > 0 else 1 | |
| 212 | + unique_scores: List[float] = [0.0] * len(unique_texts) | |
| 213 | + response_results = 0 | |
| 214 | + | |
| 215 | + def _run_one(batch_no: int, batch_indices: List[int]) -> Tuple[int, List[int], Dict[str, Any], float]: | |
| 216 | + docs = [unique_texts[i] for i in batch_indices] | |
| 217 | + # Ask each batch for all docs to avoid local truncation. | |
| 218 | + start_ts = time.perf_counter() | |
| 219 | + data = self._post_rerank(query=query, docs=docs, top_n=len(docs)) | |
| 220 | + elapsed_ms = round((time.perf_counter() - start_ts) * 1000.0, 3) | |
| 221 | + return batch_no, batch_indices, data, elapsed_ms | |
| 222 | + | |
| 223 | + with ThreadPoolExecutor(max_workers=max_workers) as ex: | |
| 224 | + future_to_batch = {ex.submit(_run_one, i + 1, b): b for i, b in enumerate(batches)} | |
| 225 | + for fut in as_completed(future_to_batch): | |
| 226 | + batch_indices = future_to_batch[fut] | |
| 227 | + try: | |
| 228 | + batch_no, _, data, batch_elapsed_ms = fut.result() | |
| 229 | + except Exception as exc: | |
| 230 | + raise RuntimeError( | |
| 231 | + f"DashScope rerank batch failed | batch_size={len(batch_indices)} error={exc}" | |
| 232 | + ) from exc | |
| 233 | + results = self._extract_results(data) | |
| 234 | + logger.info( | |
| 235 | + "DashScope batch response | batch=%d/%d docs=%d elapsed_ms=%s results=%d query=%r", | |
| 236 | + batch_no, | |
| 237 | + num_batches, | |
| 238 | + len(batch_indices), | |
| 239 | + batch_elapsed_ms, | |
| 240 | + len(results), | |
| 241 | + query[:80], | |
| 242 | + ) | |
| 243 | + response_results += len(results) | |
| 244 | + for rank, item in enumerate(results): | |
| 245 | + raw_idx = item.get("index", rank) | |
| 246 | + try: | |
| 247 | + local_idx = int(raw_idx) | |
| 248 | + except (TypeError, ValueError): | |
| 249 | + continue | |
| 250 | + if local_idx < 0 or local_idx >= len(batch_indices): | |
| 251 | + continue | |
| 252 | + global_idx = batch_indices[local_idx] | |
| 253 | + raw_score = item.get("relevance_score", item.get("score")) | |
| 254 | + unique_scores[global_idx] = self._coerce_score(raw_score, normalize=normalize) | |
| 255 | + | |
| 256 | + return unique_scores, { | |
| 257 | + "batches": num_batches, | |
| 258 | + "batch_concurrency": max_workers, | |
| 259 | + "response_results": response_results, | |
| 260 | + } | |
| 261 | + | |
| 165 | 262 | @staticmethod |
| 166 | 263 | def _extract_results(data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| 167 | 264 | # Compatible API style: {"results":[...]} |
| ... | ... | @@ -240,21 +337,34 @@ class DashScopeRerankBackend: |
| 240 | 337 | top_n_effective = min(top_n_effective, int(top_n)) |
| 241 | 338 | if self._top_n_cap > 0: |
| 242 | 339 | top_n_effective = min(top_n_effective, self._top_n_cap) |
| 243 | - | |
| 244 | - response = self._post_rerank(query=query, docs=unique_texts, top_n=top_n_effective) | |
| 245 | - results = self._extract_results(response) | |
| 246 | - | |
| 247 | - unique_scores: List[float] = [0.0] * len(unique_texts) | |
| 248 | - for rank, item in enumerate(results): | |
| 249 | - raw_idx = item.get("index", rank) | |
| 250 | - try: | |
| 251 | - idx = int(raw_idx) | |
| 252 | - except (TypeError, ValueError): | |
| 253 | - continue | |
| 254 | - if idx < 0 or idx >= len(unique_scores): | |
| 255 | - continue | |
| 256 | - raw_score = item.get("relevance_score", item.get("score")) | |
| 257 | - unique_scores[idx] = self._coerce_score(raw_score, normalize=normalize) | |
| 340 | + can_batch = ( | |
| 341 | + self._batchsize > 0 | |
| 342 | + and len(unique_texts) > self._batchsize | |
| 343 | + ) | |
| 344 | + if can_batch: | |
| 345 | + unique_scores, batch_meta = self._score_batched_concurrent( | |
| 346 | + query=query, | |
| 347 | + unique_texts=unique_texts, | |
| 348 | + normalize=normalize, | |
| 349 | + ) | |
| 350 | + if top_n_effective < len(unique_scores): | |
| 351 | + order = sorted(range(len(unique_scores)), key=lambda i: (-unique_scores[i], i)) | |
| 352 | + keep = set(order[:top_n_effective]) | |
| 353 | + for i in range(len(unique_scores)): | |
| 354 | + if i not in keep: | |
| 355 | + unique_scores[i] = 0.0 | |
| 356 | + response_results = int(batch_meta["response_results"]) | |
| 357 | + batches = int(batch_meta["batches"]) | |
| 358 | + batch_concurrency = int(batch_meta["batch_concurrency"]) | |
| 359 | + else: | |
| 360 | + unique_scores, response_results = self._score_single_request( | |
| 361 | + query=query, | |
| 362 | + unique_texts=unique_texts, | |
| 363 | + normalize=normalize, | |
| 364 | + top_n=top_n_effective, | |
| 365 | + ) | |
| 366 | + batches = 1 | |
| 367 | + batch_concurrency = 1 | |
| 258 | 368 | |
| 259 | 369 | for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): |
| 260 | 370 | output_scores[orig_idx] = float(unique_scores[unique_idx]) |
| ... | ... | @@ -275,7 +385,10 @@ class DashScopeRerankBackend: |
| 275 | 385 | "normalize": normalize, |
| 276 | 386 | "top_n": top_n_effective, |
| 277 | 387 | "requested_top_n": int(top_n) if top_n is not None else None, |
| 278 | - "response_results": len(results), | |
| 388 | + "response_results": response_results, | |
| 389 | + "batchsize": self._batchsize, | |
| 390 | + "batches": batches, | |
| 391 | + "batch_concurrency": batch_concurrency, | |
| 279 | 392 | "endpoint": self._endpoint, |
| 280 | 393 | } |
| 281 | 394 | ... | ... |
reranker/server.py
| ... | ... | @@ -154,11 +154,14 @@ def rerank(request: RerankRequest) -> RerankResponse: |
| 154 | 154 | meta.update({"service_elapsed_ms": round((time.time() - start_ts) * 1000.0, 3)}) |
| 155 | 155 | score_preview = [round(float(s), 6) for s in scores[:_LOG_DOC_PREVIEW_COUNT]] |
| 156 | 156 | logger.info( |
| 157 | - "Rerank done | docs=%d unique=%s dedup=%s elapsed_ms=%s query=%r score_preview=%s", | |
| 157 | + "Rerank done | docs=%d unique=%s dedup=%s elapsed_ms=%s batches=%s batchsize=%s batch_concurrency=%s query=%r score_preview=%s", | |
| 158 | 158 | meta.get("input_docs"), |
| 159 | 159 | meta.get("unique_docs"), |
| 160 | 160 | meta.get("dedup_ratio"), |
| 161 | 161 | meta.get("service_elapsed_ms"), |
| 162 | + meta.get("batches"), | |
| 163 | + meta.get("batchsize"), | |
| 164 | + meta.get("batch_concurrency"), | |
| 162 | 165 | _compact_preview(query, _LOG_TEXT_PREVIEW_CHARS), |
| 163 | 166 | score_preview, |
| 164 | 167 | ) | ... | ... |
tests/test_reranker_dashscope_backend.py
| 1 | 1 | from __future__ import annotations |
| 2 | 2 | |
| 3 | +import time | |
| 4 | + | |
| 5 | +import pytest | |
| 6 | + | |
| 3 | 7 | from reranker.backends import get_rerank_backend |
| 4 | 8 | from reranker.backends.dashscope_rerank import DashScopeRerankBackend |
| 5 | 9 | |
| 6 | 10 | |
| 7 | -def test_dashscope_backend_factory_loads(): | |
| 11 | +@pytest.fixture(autouse=True) | |
| 12 | +def _clear_global_dashscope_key(monkeypatch): | |
| 13 | + # Prevent accidental pass-through from unrelated global key. | |
| 14 | + monkeypatch.delenv("DASHSCOPE_API_KEY", raising=False) | |
| 15 | + | |
| 16 | + | |
| 17 | +def test_dashscope_backend_factory_loads(monkeypatch): | |
| 18 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | |
| 8 | 19 | backend = get_rerank_backend( |
| 9 | 20 | "dashscope_rerank", |
| 10 | 21 | { |
| 11 | 22 | "model_name": "qwen3-rerank", |
| 12 | 23 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", |
| 13 | - "api_key": "test-key", | |
| 24 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | |
| 14 | 25 | }, |
| 15 | 26 | ) |
| 16 | 27 | assert isinstance(backend, DashScopeRerankBackend) |
| 17 | 28 | |
| 18 | 29 | |
| 19 | 30 | def test_dashscope_backend_score_with_meta_dedup_and_restore(monkeypatch): |
| 31 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | |
| 20 | 32 | backend = DashScopeRerankBackend( |
| 21 | 33 | { |
| 22 | 34 | "model_name": "qwen3-rerank", |
| 23 | 35 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", |
| 24 | - "api_key": "test-key", | |
| 36 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | |
| 25 | 37 | "top_n_cap": 0, |
| 26 | 38 | } |
| 27 | 39 | ) |
| ... | ... | @@ -55,11 +67,12 @@ def test_dashscope_backend_score_with_meta_dedup_and_restore(monkeypatch): |
| 55 | 67 | |
| 56 | 68 | |
| 57 | 69 | def test_dashscope_backend_top_n_cap_and_normalize_fallback(monkeypatch): |
| 70 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | |
| 58 | 71 | backend = DashScopeRerankBackend( |
| 59 | 72 | { |
| 60 | 73 | "model_name": "qwen3-rerank", |
| 61 | 74 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", |
| 62 | - "api_key": "test-key", | |
| 75 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | |
| 63 | 76 | "top_n_cap": 1, |
| 64 | 77 | } |
| 65 | 78 | ) |
| ... | ... | @@ -81,11 +94,12 @@ def test_dashscope_backend_top_n_cap_and_normalize_fallback(monkeypatch): |
| 81 | 94 | |
| 82 | 95 | |
| 83 | 96 | def test_dashscope_backend_score_with_meta_topn_request(monkeypatch): |
| 97 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | |
| 84 | 98 | backend = DashScopeRerankBackend( |
| 85 | 99 | { |
| 86 | 100 | "model_name": "qwen3-rerank", |
| 87 | 101 | "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", |
| 88 | - "api_key": "test-key", | |
| 102 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | |
| 89 | 103 | "top_n_cap": 0, |
| 90 | 104 | } |
| 91 | 105 | ) |
| ... | ... | @@ -101,3 +115,127 @@ def test_dashscope_backend_score_with_meta_topn_request(monkeypatch): |
| 101 | 115 | assert scores == [0.3, 0.0, 0.8] |
| 102 | 116 | assert meta["top_n"] == 2 |
| 103 | 117 | assert meta["requested_top_n"] == 2 |
| 118 | + | |
| 119 | + | |
| 120 | +def test_dashscope_backend_batchsize_concurrent_full_topn(monkeypatch): | |
| 121 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | |
| 122 | + backend = DashScopeRerankBackend( | |
| 123 | + { | |
| 124 | + "model_name": "qwen3-rerank", | |
| 125 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | |
| 126 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | |
| 127 | + "top_n_cap": 0, | |
| 128 | + "batchsize": 2, | |
| 129 | + } | |
| 130 | + ) | |
| 131 | + | |
| 132 | + def _fake_post(query: str, docs: list[str], top_n: int): | |
| 133 | + assert query == "q" | |
| 134 | + # batching path asks every batch for full local list | |
| 135 | + assert top_n == len(docs) | |
| 136 | + time.sleep(0.05) | |
| 137 | + return { | |
| 138 | + "results": [ | |
| 139 | + {"index": i, "relevance_score": float(i + 1) / 10.0} | |
| 140 | + for i, _ in enumerate(docs) | |
| 141 | + ] | |
| 142 | + } | |
| 143 | + | |
| 144 | + monkeypatch.setattr(backend, "_post_rerank", _fake_post) | |
| 145 | + start = time.perf_counter() | |
| 146 | + scores, meta = backend.score_with_meta(query="q", docs=["d1", "d2", "d3", "d4", "d5", "d6"]) | |
| 147 | + elapsed = time.perf_counter() - start | |
| 148 | + | |
| 149 | + # 3 batches * 50ms serial ~=150ms; concurrent should be significantly lower. | |
| 150 | + assert elapsed < 0.14 | |
| 151 | + assert len(scores) == 6 | |
| 152 | + assert meta["batches"] == 3 | |
| 153 | + assert meta["batch_concurrency"] == 3 | |
| 154 | + assert meta["response_results"] == 6 | |
| 155 | + | |
| 156 | + | |
| 157 | +def test_dashscope_backend_batchsize_still_effective_when_topn_limited(monkeypatch): | |
| 158 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | |
| 159 | + backend = DashScopeRerankBackend( | |
| 160 | + { | |
| 161 | + "model_name": "qwen3-rerank", | |
| 162 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | |
| 163 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | |
| 164 | + "top_n_cap": 0, | |
| 165 | + "batchsize": 2, | |
| 166 | + } | |
| 167 | + ) | |
| 168 | + | |
| 169 | + called = {"count": 0} | |
| 170 | + | |
| 171 | + def _fake_post(query: str, docs: list[str], top_n: int): | |
| 172 | + called["count"] += 1 | |
| 173 | + # batching remains enabled; each batch asks for full local scores | |
| 174 | + assert top_n == len(docs) | |
| 175 | + score_map = {"d1": 0.9, "d2": 0.1, "d3": 0.8, "d4": 0.2} | |
| 176 | + return { | |
| 177 | + "results": [ | |
| 178 | + {"index": i, "relevance_score": score_map[doc]} | |
| 179 | + for i, doc in enumerate(docs) | |
| 180 | + ] | |
| 181 | + } | |
| 182 | + | |
| 183 | + monkeypatch.setattr(backend, "_post_rerank", _fake_post) | |
| 184 | + scores, meta = backend.score_with_meta_topn(query="q", docs=["d1", "d2", "d3", "d4"], top_n=2) | |
| 185 | + | |
| 186 | + assert called["count"] == 2 | |
| 187 | + assert scores == [0.9, 0.0, 0.8, 0.0] | |
| 188 | + assert meta["batches"] == 2 | |
| 189 | + assert meta["top_n"] == 2 | |
| 190 | + | |
| 191 | + | |
| 192 | +def test_dashscope_backend_batchsize_raises_when_one_batch_fails(monkeypatch): | |
| 193 | + monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key") | |
| 194 | + backend = DashScopeRerankBackend( | |
| 195 | + { | |
| 196 | + "model_name": "qwen3-rerank", | |
| 197 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | |
| 198 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | |
| 199 | + "top_n_cap": 0, | |
| 200 | + "batchsize": 2, | |
| 201 | + } | |
| 202 | + ) | |
| 203 | + | |
| 204 | + def _fake_post(query: str, docs: list[str], top_n: int): | |
| 205 | + if docs == ["d3", "d4"]: | |
| 206 | + raise RuntimeError("provider temporary error") | |
| 207 | + return { | |
| 208 | + "results": [ | |
| 209 | + {"index": i, "relevance_score": 0.1} | |
| 210 | + for i, _ in enumerate(docs) | |
| 211 | + ] | |
| 212 | + } | |
| 213 | + | |
| 214 | + monkeypatch.setattr(backend, "_post_rerank", _fake_post) | |
| 215 | + | |
| 216 | + with pytest.raises(RuntimeError, match="DashScope rerank batch failed"): | |
| 217 | + backend.score_with_meta(query="q", docs=["d1", "d2", "d3", "d4"]) | |
| 218 | + | |
| 219 | + | |
| 220 | +def test_dashscope_backend_requires_api_key_env(): | |
| 221 | + with pytest.raises(ValueError, match="api_key_env is required"): | |
| 222 | + DashScopeRerankBackend( | |
| 223 | + { | |
| 224 | + "model_name": "qwen3-rerank", | |
| 225 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | |
| 226 | + "top_n_cap": 0, | |
| 227 | + } | |
| 228 | + ) | |
| 229 | + | |
| 230 | + | |
| 231 | +def test_dashscope_backend_requires_api_key_env_value(monkeypatch): | |
| 232 | + monkeypatch.delenv("TEST_RERANK_DASHSCOPE_API_KEY", raising=False) | |
| 233 | + with pytest.raises(ValueError, match="set env TEST_RERANK_DASHSCOPE_API_KEY"): | |
| 234 | + DashScopeRerankBackend( | |
| 235 | + { | |
| 236 | + "model_name": "qwen3-rerank", | |
| 237 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | |
| 238 | + "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY", | |
| 239 | + "top_n_cap": 0, | |
| 240 | + } | |
| 241 | + ) | ... | ... |