Commit d31c7f65c06debea804f1f52ea2d72ca436acd7c
1 parent
a99e62ba
补充云服务reranker
Showing
16 changed files
with
624 additions
and
12 deletions
Show diff stats
.env.example
| ... | ... | @@ -44,6 +44,11 @@ TEI_MAX_CLIENT_BATCH_SIZE=8 |
| 44 | 44 | TEI_HEALTH_TIMEOUT_SEC=300 |
| 45 | 45 | RERANK_PROVIDER=http |
| 46 | 46 | RERANK_BACKEND=qwen3_vllm |
| 47 | +# Optional for cloud rerank backend (RERANK_BACKEND=dashscope_rerank) | |
| 48 | +DASHSCOPE_API_KEY= | |
| 49 | +# Example: | |
| 50 | +# RERANK_DASHSCOPE_ENDPOINT=https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks | |
| 51 | +RERANK_DASHSCOPE_ENDPOINT= | |
| 47 | 52 | |
| 48 | 53 | # Cache Directory |
| 49 | 54 | CACHE_DIR=.cache | ... | ... |
config/config.yaml
| ... | ... | @@ -166,7 +166,7 @@ services: |
| 166 | 166 | base_url: "http://127.0.0.1:6007" |
| 167 | 167 | service_url: "http://127.0.0.1:6007/rerank" |
| 168 | 168 | # 服务内后端(reranker 进程启动时读取) |
| 169 | - backend: "qwen3_vllm" # bge | qwen3_vllm | |
| 169 | + backend: "qwen3_vllm" # bge | qwen3_vllm | qwen3_transformers | dashscope_rerank | |
| 170 | 170 | backends: |
| 171 | 171 | bge: |
| 172 | 172 | model_name: "BAAI/bge-reranker-v2-m3" |
| ... | ... | @@ -189,6 +189,26 @@ services: |
| 189 | 189 | sort_by_doc_length: true |
| 190 | 190 | length_sort_mode: "char" # char | token |
| 191 | 191 | instruction: "Given a shopping query, rank product titles by relevance" |
| 192 | + qwen3_transformers: | |
| 193 | + model_name: "Qwen/Qwen3-Reranker-0.6B" | |
| 194 | + instruction: "Given a shopping query, rank product titles by relevance" | |
| 195 | + max_length: 8192 | |
| 196 | + batch_size: 64 | |
| 197 | + use_fp16: true | |
| 198 | + attn_implementation: "flash_attention_2" | |
| 199 | + dashscope_rerank: | |
| 200 | + model_name: "qwen3-rerank" | |
| 201 | + # 按地域选择 endpoint: | |
| 202 | + # 中国: https://dashscope.aliyuncs.com/compatible-api/v1/reranks | |
| 203 | + # 新加坡: https://dashscope-intl.aliyuncs.com/compatible-api/v1/reranks | |
| 204 | + # 美国: https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks | |
| 205 | + endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" | |
| 206 | + api_key: null # 推荐通过环境变量 DASHSCOPE_API_KEY 设置 | |
| 207 | + timeout_sec: 15.0 | |
| 208 | + top_n_cap: 0 # 0 表示 top_n=当前请求文档数;>0 则限制 top_n 上限 | |
| 209 | + instruct: "Given a shopping query, rank product titles by relevance" | |
| 210 | + max_retries: 2 | |
| 211 | + retry_backoff_sec: 0.2 | |
| 192 | 212 | |
| 193 | 213 | # SPU配置(已启用,使用嵌套skus) |
| 194 | 214 | spu_config: | ... | ... |
docs/DEVELOPER_GUIDE.md
| ... | ... | @@ -318,7 +318,7 @@ services: |
| 318 | 318 | |------|--------|------|--------| |
| 319 | 319 | | 调用方 | `services.<capability>.provider` | http | http | |
| 320 | 320 | | 调用方 | `services.<capability>.providers.http.base_url` | 6007 | 6005 | |
| 321 | -| 服务内 | `services.<capability>.backend` | qwen3_vllm / bge | tei / local_st | | |
| 321 | +| 服务内 | `services.<capability>.backend` | qwen3_vllm / qwen3_transformers / bge / dashscope_rerank | tei / local_st | | |
| 322 | 322 | | 服务内 | `services.<capability>.backends.<name>` | 模型名、batch、vLLM 参数 | 模型名、device 等 | |
| 323 | 323 | |
| 324 | 324 | ### 7.6 新增后端清单(以 Qwen3-Reranker 为例) |
| ... | ... | @@ -334,7 +334,7 @@ services: |
| 334 | 334 | |
| 335 | 335 | - **单一路径**:Provider 和 backend 必须由 `config/config.yaml` 的 `services` 块显式指定;未知配置应直接报错。 |
| 336 | 336 | - **无兼容回退**:不保留“旧配置自动推导/兜底默认值”机制,避免静默行为偏差。 |
| 337 | -- **环境变量覆盖**:允许环境变量覆盖(如 `RERANKER_SERVICE_URL`、`RERANK_BACKEND`、`EMBEDDING_SERVICE_URL`、`EMBEDDING_BACKEND`、`TEI_BASE_URL`),但覆盖后仍需满足合法性校验。 | |
| 337 | +- **环境变量覆盖**:允许环境变量覆盖(如 `RERANKER_SERVICE_URL`、`RERANK_BACKEND`、`DASHSCOPE_API_KEY`、`RERANK_DASHSCOPE_ENDPOINT`、`EMBEDDING_SERVICE_URL`、`EMBEDDING_BACKEND`、`TEI_BASE_URL`),但覆盖后仍需满足合法性校验。 | |
| 338 | 338 | |
| 339 | 339 | --- |
| 340 | 340 | ... | ... |
docs/QUICKSTART.md
| ... | ... | @@ -409,7 +409,7 @@ services: |
| 409 | 409 | tei: { base_url: "http://127.0.0.1:8080", timeout_sec: 60, model_id: "Qwen/Qwen3-Embedding-0.6B" } |
| 410 | 410 | rerank: |
| 411 | 411 | provider: "http" |
| 412 | - backend: "qwen3_vllm" | |
| 412 | + backend: "qwen3_vllm" # bge | qwen3_vllm | qwen3_transformers | dashscope_rerank | |
| 413 | 413 | providers: |
| 414 | 414 | http: { base_url: "http://127.0.0.1:6007", service_url: "http://127.0.0.1:6007/rerank" } |
| 415 | 415 | ``` |
| ... | ... | @@ -423,6 +423,8 @@ services: |
| 423 | 423 | - `TEI_BASE_URL` |
| 424 | 424 | - `RERANKER_SERVICE_URL` |
| 425 | 425 | - `RERANK_BACKEND`(服务内后端) |
| 426 | +- `DASHSCOPE_API_KEY`(`dashscope_rerank` 后端鉴权) | |
| 427 | +- `RERANK_DASHSCOPE_ENDPOINT`(`dashscope_rerank` 地域 endpoint 覆盖) | |
| 426 | 428 | |
| 427 | 429 | ### 3.3 新增 provider 的最小步骤 |
| 428 | 430 | |
| ... | ... | @@ -451,6 +453,8 @@ services: |
| 451 | 453 | - `reranker/backends/__init__.py`(工厂) |
| 452 | 454 | - `reranker/backends/bge.py` |
| 453 | 455 | - `reranker/backends/qwen3_vllm.py` |
| 456 | +- `reranker/backends/qwen3_transformers.py` | |
| 457 | +- `reranker/backends/dashscope_rerank.py` | |
| 454 | 458 | |
| 455 | 459 | 后端协议(服务内): |
| 456 | 460 | ... | ... |
docs/性能测试报告.md
| ... | ... | @@ -338,3 +338,46 @@ done |
| 338 | 338 | 异常说明: |
| 339 | 339 | - `tenant 0` 在并发 `20` 出现 `ReadTimeout`(25 次),该档成功率下降到 `59.02%` |
| 340 | 340 | - 其他租户在本轮口径下均为 `100%` 成功率 |
| 341 | + | |
| 342 | +## 13. Rerank 后端对比(qwen3_vllm vs DashScope 云服务) | |
| 343 | + | |
| 344 | +目标: | |
| 345 | +- 使用同一套构造数据,对比两个 rerank 微服务在电商搜索重排场景下的速度差异 | |
| 346 | +- 为后端选型提供直接依据 | |
| 347 | + | |
| 348 | +测试口径(两端一致): | |
| 349 | +- query:固定 `wireless mouse` | |
| 350 | +- docs:每次请求固定 `386` 条 | |
| 351 | +- 构造方式:从 `1000` 词池随机采样;每条 doc 句长随机 `15-40` | |
| 352 | +- `top_n`:`30`(模拟 `page+size`) | |
| 353 | +- 并发:`1 / 5 / 10 / 20` | |
| 354 | +- 每档时长:`20s` | |
| 355 | +- 每个后端跑 `2` 轮,以下表格为两轮均值 | |
| 356 | + | |
| 357 | +执行文件: | |
| 358 | +- vLLM:`perf_reports/2026-03-12/rerank_backend_compare/vllm_round1_topn30.json` | |
| 359 | +- vLLM:`perf_reports/2026-03-12/rerank_backend_compare/vllm_round2b_topn30.json` | |
| 360 | +- Cloud:`perf_reports/2026-03-12/rerank_backend_compare/cloud_round1_topn30.json` | |
| 361 | +- Cloud:`perf_reports/2026-03-12/rerank_backend_compare/cloud_round2_topn30.json` | |
| 362 | + | |
| 363 | +### 13.1 两轮均值对比 | |
| 364 | + | |
| 365 | +| 并发 | vLLM RPS | Cloud RPS | vLLM P95(ms) | Cloud P95(ms) | vLLM Avg(ms) | Cloud Avg(ms) | | |
| 366 | +|---:|---:|---:|---:|---:|---:|---:| | |
| 367 | +| 1 | 0.625 | 0.220 | 1937.68 | 6371.03 | 1602.37 | 4752.53 | | |
| 368 | +| 5 | 0.585 | 1.040 | 9421.37 | 7372.85 | 8480.29 | 4543.84 | | |
| 369 | +| 10 | 0.595 | 1.820 | 18040.65 | 7637.43 | 16767.64 | 4820.35 | | |
| 370 | +| 20 | 0.590 | 3.530 | 33766.06 | 8445.39 | 33563.23 | 4890.59 | | |
| 371 | + | |
| 372 | +### 13.2 结论 | |
| 373 | + | |
| 374 | +- 单并发(`c=1`)下,`qwen3_vllm` 更快(更低延迟、略高吞吐)。 | |
| 375 | +- 从 `c=5` 开始,DashScope 云后端明显更快: | |
| 376 | + - `c=5`:Cloud 吞吐约为 vLLM 的 `1.78x` | |
| 377 | + - `c=10`:Cloud 吞吐约为 vLLM 的 `3.06x` | |
| 378 | + - `c=20`:Cloud 吞吐约为 vLLM 的 `5.98x` | |
| 379 | +- 在“电商搜索在线重排(有并发)”场景下,当前实现建议优先选云后端。 | |
| 380 | + | |
| 381 | +说明: | |
| 382 | +- 本轮对比基于当前实现:`dashscope_rerank` 支持 `top_n`(本次取 `30`),`qwen3_vllm` 当前仍按全量 docs 评分。 | |
| 383 | +- 若后续为本地模型实现 `top_n` 局部重排能力,需要重新对比后再最终定版。 | ... | ... |
providers/rerank.py
| ... | ... | @@ -23,11 +23,14 @@ class HttpRerankProvider: |
| 23 | 23 | query: str, |
| 24 | 24 | docs: List[str], |
| 25 | 25 | timeout_sec: float, |
| 26 | + top_n: Optional[int] = None, | |
| 26 | 27 | ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]: |
| 27 | 28 | if not docs: |
| 28 | 29 | return [], {} |
| 29 | 30 | try: |
| 30 | 31 | payload = {"query": (query or "").strip(), "docs": docs} |
| 32 | + if top_n is not None and int(top_n) > 0: | |
| 33 | + payload["top_n"] = int(top_n) | |
| 31 | 34 | response = requests.post(self.service_url, json=payload, timeout=timeout_sec) |
| 32 | 35 | if response.status_code != 200: |
| 33 | 36 | logger.warning( | ... | ... |
reranker/README.md
| ... | ... | @@ -4,10 +4,11 @@ |
| 4 | 4 | |
| 5 | 5 | --- |
| 6 | 6 | |
| 7 | -Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM、Qwen3-Transformers)。调用方通过 HTTP 访问,不关心具体后端。 | |
| 7 | +Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM、Qwen3-Transformers、DashScope 云重排)。调用方通过 HTTP 访问,不关心具体后端。 | |
| 8 | 8 | |
| 9 | 9 | **特性** |
| 10 | 10 | - 多后端:`qwen3_vllm`(默认,Qwen3-Reranker-0.6B + vLLM)、`qwen3_transformers`(纯 Transformers,无需 vLLM)、`bge`(兼容保留) |
| 11 | +- 云后端:`dashscope_rerank`(调用 DashScope `/compatible-api/v1/reranks`,支持按地域切换 endpoint) | |
| 11 | 12 | - 统一配置:`config/config.yaml` → `services.rerank.backend` / `services.rerank.backends.<name>` |
| 12 | 13 | - 文档去重、分数与输入顺序一致、FP16/GPU 支持(视后端) |
| 13 | 14 | |
| ... | ... | @@ -18,6 +19,7 @@ Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwe |
| 18 | 19 | - `backends/bge.py`:BGE 后端 |
| 19 | 20 | - `backends/qwen3_vllm.py`:Qwen3-Reranker-0.6B + vLLM 后端 |
| 20 | 21 | - `backends/qwen3_transformers.py`:Qwen3-Reranker-0.6B 纯 Transformers 后端(官方 Usage 方式) |
| 22 | + - `backends/dashscope_rerank.py`:DashScope 云重排后端(HTTP 调用) | |
| 21 | 23 | - `reranker/bge_reranker.py`:BGE 核心推理(被 bge 后端封装) |
| 22 | 24 | - `reranker/config.py`:服务端口、MAX_DOCS、NORMALIZE 等(后端参数在 config.yaml) |
| 23 | 25 | |
| ... | ... | @@ -30,7 +32,7 @@ Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwe |
| 30 | 32 | ``` |
| 31 | 33 | |
| 32 | 34 | ## 配置 |
| 33 | -- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `qwen3_transformers` | `bge`),或环境变量 `RERANK_BACKEND`。 | |
| 35 | +- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `qwen3_transformers` | `bge` | `dashscope_rerank`),或环境变量 `RERANK_BACKEND`。 | |
| 34 | 36 | - **后端参数**:`services.rerank.backends.bge` / `services.rerank.backends.qwen3_vllm`,例如: |
| 35 | 37 | |
| 36 | 38 | ```yaml |
| ... | ... | @@ -64,8 +66,26 @@ services: |
| 64 | 66 | tensor_parallel_size: 1 |
| 65 | 67 | gpu_memory_utilization: 0.8 |
| 66 | 68 | instruction: "Given a shopping query, rank product titles by relevance" |
| 69 | + dashscope_rerank: | |
| 70 | + model_name: "qwen3-rerank" | |
| 71 | + endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" | |
| 72 | + api_key: null # 推荐使用环境变量 DASHSCOPE_API_KEY | |
| 73 | + timeout_sec: 15.0 | |
| 74 | + top_n_cap: 0 | |
| 75 | + instruct: "Given a shopping query, rank product titles by relevance" | |
| 76 | + max_retries: 2 | |
| 77 | + retry_backoff_sec: 0.2 | |
| 67 | 78 | ``` |
| 68 | 79 | |
| 80 | +DashScope endpoint 地域示例: | |
| 81 | +- 中国:`https://dashscope.aliyuncs.com/compatible-api/v1/reranks` | |
| 82 | +- 新加坡:`https://dashscope-intl.aliyuncs.com/compatible-api/v1/reranks` | |
| 83 | +- 美国:`https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks` | |
| 84 | + | |
| 85 | +DashScope 认证: | |
| 86 | +- `api_key` 支持配置在 `config.yaml` | |
| 87 | +- 推荐通过环境变量注入:`DASHSCOPE_API_KEY=...` | |
| 88 | + | |
| 69 | 89 | - 服务端口、请求限制等仍在 `reranker/config.py`(或环境变量 `RERANKER_PORT`、`RERANKER_HOST`)。 |
| 70 | 90 | |
| 71 | 91 | ## 运行 |
| ... | ... | @@ -94,10 +114,15 @@ Content-Type: application/json |
| 94 | 114 | |
| 95 | 115 | { |
| 96 | 116 | "query": "wireless mouse", |
| 97 | - "docs": ["logitech mx master", "usb cable", "wireless mouse bluetooth"] | |
| 117 | + "docs": ["logitech mx master", "usb cable", "wireless mouse bluetooth"], | |
| 118 | + "top_n": 10 | |
| 98 | 119 | } |
| 99 | 120 | ``` |
| 100 | 121 | |
| 122 | +`top_n` 为可选字段: | |
| 123 | +- 对本地后端(`qwen3_vllm` / `qwen3_transformers` / `bge`)通常会忽略,仍返回全量分数。 | |
| 124 | +- 对 `dashscope_rerank` 可用于控制云端返回的候选量,建议设置为 `page+size`(例如分页 `from=20,size=10` 时传 `30`)。 | |
| 125 | + | |
| 101 | 126 | Response: |
| 102 | 127 | ``` |
| 103 | 128 | { | ... | ... |
reranker/backends/__init__.py
| ... | ... | @@ -46,8 +46,11 @@ def get_rerank_backend(name: str, config: Dict[str, Any]) -> RerankBackendProtoc |
| 46 | 46 | if name == "qwen3_transformers": |
| 47 | 47 | from reranker.backends.qwen3_transformers import Qwen3TransformersRerankerBackend |
| 48 | 48 | return Qwen3TransformersRerankerBackend(config) |
| 49 | + if name == "dashscope_rerank": | |
| 50 | + from reranker.backends.dashscope_rerank import DashScopeRerankBackend | |
| 51 | + return DashScopeRerankBackend(config) | |
| 49 | 52 | raise ValueError( |
| 50 | - f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_transformers" | |
| 53 | + f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_transformers, dashscope_rerank" | |
| 51 | 54 | ) |
| 52 | 55 | |
| 53 | 56 | ... | ... |
| ... | ... | @@ -0,0 +1,288 @@ |
| 1 | +""" | |
| 2 | +DashScope cloud reranker backend (OpenAI-compatible reranks API). | |
| 3 | + | |
| 4 | +Reference: | |
| 5 | +- https://dashscope.aliyuncs.com/compatible-api/v1/reranks | |
| 6 | +- Use region-specific domains when needed: | |
| 7 | + - China: https://dashscope.aliyuncs.com | |
| 8 | + - Singapore: https://dashscope-intl.aliyuncs.com | |
| 9 | + - US: https://dashscope-us.aliyuncs.com | |
| 10 | +""" | |
| 11 | + | |
| 12 | +from __future__ import annotations | |
| 13 | + | |
| 14 | +import json | |
| 15 | +import logging | |
| 16 | +import math | |
| 17 | +import os | |
| 18 | +import time | |
| 19 | +from typing import Any, Dict, List, Tuple | |
| 20 | +from urllib import error as urllib_error | |
| 21 | +from urllib import request as urllib_request | |
| 22 | + | |
| 23 | +from reranker.backends.batching_utils import deduplicate_with_positions | |
| 24 | + | |
| 25 | +logger = logging.getLogger("reranker.backends.dashscope_rerank") | |
| 26 | + | |
| 27 | + | |
| 28 | +class DashScopeRerankBackend: | |
| 29 | + """ | |
| 30 | + DashScope cloud reranker backend. | |
| 31 | + | |
| 32 | + Config from services.rerank.backends.dashscope_rerank: | |
| 33 | + - model_name: str, default "qwen3-rerank" | |
| 34 | + - endpoint: str, default "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" | |
| 35 | + - api_key: optional str (or env DASHSCOPE_API_KEY) | |
| 36 | + - timeout_sec: float, default 15.0 | |
| 37 | + - top_n_cap: int, optional cap; 0 means use all docs in request | |
| 38 | + - instruct: optional str | |
| 39 | + - max_retries: int, default 1 | |
| 40 | + - retry_backoff_sec: float, default 0.2 | |
| 41 | + | |
| 42 | + Env overrides: | |
| 43 | + - DASHSCOPE_API_KEY | |
| 44 | + - RERANK_DASHSCOPE_ENDPOINT | |
| 45 | + - RERANK_DASHSCOPE_MODEL | |
| 46 | + - RERANK_DASHSCOPE_TIMEOUT_SEC | |
| 47 | + - RERANK_DASHSCOPE_TOP_N_CAP | |
| 48 | + """ | |
| 49 | + | |
| 50 | + def __init__(self, config: Dict[str, Any]) -> None: | |
| 51 | + self._config = config or {} | |
| 52 | + self._model_name = str( | |
| 53 | + os.getenv("RERANK_DASHSCOPE_MODEL") | |
| 54 | + or self._config.get("model_name") | |
| 55 | + or "qwen3-rerank" | |
| 56 | + ) | |
| 57 | + self._endpoint = str( | |
| 58 | + os.getenv("RERANK_DASHSCOPE_ENDPOINT") | |
| 59 | + or self._config.get("endpoint") | |
| 60 | + or "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" | |
| 61 | + ).strip() | |
| 62 | + self._api_key = str( | |
| 63 | + os.getenv("DASHSCOPE_API_KEY") | |
| 64 | + or self._config.get("api_key") | |
| 65 | + or "" | |
| 66 | + ).strip().strip('"').strip("'") | |
| 67 | + self._timeout_sec = float( | |
| 68 | + os.getenv("RERANK_DASHSCOPE_TIMEOUT_SEC") | |
| 69 | + or self._config.get("timeout_sec") | |
| 70 | + or 15.0 | |
| 71 | + ) | |
| 72 | + self._top_n_cap = int( | |
| 73 | + os.getenv("RERANK_DASHSCOPE_TOP_N_CAP") | |
| 74 | + or self._config.get("top_n_cap") | |
| 75 | + or 0 | |
| 76 | + ) | |
| 77 | + self._instruct = str(self._config.get("instruct") or "").strip() | |
| 78 | + self._max_retries = int(self._config.get("max_retries", 1)) | |
| 79 | + self._retry_backoff_sec = float(self._config.get("retry_backoff_sec", 0.2)) | |
| 80 | + | |
| 81 | + if not self._endpoint: | |
| 82 | + raise ValueError("dashscope_rerank endpoint is required") | |
| 83 | + if not self._api_key: | |
| 84 | + raise ValueError( | |
| 85 | + "dashscope_rerank api_key is required (set services.rerank.backends.dashscope_rerank.api_key " | |
| 86 | + "or env DASHSCOPE_API_KEY)" | |
| 87 | + ) | |
| 88 | + if self._timeout_sec <= 0: | |
| 89 | + raise ValueError(f"dashscope_rerank timeout_sec must be > 0, got {self._timeout_sec}") | |
| 90 | + if self._top_n_cap < 0: | |
| 91 | + raise ValueError(f"dashscope_rerank top_n_cap must be >= 0, got {self._top_n_cap}") | |
| 92 | + if self._max_retries <= 0: | |
| 93 | + raise ValueError(f"dashscope_rerank max_retries must be > 0, got {self._max_retries}") | |
| 94 | + if self._retry_backoff_sec < 0: | |
| 95 | + raise ValueError( | |
| 96 | + f"dashscope_rerank retry_backoff_sec must be >= 0, got {self._retry_backoff_sec}" | |
| 97 | + ) | |
| 98 | + | |
| 99 | + logger.info( | |
| 100 | + "DashScope reranker ready | endpoint=%s model=%s timeout_sec=%s top_n_cap=%s", | |
| 101 | + self._endpoint, | |
| 102 | + self._model_name, | |
| 103 | + self._timeout_sec, | |
| 104 | + self._top_n_cap, | |
| 105 | + ) | |
| 106 | + | |
| 107 | + def _http_post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]: | |
| 108 | + body = json.dumps(payload, ensure_ascii=False).encode("utf-8") | |
| 109 | + req = urllib_request.Request( | |
| 110 | + url=self._endpoint, | |
| 111 | + method="POST", | |
| 112 | + data=body, | |
| 113 | + headers={ | |
| 114 | + "Authorization": f"Bearer {self._api_key}", | |
| 115 | + "Content-Type": "application/json", | |
| 116 | + }, | |
| 117 | + ) | |
| 118 | + with urllib_request.urlopen(req, timeout=self._timeout_sec) as resp: | |
| 119 | + raw = resp.read().decode("utf-8", errors="replace") | |
| 120 | + try: | |
| 121 | + data = json.loads(raw) | |
| 122 | + except json.JSONDecodeError as exc: | |
| 123 | + raise RuntimeError(f"DashScope response is not valid JSON: {raw[:512]}") from exc | |
| 124 | + if not isinstance(data, dict): | |
| 125 | + raise RuntimeError(f"DashScope response must be JSON object, got: {type(data).__name__}") | |
| 126 | + return data | |
| 127 | + | |
| 128 | + def _post_rerank(self, query: str, docs: List[str], top_n: int) -> Dict[str, Any]: | |
| 129 | + payload: Dict[str, Any] = { | |
| 130 | + "model": self._model_name, | |
| 131 | + "query": query, | |
| 132 | + "documents": docs, | |
| 133 | + "top_n": top_n, | |
| 134 | + } | |
| 135 | + if self._instruct: | |
| 136 | + payload["instruct"] = self._instruct | |
| 137 | + | |
| 138 | + last_exc: Exception | None = None | |
| 139 | + for attempt in range(1, self._max_retries + 1): | |
| 140 | + try: | |
| 141 | + return self._http_post_json(payload) | |
| 142 | + except urllib_error.HTTPError as exc: | |
| 143 | + body = "" | |
| 144 | + try: | |
| 145 | + body = exc.read().decode("utf-8", errors="replace") | |
| 146 | + except Exception: | |
| 147 | + body = "" | |
| 148 | + last_exc = RuntimeError( | |
| 149 | + f"DashScope rerank HTTP {exc.code} (attempt {attempt}/{self._max_retries}): {body[:512]}" | |
| 150 | + ) | |
| 151 | + except urllib_error.URLError as exc: | |
| 152 | + last_exc = RuntimeError( | |
| 153 | + f"DashScope rerank network error (attempt {attempt}/{self._max_retries}): {exc}" | |
| 154 | + ) | |
| 155 | + except Exception as exc: # pragma: no cover - defensive | |
| 156 | + last_exc = RuntimeError( | |
| 157 | + f"DashScope rerank unexpected error (attempt {attempt}/{self._max_retries}): {exc}" | |
| 158 | + ) | |
| 159 | + | |
| 160 | + if attempt < self._max_retries and self._retry_backoff_sec > 0: | |
| 161 | + time.sleep(self._retry_backoff_sec * attempt) | |
| 162 | + | |
| 163 | + raise RuntimeError(str(last_exc) if last_exc else "DashScope rerank failed with unknown error") | |
| 164 | + | |
| 165 | + @staticmethod | |
| 166 | + def _extract_results(data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| 167 | + # Compatible API style: {"results":[...]} | |
| 168 | + results = data.get("results") | |
| 169 | + if isinstance(results, list): | |
| 170 | + return [x for x in results if isinstance(x, dict)] | |
| 171 | + | |
| 172 | + # Native style fallback: {"output":{"results":[...]}} | |
| 173 | + output = data.get("output") | |
| 174 | + if isinstance(output, dict): | |
| 175 | + output_results = output.get("results") | |
| 176 | + if isinstance(output_results, list): | |
| 177 | + return [x for x in output_results if isinstance(x, dict)] | |
| 178 | + | |
| 179 | + return [] | |
| 180 | + | |
| 181 | + @staticmethod | |
| 182 | + def _coerce_score(raw_score: Any, normalize: bool) -> float: | |
| 183 | + try: | |
| 184 | + score = float(raw_score) | |
| 185 | + except (TypeError, ValueError): | |
| 186 | + return 0.0 | |
| 187 | + | |
| 188 | + if not normalize: | |
| 189 | + return score | |
| 190 | + # DashScope relevance_score is typically already in [0,1]; keep it. | |
| 191 | + if 0.0 <= score <= 1.0: | |
| 192 | + return score | |
| 193 | + # Fallback when provider returns logits/raw scores. | |
| 194 | + if score > 60: | |
| 195 | + return 1.0 | |
| 196 | + if score < -60: | |
| 197 | + return 0.0 | |
| 198 | + return 1.0 / (1.0 + math.exp(-score)) | |
| 199 | + | |
| 200 | + def score_with_meta_topn( | |
| 201 | + self, | |
| 202 | + query: str, | |
| 203 | + docs: List[str], | |
| 204 | + normalize: bool = True, | |
| 205 | + top_n: int | None = None, | |
| 206 | + ) -> Tuple[List[float], Dict[str, Any]]: | |
| 207 | + start_ts = time.time() | |
| 208 | + total_docs = len(docs) if docs else 0 | |
| 209 | + output_scores: List[float] = [0.0] * total_docs | |
| 210 | + | |
| 211 | + query = "" if query is None else str(query).strip() | |
| 212 | + indexed: List[Tuple[int, str]] = [] | |
| 213 | + for i, doc in enumerate(docs or []): | |
| 214 | + if doc is None: | |
| 215 | + continue | |
| 216 | + text = str(doc).strip() | |
| 217 | + if not text: | |
| 218 | + continue | |
| 219 | + indexed.append((i, text)) | |
| 220 | + | |
| 221 | + if not query or not indexed: | |
| 222 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 223 | + return output_scores, { | |
| 224 | + "input_docs": total_docs, | |
| 225 | + "usable_docs": len(indexed), | |
| 226 | + "unique_docs": 0, | |
| 227 | + "dedup_ratio": 0.0, | |
| 228 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 229 | + "model": self._model_name, | |
| 230 | + "backend": "dashscope_rerank", | |
| 231 | + "normalize": normalize, | |
| 232 | + "top_n": 0, | |
| 233 | + } | |
| 234 | + | |
| 235 | + indexed_texts = [text for _, text in indexed] | |
| 236 | + unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts) | |
| 237 | + | |
| 238 | + top_n_effective = len(unique_texts) | |
| 239 | + if top_n is not None and int(top_n) > 0: | |
| 240 | + top_n_effective = min(top_n_effective, int(top_n)) | |
| 241 | + if self._top_n_cap > 0: | |
| 242 | + top_n_effective = min(top_n_effective, self._top_n_cap) | |
| 243 | + | |
| 244 | + response = self._post_rerank(query=query, docs=unique_texts, top_n=top_n_effective) | |
| 245 | + results = self._extract_results(response) | |
| 246 | + | |
| 247 | + unique_scores: List[float] = [0.0] * len(unique_texts) | |
| 248 | + for rank, item in enumerate(results): | |
| 249 | + raw_idx = item.get("index", rank) | |
| 250 | + try: | |
| 251 | + idx = int(raw_idx) | |
| 252 | + except (TypeError, ValueError): | |
| 253 | + continue | |
| 254 | + if idx < 0 or idx >= len(unique_scores): | |
| 255 | + continue | |
| 256 | + raw_score = item.get("relevance_score", item.get("score")) | |
| 257 | + unique_scores[idx] = self._coerce_score(raw_score, normalize=normalize) | |
| 258 | + | |
| 259 | + for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): | |
| 260 | + output_scores[orig_idx] = float(unique_scores[unique_idx]) | |
| 261 | + | |
| 262 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 263 | + dedup_ratio = 0.0 | |
| 264 | + if indexed: | |
| 265 | + dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) | |
| 266 | + | |
| 267 | + return output_scores, { | |
| 268 | + "input_docs": total_docs, | |
| 269 | + "usable_docs": len(indexed), | |
| 270 | + "unique_docs": len(unique_texts), | |
| 271 | + "dedup_ratio": round(dedup_ratio, 4), | |
| 272 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 273 | + "model": self._model_name, | |
| 274 | + "backend": "dashscope_rerank", | |
| 275 | + "normalize": normalize, | |
| 276 | + "top_n": top_n_effective, | |
| 277 | + "requested_top_n": int(top_n) if top_n is not None else None, | |
| 278 | + "response_results": len(results), | |
| 279 | + "endpoint": self._endpoint, | |
| 280 | + } | |
| 281 | + | |
| 282 | + def score_with_meta( | |
| 283 | + self, | |
| 284 | + query: str, | |
| 285 | + docs: List[str], | |
| 286 | + normalize: bool = True, | |
| 287 | + ) -> Tuple[List[float], Dict[str, Any]]: | |
| 288 | + return self.score_with_meta_topn(query=query, docs=docs, normalize=normalize, top_n=None) | ... | ... |
reranker/server.py
| 1 | 1 | """ |
| 2 | -Reranker service - unified /rerank API backed by pluggable backends (BGE, Qwen3-vLLM). | |
| 2 | +Reranker service - unified /rerank API backed by pluggable backends | |
| 3 | +(BGE, Qwen3-vLLM, Qwen3-Transformers, DashScope cloud rerank). | |
| 3 | 4 | |
| 4 | 5 | POST /rerank |
| 5 | 6 | Request: { "query": "...", "docs": ["doc1", "doc2", ...], "normalize": optional bool } |
| 6 | 7 | Response: { "scores": [float], "meta": {...} } |
| 7 | 8 | |
| 8 | -Backend selected via config: services.rerank.backend (bge | qwen3_vllm), env RERANK_BACKEND. | |
| 9 | +Backend selected via config: services.rerank.backend | |
| 10 | +(bge | qwen3_vllm | qwen3_transformers | dashscope_rerank), env RERANK_BACKEND. | |
| 9 | 11 | """ |
| 10 | 12 | |
| 11 | 13 | import logging |
| ... | ... | @@ -60,6 +62,10 @@ class RerankRequest(BaseModel): |
| 60 | 62 | normalize: Optional[bool] = Field( |
| 61 | 63 | default=CONFIG.NORMALIZE, description="Apply sigmoid normalization" |
| 62 | 64 | ) |
| 65 | + top_n: Optional[int] = Field( | |
| 66 | + default=None, | |
| 67 | + description="Optional top_n hint for backends that support partial ranking", | |
| 68 | + ) | |
| 63 | 69 | |
| 64 | 70 | |
| 65 | 71 | class RerankResponse(BaseModel): |
| ... | ... | @@ -118,8 +124,11 @@ def rerank(request: RerankRequest) -> RerankResponse: |
| 118 | 124 | status_code=400, |
| 119 | 125 | detail=f"Too many docs: {len(request.docs)} > {CONFIG.MAX_DOCS}", |
| 120 | 126 | ) |
| 127 | + if request.top_n is not None and int(request.top_n) <= 0: | |
| 128 | + raise HTTPException(status_code=400, detail="top_n must be > 0") | |
| 121 | 129 | |
| 122 | 130 | normalize = CONFIG.NORMALIZE if request.normalize is None else bool(request.normalize) |
| 131 | + top_n = int(request.top_n) if request.top_n is not None else None | |
| 123 | 132 | |
| 124 | 133 | start_ts = time.time() |
| 125 | 134 | logger.info( |
| ... | ... | @@ -130,8 +139,18 @@ def rerank(request: RerankRequest) -> RerankResponse: |
| 130 | 139 | _compact_preview(query, _LOG_TEXT_PREVIEW_CHARS), |
| 131 | 140 | _preview_docs(request.docs, _LOG_DOC_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS), |
| 132 | 141 | ) |
| 133 | - scores, meta = _reranker.score_with_meta(query, request.docs, normalize=normalize) | |
| 142 | + if top_n is not None and hasattr(_reranker, "score_with_meta_topn"): | |
| 143 | + scores, meta = getattr(_reranker, "score_with_meta_topn")( | |
| 144 | + query, | |
| 145 | + request.docs, | |
| 146 | + normalize=normalize, | |
| 147 | + top_n=top_n, | |
| 148 | + ) | |
| 149 | + else: | |
| 150 | + scores, meta = _reranker.score_with_meta(query, request.docs, normalize=normalize) | |
| 134 | 151 | meta = dict(meta) |
| 152 | + if top_n is not None: | |
| 153 | + meta.setdefault("requested_top_n", top_n) | |
| 135 | 154 | meta.update({"service_elapsed_ms": round((time.time() - start_ts) * 1000.0, 3)}) |
| 136 | 155 | score_preview = [round(float(s), 6) for s in scores[:_LOG_DOC_PREVIEW_COUNT]] |
| 137 | 156 | logger.info( | ... | ... |
scripts/perf_api_benchmark.py
| ... | ... | @@ -467,6 +467,12 @@ def parse_args() -> argparse.Namespace: |
| 467 | 467 | parser.add_argument("--rerank-sentence-max-words", type=int, default=40, help="Maximum words per generated doc sentence") |
| 468 | 468 | parser.add_argument("--rerank-query", type=str, default="wireless mouse", help="Fixed query used for rerank dynamic docs mode") |
| 469 | 469 | parser.add_argument("--rerank-seed", type=int, default=20260312, help="Base random seed for rerank dynamic docs mode") |
| 470 | + parser.add_argument( | |
| 471 | + "--rerank-top-n", | |
| 472 | + type=int, | |
| 473 | + default=0, | |
| 474 | + help="Optional top_n for rerank requests in dynamic docs mode (0 means omit top_n).", | |
| 475 | + ) | |
| 470 | 476 | return parser.parse_args() |
| 471 | 477 | |
| 472 | 478 | |
| ... | ... | @@ -487,6 +493,8 @@ def build_rerank_dynamic_cfg(args: argparse.Namespace) -> Dict[str, Any]: |
| 487 | 493 | ) |
| 488 | 494 | if args.rerank_seed < 0: |
| 489 | 495 | raise ValueError(f"rerank-seed must be >= 0, got {args.rerank_seed}") |
| 496 | + if int(args.rerank_top_n) < 0: | |
| 497 | + raise ValueError(f"rerank-top-n must be >= 0, got {args.rerank_top_n}") | |
| 490 | 498 | |
| 491 | 499 | # Use deterministic, letter-only pseudo words to avoid long tokenization of numeric strings. |
| 492 | 500 | syllables = [ |
| ... | ... | @@ -513,6 +521,7 @@ def build_rerank_dynamic_cfg(args: argparse.Namespace) -> Dict[str, Any]: |
| 513 | 521 | "max_words": max_words, |
| 514 | 522 | "seed": int(args.rerank_seed), |
| 515 | 523 | "normalize": True, |
| 524 | + "top_n": int(args.rerank_top_n), | |
| 516 | 525 | "word_pool": word_pool, |
| 517 | 526 | } |
| 518 | 527 | |
| ... | ... | @@ -530,6 +539,7 @@ def build_random_rerank_payload( |
| 530 | 539 | "query": cfg["query"], |
| 531 | 540 | "docs": docs, |
| 532 | 541 | "normalize": bool(cfg.get("normalize", True)), |
| 542 | + **({"top_n": int(cfg["top_n"])} if int(cfg.get("top_n", 0)) > 0 else {}), | |
| 533 | 543 | } |
| 534 | 544 | |
| 535 | 545 | |
| ... | ... | @@ -595,6 +605,7 @@ async def main_async() -> int: |
| 595 | 605 | print(f" rerank_sentence_words=[{args.rerank_sentence_min_words},{args.rerank_sentence_max_words}]") |
| 596 | 606 | print(f" rerank_query={args.rerank_query}") |
| 597 | 607 | print(f" rerank_seed={args.rerank_seed}") |
| 608 | + print(f" rerank_top_n={args.rerank_top_n}") | |
| 598 | 609 | |
| 599 | 610 | results: List[Dict[str, Any]] = [] |
| 600 | 611 | total_jobs = len(run_names) * len(concurrency_values) |
| ... | ... | @@ -643,6 +654,7 @@ async def main_async() -> int: |
| 643 | 654 | "rerank_sentence_max_words": args.rerank_sentence_max_words, |
| 644 | 655 | "rerank_query": args.rerank_query, |
| 645 | 656 | "rerank_seed": args.rerank_seed, |
| 657 | + "rerank_top_n": args.rerank_top_n, | |
| 646 | 658 | }, |
| 647 | 659 | "results": results, |
| 648 | 660 | "overall": aggregate_results(results), | ... | ... |
search/rerank_client.py
| ... | ... | @@ -80,6 +80,7 @@ def call_rerank_service( |
| 80 | 80 | query: str, |
| 81 | 81 | docs: List[str], |
| 82 | 82 | timeout_sec: float = DEFAULT_TIMEOUT_SEC, |
| 83 | + top_n: Optional[int] = None, | |
| 83 | 84 | ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]: |
| 84 | 85 | """ |
| 85 | 86 | 调用重排服务 POST /rerank,返回分数列表与 meta。 |
| ... | ... | @@ -89,7 +90,7 @@ def call_rerank_service( |
| 89 | 90 | return [], {} |
| 90 | 91 | try: |
| 91 | 92 | client = create_rerank_provider() |
| 92 | - return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec) | |
| 93 | + return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n) | |
| 93 | 94 | except Exception as e: |
| 94 | 95 | logger.warning("Rerank request failed: %s", e, exc_info=True) |
| 95 | 96 | return None, None |
| ... | ... | @@ -176,10 +177,12 @@ def run_rerank( |
| 176 | 177 | weight_ai: float = DEFAULT_WEIGHT_AI, |
| 177 | 178 | rerank_query_template: str = "{query}", |
| 178 | 179 | rerank_doc_template: str = "{title}", |
| 180 | + top_n: Optional[int] = None, | |
| 179 | 181 | ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]: |
| 180 | 182 | """ |
| 181 | 183 | 完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。 |
| 182 | 184 | Provider 和 URL 从 services_config 读取。 |
| 185 | + top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。 | |
| 183 | 186 | """ |
| 184 | 187 | hits = es_response.get("hits", {}).get("hits") or [] |
| 185 | 188 | if not hits: |
| ... | ... | @@ -191,6 +194,7 @@ def run_rerank( |
| 191 | 194 | query_text, |
| 192 | 195 | docs, |
| 193 | 196 | timeout_sec=timeout_sec, |
| 197 | + top_n=top_n, | |
| 194 | 198 | ) |
| 195 | 199 | |
| 196 | 200 | if scores is None or len(scores) != len(hits): | ... | ... |
search/searcher.py
| ... | ... | @@ -0,0 +1,34 @@ |
| 1 | +from __future__ import annotations | |
| 2 | + | |
| 3 | +from typing import Any, Dict | |
| 4 | + | |
| 5 | +from providers.rerank import HttpRerankProvider | |
| 6 | + | |
| 7 | + | |
| 8 | +class _FakeResponse: | |
| 9 | + def __init__(self, status_code: int, data: Dict[str, Any]): | |
| 10 | + self.status_code = status_code | |
| 11 | + self._data = data | |
| 12 | + self.text = str(data) | |
| 13 | + | |
| 14 | + def json(self): | |
| 15 | + return self._data | |
| 16 | + | |
| 17 | + | |
| 18 | +def test_http_rerank_provider_includes_top_n(monkeypatch): | |
| 19 | + captured: Dict[str, Any] = {} | |
| 20 | + | |
| 21 | + def _fake_post(url, json, timeout): | |
| 22 | + captured["url"] = url | |
| 23 | + captured["json"] = json | |
| 24 | + captured["timeout"] = timeout | |
| 25 | + return _FakeResponse(200, {"scores": [0.1, 0.2], "meta": {"ok": True}}) | |
| 26 | + | |
| 27 | + monkeypatch.setattr("providers.rerank.requests.post", _fake_post) | |
| 28 | + | |
| 29 | + provider = HttpRerankProvider("http://127.0.0.1:6007/rerank") | |
| 30 | + scores, meta = provider.rerank("q", ["a", "b"], timeout_sec=3.0, top_n=2) | |
| 31 | + | |
| 32 | + assert scores == [0.1, 0.2] | |
| 33 | + assert meta == {"ok": True} | |
| 34 | + assert captured["json"]["top_n"] == 2 | ... | ... |
| ... | ... | @@ -0,0 +1,103 @@ |
| 1 | +from __future__ import annotations | |
| 2 | + | |
| 3 | +from reranker.backends import get_rerank_backend | |
| 4 | +from reranker.backends.dashscope_rerank import DashScopeRerankBackend | |
| 5 | + | |
| 6 | + | |
| 7 | +def test_dashscope_backend_factory_loads(): | |
| 8 | + backend = get_rerank_backend( | |
| 9 | + "dashscope_rerank", | |
| 10 | + { | |
| 11 | + "model_name": "qwen3-rerank", | |
| 12 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | |
| 13 | + "api_key": "test-key", | |
| 14 | + }, | |
| 15 | + ) | |
| 16 | + assert isinstance(backend, DashScopeRerankBackend) | |
| 17 | + | |
| 18 | + | |
| 19 | +def test_dashscope_backend_score_with_meta_dedup_and_restore(monkeypatch): | |
| 20 | + backend = DashScopeRerankBackend( | |
| 21 | + { | |
| 22 | + "model_name": "qwen3-rerank", | |
| 23 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | |
| 24 | + "api_key": "test-key", | |
| 25 | + "top_n_cap": 0, | |
| 26 | + } | |
| 27 | + ) | |
| 28 | + | |
| 29 | + def _fake_post(query: str, docs: list[str], top_n: int): | |
| 30 | + assert query == "wireless mouse" | |
| 31 | + # deduplicated docs | |
| 32 | + assert docs == ["doc-a", "doc-b"] | |
| 33 | + assert top_n == 2 | |
| 34 | + return { | |
| 35 | + "results": [ | |
| 36 | + {"index": 1, "relevance_score": 0.9}, | |
| 37 | + {"index": 0, "relevance_score": 0.2}, | |
| 38 | + ] | |
| 39 | + } | |
| 40 | + | |
| 41 | + monkeypatch.setattr(backend, "_post_rerank", _fake_post) | |
| 42 | + scores, meta = backend.score_with_meta( | |
| 43 | + query="wireless mouse", | |
| 44 | + docs=["doc-a", "doc-b", "doc-a", "", " ", None], | |
| 45 | + normalize=True, | |
| 46 | + ) | |
| 47 | + | |
| 48 | + assert scores == [0.2, 0.9, 0.2, 0.0, 0.0, 0.0] | |
| 49 | + assert meta["input_docs"] == 6 | |
| 50 | + assert meta["usable_docs"] == 3 | |
| 51 | + assert meta["unique_docs"] == 2 | |
| 52 | + assert meta["top_n"] == 2 | |
| 53 | + assert meta["response_results"] == 2 | |
| 54 | + assert meta["backend"] == "dashscope_rerank" | |
| 55 | + | |
| 56 | + | |
| 57 | +def test_dashscope_backend_top_n_cap_and_normalize_fallback(monkeypatch): | |
| 58 | + backend = DashScopeRerankBackend( | |
| 59 | + { | |
| 60 | + "model_name": "qwen3-rerank", | |
| 61 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | |
| 62 | + "api_key": "test-key", | |
| 63 | + "top_n_cap": 1, | |
| 64 | + } | |
| 65 | + ) | |
| 66 | + | |
| 67 | + def _fake_post(query: str, docs: list[str], top_n: int): | |
| 68 | + assert query == "q" | |
| 69 | + assert len(docs) == 2 | |
| 70 | + assert top_n == 1 | |
| 71 | + # Only top-1 returned, score outside [0,1] to trigger sigmoid fallback | |
| 72 | + return {"results": [{"index": 1, "score": 3.0}]} | |
| 73 | + | |
| 74 | + monkeypatch.setattr(backend, "_post_rerank", _fake_post) | |
| 75 | + scores_norm, _ = backend.score_with_meta(query="q", docs=["a", "b"], normalize=True) | |
| 76 | + scores_raw, _ = backend.score_with_meta(query="q", docs=["a", "b"], normalize=False) | |
| 77 | + | |
| 78 | + assert scores_norm[0] == 0.0 | |
| 79 | + assert 0.95 < scores_norm[1] < 0.96 | |
| 80 | + assert scores_raw == [0.0, 3.0] | |
| 81 | + | |
| 82 | + | |
| 83 | +def test_dashscope_backend_score_with_meta_topn_request(monkeypatch): | |
| 84 | + backend = DashScopeRerankBackend( | |
| 85 | + { | |
| 86 | + "model_name": "qwen3-rerank", | |
| 87 | + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", | |
| 88 | + "api_key": "test-key", | |
| 89 | + "top_n_cap": 0, | |
| 90 | + } | |
| 91 | + ) | |
| 92 | + | |
| 93 | + def _fake_post(query: str, docs: list[str], top_n: int): | |
| 94 | + assert query == "q" | |
| 95 | + assert docs == ["d1", "d2", "d3"] | |
| 96 | + assert top_n == 2 | |
| 97 | + return {"results": [{"index": 2, "relevance_score": 0.8}, {"index": 0, "relevance_score": 0.3}]} | |
| 98 | + | |
| 99 | + monkeypatch.setattr(backend, "_post_rerank", _fake_post) | |
| 100 | + scores, meta = backend.score_with_meta_topn(query="q", docs=["d1", "d2", "d3"], top_n=2) | |
| 101 | + assert scores == [0.3, 0.0, 0.8] | |
| 102 | + assert meta["top_n"] == 2 | |
| 103 | + assert meta["requested_top_n"] == 2 | ... | ... |
| ... | ... | @@ -0,0 +1,48 @@ |
| 1 | +from __future__ import annotations | |
| 2 | + | |
| 3 | +from typing import Any, Dict, List | |
| 4 | + | |
| 5 | +from fastapi.testclient import TestClient | |
| 6 | + | |
| 7 | + | |
| 8 | +class _FakeTopNReranker: | |
| 9 | + _model_name = "fake-topn-reranker" | |
| 10 | + | |
| 11 | + def score_with_meta(self, query: str, docs: List[str], normalize: bool = True): | |
| 12 | + return [0.1 for _ in docs], {"input_docs": len(docs), "path": "base"} | |
| 13 | + | |
| 14 | + def score_with_meta_topn( | |
| 15 | + self, | |
| 16 | + query: str, | |
| 17 | + docs: List[str], | |
| 18 | + normalize: bool = True, | |
| 19 | + top_n: int | None = None, | |
| 20 | + ): | |
| 21 | + scores = [0.0 for _ in docs] | |
| 22 | + if docs and top_n: | |
| 23 | + scores[0] = 1.0 | |
| 24 | + return scores, {"input_docs": len(docs), "path": "topn", "top_n": top_n} | |
| 25 | + | |
| 26 | + | |
| 27 | +def test_reranker_server_forwards_top_n(): | |
| 28 | + import reranker.server as reranker_server | |
| 29 | + | |
| 30 | + reranker_server.app.router.on_startup.clear() | |
| 31 | + reranker_server._reranker = _FakeTopNReranker() | |
| 32 | + reranker_server._backend_name = "fake_topn" | |
| 33 | + | |
| 34 | + with TestClient(reranker_server.app) as client: | |
| 35 | + response = client.post( | |
| 36 | + "/rerank", | |
| 37 | + json={ | |
| 38 | + "query": "wireless mouse", | |
| 39 | + "docs": ["a", "b", "c"], | |
| 40 | + "top_n": 2, | |
| 41 | + }, | |
| 42 | + ) | |
| 43 | + assert response.status_code == 200 | |
| 44 | + data: Dict[str, Any] = response.json() | |
| 45 | + assert data["scores"] == [1.0, 0.0, 0.0] | |
| 46 | + assert data["meta"]["path"] == "topn" | |
| 47 | + assert data["meta"]["requested_top_n"] == 2 | |
| 48 | + assert data["meta"]["top_n"] == 2 | ... | ... |