diff --git a/.env.example b/.env.example index c65fca8..a43a538 100644 --- a/.env.example +++ b/.env.example @@ -44,6 +44,11 @@ TEI_MAX_CLIENT_BATCH_SIZE=8 TEI_HEALTH_TIMEOUT_SEC=300 RERANK_PROVIDER=http RERANK_BACKEND=qwen3_vllm +# Optional for cloud rerank backend (RERANK_BACKEND=dashscope_rerank) +DASHSCOPE_API_KEY= +# Example: +# RERANK_DASHSCOPE_ENDPOINT=https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks +RERANK_DASHSCOPE_ENDPOINT= # Cache Directory CACHE_DIR=.cache diff --git a/config/config.yaml b/config/config.yaml index ae3abb4..73c68df 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -166,7 +166,7 @@ services: base_url: "http://127.0.0.1:6007" service_url: "http://127.0.0.1:6007/rerank" # 服务内后端(reranker 进程启动时读取) - backend: "qwen3_vllm" # bge | qwen3_vllm + backend: "qwen3_vllm" # bge | qwen3_vllm | qwen3_transformers | dashscope_rerank backends: bge: model_name: "BAAI/bge-reranker-v2-m3" @@ -189,6 +189,26 @@ services: sort_by_doc_length: true length_sort_mode: "char" # char | token instruction: "Given a shopping query, rank product titles by relevance" + qwen3_transformers: + model_name: "Qwen/Qwen3-Reranker-0.6B" + instruction: "Given a shopping query, rank product titles by relevance" + max_length: 8192 + batch_size: 64 + use_fp16: true + attn_implementation: "flash_attention_2" + dashscope_rerank: + model_name: "qwen3-rerank" + # 按地域选择 endpoint: + # 中国: https://dashscope.aliyuncs.com/compatible-api/v1/reranks + # 新加坡: https://dashscope-intl.aliyuncs.com/compatible-api/v1/reranks + # 美国: https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks + endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" + api_key: null # 推荐通过环境变量 DASHSCOPE_API_KEY 设置 + timeout_sec: 15.0 + top_n_cap: 0 # 0 表示 top_n=当前请求文档数;>0 则限制 top_n 上限 + instruct: "Given a shopping query, rank product titles by relevance" + max_retries: 2 + retry_backoff_sec: 0.2 # SPU配置(已启用,使用嵌套skus) spu_config: diff --git a/docs/DEVELOPER_GUIDE.md b/docs/DEVELOPER_GUIDE.md index 020c54d..70dd1ad 100644 --- a/docs/DEVELOPER_GUIDE.md +++ b/docs/DEVELOPER_GUIDE.md @@ -318,7 +318,7 @@ services: |------|--------|------|--------| | 调用方 | `services..provider` | http | http | | 调用方 | `services..providers.http.base_url` | 6007 | 6005 | -| 服务内 | `services..backend` | qwen3_vllm / bge | tei / local_st | +| 服务内 | `services..backend` | qwen3_vllm / qwen3_transformers / bge / dashscope_rerank | tei / local_st | | 服务内 | `services..backends.` | 模型名、batch、vLLM 参数 | 模型名、device 等 | ### 7.6 新增后端清单(以 Qwen3-Reranker 为例) @@ -334,7 +334,7 @@ services: - **单一路径**:Provider 和 backend 必须由 `config/config.yaml` 的 `services` 块显式指定;未知配置应直接报错。 - **无兼容回退**:不保留“旧配置自动推导/兜底默认值”机制,避免静默行为偏差。 -- **环境变量覆盖**:允许环境变量覆盖(如 `RERANKER_SERVICE_URL`、`RERANK_BACKEND`、`EMBEDDING_SERVICE_URL`、`EMBEDDING_BACKEND`、`TEI_BASE_URL`),但覆盖后仍需满足合法性校验。 +- **环境变量覆盖**:允许环境变量覆盖(如 `RERANKER_SERVICE_URL`、`RERANK_BACKEND`、`DASHSCOPE_API_KEY`、`RERANK_DASHSCOPE_ENDPOINT`、`EMBEDDING_SERVICE_URL`、`EMBEDDING_BACKEND`、`TEI_BASE_URL`),但覆盖后仍需满足合法性校验。 --- diff --git a/docs/QUICKSTART.md b/docs/QUICKSTART.md index 9f8e74c..846682d 100644 --- a/docs/QUICKSTART.md +++ b/docs/QUICKSTART.md @@ -409,7 +409,7 @@ services: tei: { base_url: "http://127.0.0.1:8080", timeout_sec: 60, model_id: "Qwen/Qwen3-Embedding-0.6B" } rerank: provider: "http" - backend: "qwen3_vllm" + backend: "qwen3_vllm" # bge | qwen3_vllm | qwen3_transformers | dashscope_rerank providers: http: { base_url: "http://127.0.0.1:6007", service_url: "http://127.0.0.1:6007/rerank" } ``` @@ -423,6 +423,8 @@ services: - `TEI_BASE_URL` - `RERANKER_SERVICE_URL` - `RERANK_BACKEND`(服务内后端) +- `DASHSCOPE_API_KEY`(`dashscope_rerank` 后端鉴权) +- `RERANK_DASHSCOPE_ENDPOINT`(`dashscope_rerank` 地域 endpoint 覆盖) ### 3.3 新增 provider 的最小步骤 @@ -451,6 +453,8 @@ services: - `reranker/backends/__init__.py`(工厂) - `reranker/backends/bge.py` - `reranker/backends/qwen3_vllm.py` +- `reranker/backends/qwen3_transformers.py` +- `reranker/backends/dashscope_rerank.py` 后端协议(服务内): diff --git a/docs/性能测试报告.md b/docs/性能测试报告.md index c50264e..7f60505 100644 --- a/docs/性能测试报告.md +++ b/docs/性能测试报告.md @@ -338,3 +338,46 @@ done 异常说明: - `tenant 0` 在并发 `20` 出现 `ReadTimeout`(25 次),该档成功率下降到 `59.02%` - 其他租户在本轮口径下均为 `100%` 成功率 + +## 13. Rerank 后端对比(qwen3_vllm vs DashScope 云服务) + +目标: +- 使用同一套构造数据,对比两个 rerank 微服务在电商搜索重排场景下的速度差异 +- 为后端选型提供直接依据 + +测试口径(两端一致): +- query:固定 `wireless mouse` +- docs:每次请求固定 `386` 条 +- 构造方式:从 `1000` 词池随机采样;每条 doc 句长随机 `15-40` +- `top_n`:`30`(模拟 `page+size`) +- 并发:`1 / 5 / 10 / 20` +- 每档时长:`20s` +- 每个后端跑 `2` 轮,以下表格为两轮均值 + +执行文件: +- vLLM:`perf_reports/2026-03-12/rerank_backend_compare/vllm_round1_topn30.json` +- vLLM:`perf_reports/2026-03-12/rerank_backend_compare/vllm_round2b_topn30.json` +- Cloud:`perf_reports/2026-03-12/rerank_backend_compare/cloud_round1_topn30.json` +- Cloud:`perf_reports/2026-03-12/rerank_backend_compare/cloud_round2_topn30.json` + +### 13.1 两轮均值对比 + +| 并发 | vLLM RPS | Cloud RPS | vLLM P95(ms) | Cloud P95(ms) | vLLM Avg(ms) | Cloud Avg(ms) | +|---:|---:|---:|---:|---:|---:|---:| +| 1 | 0.625 | 0.220 | 1937.68 | 6371.03 | 1602.37 | 4752.53 | +| 5 | 0.585 | 1.040 | 9421.37 | 7372.85 | 8480.29 | 4543.84 | +| 10 | 0.595 | 1.820 | 18040.65 | 7637.43 | 16767.64 | 4820.35 | +| 20 | 0.590 | 3.530 | 33766.06 | 8445.39 | 33563.23 | 4890.59 | + +### 13.2 结论 + +- 单并发(`c=1`)下,`qwen3_vllm` 更快(更低延迟、略高吞吐)。 +- 从 `c=5` 开始,DashScope 云后端明显更快: + - `c=5`:Cloud 吞吐约为 vLLM 的 `1.78x` + - `c=10`:Cloud 吞吐约为 vLLM 的 `3.06x` + - `c=20`:Cloud 吞吐约为 vLLM 的 `5.98x` +- 在“电商搜索在线重排(有并发)”场景下,当前实现建议优先选云后端。 + +说明: +- 本轮对比基于当前实现:`dashscope_rerank` 支持 `top_n`(本次取 `30`),`qwen3_vllm` 当前仍按全量 docs 评分。 +- 若后续为本地模型实现 `top_n` 局部重排能力,需要重新对比后再最终定版。 diff --git a/providers/rerank.py b/providers/rerank.py index fc16dd5..885b806 100644 --- a/providers/rerank.py +++ b/providers/rerank.py @@ -23,11 +23,14 @@ class HttpRerankProvider: query: str, docs: List[str], timeout_sec: float, + top_n: Optional[int] = None, ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]: if not docs: return [], {} try: payload = {"query": (query or "").strip(), "docs": docs} + if top_n is not None and int(top_n) > 0: + payload["top_n"] = int(top_n) response = requests.post(self.service_url, json=payload, timeout=timeout_sec) if response.status_code != 200: logger.warning( diff --git a/reranker/README.md b/reranker/README.md index 878c527..19d9153 100644 --- a/reranker/README.md +++ b/reranker/README.md @@ -4,10 +4,11 @@ --- -Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM、Qwen3-Transformers)。调用方通过 HTTP 访问,不关心具体后端。 +Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM、Qwen3-Transformers、DashScope 云重排)。调用方通过 HTTP 访问,不关心具体后端。 **特性** - 多后端:`qwen3_vllm`(默认,Qwen3-Reranker-0.6B + vLLM)、`qwen3_transformers`(纯 Transformers,无需 vLLM)、`bge`(兼容保留) +- 云后端:`dashscope_rerank`(调用 DashScope `/compatible-api/v1/reranks`,支持按地域切换 endpoint) - 统一配置:`config/config.yaml` → `services.rerank.backend` / `services.rerank.backends.` - 文档去重、分数与输入顺序一致、FP16/GPU 支持(视后端) @@ -18,6 +19,7 @@ Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwe - `backends/bge.py`:BGE 后端 - `backends/qwen3_vllm.py`:Qwen3-Reranker-0.6B + vLLM 后端 - `backends/qwen3_transformers.py`:Qwen3-Reranker-0.6B 纯 Transformers 后端(官方 Usage 方式) + - `backends/dashscope_rerank.py`:DashScope 云重排后端(HTTP 调用) - `reranker/bge_reranker.py`:BGE 核心推理(被 bge 后端封装) - `reranker/config.py`:服务端口、MAX_DOCS、NORMALIZE 等(后端参数在 config.yaml) @@ -30,7 +32,7 @@ Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwe ``` ## 配置 -- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `qwen3_transformers` | `bge`),或环境变量 `RERANK_BACKEND`。 +- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `qwen3_transformers` | `bge` | `dashscope_rerank`),或环境变量 `RERANK_BACKEND`。 - **后端参数**:`services.rerank.backends.bge` / `services.rerank.backends.qwen3_vllm`,例如: ```yaml @@ -64,8 +66,26 @@ services: tensor_parallel_size: 1 gpu_memory_utilization: 0.8 instruction: "Given a shopping query, rank product titles by relevance" + dashscope_rerank: + model_name: "qwen3-rerank" + endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" + api_key: null # 推荐使用环境变量 DASHSCOPE_API_KEY + timeout_sec: 15.0 + top_n_cap: 0 + instruct: "Given a shopping query, rank product titles by relevance" + max_retries: 2 + retry_backoff_sec: 0.2 ``` +DashScope endpoint 地域示例: +- 中国:`https://dashscope.aliyuncs.com/compatible-api/v1/reranks` +- 新加坡:`https://dashscope-intl.aliyuncs.com/compatible-api/v1/reranks` +- 美国:`https://dashscope-us.aliyuncs.com/compatible-api/v1/reranks` + +DashScope 认证: +- `api_key` 支持配置在 `config.yaml` +- 推荐通过环境变量注入:`DASHSCOPE_API_KEY=...` + - 服务端口、请求限制等仍在 `reranker/config.py`(或环境变量 `RERANKER_PORT`、`RERANKER_HOST`)。 ## 运行 @@ -94,10 +114,15 @@ Content-Type: application/json { "query": "wireless mouse", - "docs": ["logitech mx master", "usb cable", "wireless mouse bluetooth"] + "docs": ["logitech mx master", "usb cable", "wireless mouse bluetooth"], + "top_n": 10 } ``` +`top_n` 为可选字段: +- 对本地后端(`qwen3_vllm` / `qwen3_transformers` / `bge`)通常会忽略,仍返回全量分数。 +- 对 `dashscope_rerank` 可用于控制云端返回的候选量,建议设置为 `page+size`(例如分页 `from=20,size=10` 时传 `30`)。 + Response: ``` { diff --git a/reranker/backends/__init__.py b/reranker/backends/__init__.py index 7edd115..f68d472 100644 --- a/reranker/backends/__init__.py +++ b/reranker/backends/__init__.py @@ -46,8 +46,11 @@ def get_rerank_backend(name: str, config: Dict[str, Any]) -> RerankBackendProtoc if name == "qwen3_transformers": from reranker.backends.qwen3_transformers import Qwen3TransformersRerankerBackend return Qwen3TransformersRerankerBackend(config) + if name == "dashscope_rerank": + from reranker.backends.dashscope_rerank import DashScopeRerankBackend + return DashScopeRerankBackend(config) raise ValueError( - f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_transformers" + f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_transformers, dashscope_rerank" ) diff --git a/reranker/backends/dashscope_rerank.py b/reranker/backends/dashscope_rerank.py new file mode 100644 index 0000000..2b73f6b --- /dev/null +++ b/reranker/backends/dashscope_rerank.py @@ -0,0 +1,288 @@ +""" +DashScope cloud reranker backend (OpenAI-compatible reranks API). + +Reference: +- https://dashscope.aliyuncs.com/compatible-api/v1/reranks +- Use region-specific domains when needed: + - China: https://dashscope.aliyuncs.com + - Singapore: https://dashscope-intl.aliyuncs.com + - US: https://dashscope-us.aliyuncs.com +""" + +from __future__ import annotations + +import json +import logging +import math +import os +import time +from typing import Any, Dict, List, Tuple +from urllib import error as urllib_error +from urllib import request as urllib_request + +from reranker.backends.batching_utils import deduplicate_with_positions + +logger = logging.getLogger("reranker.backends.dashscope_rerank") + + +class DashScopeRerankBackend: + """ + DashScope cloud reranker backend. + + Config from services.rerank.backends.dashscope_rerank: + - model_name: str, default "qwen3-rerank" + - endpoint: str, default "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" + - api_key: optional str (or env DASHSCOPE_API_KEY) + - timeout_sec: float, default 15.0 + - top_n_cap: int, optional cap; 0 means use all docs in request + - instruct: optional str + - max_retries: int, default 1 + - retry_backoff_sec: float, default 0.2 + + Env overrides: + - DASHSCOPE_API_KEY + - RERANK_DASHSCOPE_ENDPOINT + - RERANK_DASHSCOPE_MODEL + - RERANK_DASHSCOPE_TIMEOUT_SEC + - RERANK_DASHSCOPE_TOP_N_CAP + """ + + def __init__(self, config: Dict[str, Any]) -> None: + self._config = config or {} + self._model_name = str( + os.getenv("RERANK_DASHSCOPE_MODEL") + or self._config.get("model_name") + or "qwen3-rerank" + ) + self._endpoint = str( + os.getenv("RERANK_DASHSCOPE_ENDPOINT") + or self._config.get("endpoint") + or "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" + ).strip() + self._api_key = str( + os.getenv("DASHSCOPE_API_KEY") + or self._config.get("api_key") + or "" + ).strip().strip('"').strip("'") + self._timeout_sec = float( + os.getenv("RERANK_DASHSCOPE_TIMEOUT_SEC") + or self._config.get("timeout_sec") + or 15.0 + ) + self._top_n_cap = int( + os.getenv("RERANK_DASHSCOPE_TOP_N_CAP") + or self._config.get("top_n_cap") + or 0 + ) + self._instruct = str(self._config.get("instruct") or "").strip() + self._max_retries = int(self._config.get("max_retries", 1)) + self._retry_backoff_sec = float(self._config.get("retry_backoff_sec", 0.2)) + + if not self._endpoint: + raise ValueError("dashscope_rerank endpoint is required") + if not self._api_key: + raise ValueError( + "dashscope_rerank api_key is required (set services.rerank.backends.dashscope_rerank.api_key " + "or env DASHSCOPE_API_KEY)" + ) + if self._timeout_sec <= 0: + raise ValueError(f"dashscope_rerank timeout_sec must be > 0, got {self._timeout_sec}") + if self._top_n_cap < 0: + raise ValueError(f"dashscope_rerank top_n_cap must be >= 0, got {self._top_n_cap}") + if self._max_retries <= 0: + raise ValueError(f"dashscope_rerank max_retries must be > 0, got {self._max_retries}") + if self._retry_backoff_sec < 0: + raise ValueError( + f"dashscope_rerank retry_backoff_sec must be >= 0, got {self._retry_backoff_sec}" + ) + + logger.info( + "DashScope reranker ready | endpoint=%s model=%s timeout_sec=%s top_n_cap=%s", + self._endpoint, + self._model_name, + self._timeout_sec, + self._top_n_cap, + ) + + def _http_post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]: + body = json.dumps(payload, ensure_ascii=False).encode("utf-8") + req = urllib_request.Request( + url=self._endpoint, + method="POST", + data=body, + headers={ + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + }, + ) + with urllib_request.urlopen(req, timeout=self._timeout_sec) as resp: + raw = resp.read().decode("utf-8", errors="replace") + try: + data = json.loads(raw) + except json.JSONDecodeError as exc: + raise RuntimeError(f"DashScope response is not valid JSON: {raw[:512]}") from exc + if not isinstance(data, dict): + raise RuntimeError(f"DashScope response must be JSON object, got: {type(data).__name__}") + return data + + def _post_rerank(self, query: str, docs: List[str], top_n: int) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "model": self._model_name, + "query": query, + "documents": docs, + "top_n": top_n, + } + if self._instruct: + payload["instruct"] = self._instruct + + last_exc: Exception | None = None + for attempt in range(1, self._max_retries + 1): + try: + return self._http_post_json(payload) + except urllib_error.HTTPError as exc: + body = "" + try: + body = exc.read().decode("utf-8", errors="replace") + except Exception: + body = "" + last_exc = RuntimeError( + f"DashScope rerank HTTP {exc.code} (attempt {attempt}/{self._max_retries}): {body[:512]}" + ) + except urllib_error.URLError as exc: + last_exc = RuntimeError( + f"DashScope rerank network error (attempt {attempt}/{self._max_retries}): {exc}" + ) + except Exception as exc: # pragma: no cover - defensive + last_exc = RuntimeError( + f"DashScope rerank unexpected error (attempt {attempt}/{self._max_retries}): {exc}" + ) + + if attempt < self._max_retries and self._retry_backoff_sec > 0: + time.sleep(self._retry_backoff_sec * attempt) + + raise RuntimeError(str(last_exc) if last_exc else "DashScope rerank failed with unknown error") + + @staticmethod + def _extract_results(data: Dict[str, Any]) -> List[Dict[str, Any]]: + # Compatible API style: {"results":[...]} + results = data.get("results") + if isinstance(results, list): + return [x for x in results if isinstance(x, dict)] + + # Native style fallback: {"output":{"results":[...]}} + output = data.get("output") + if isinstance(output, dict): + output_results = output.get("results") + if isinstance(output_results, list): + return [x for x in output_results if isinstance(x, dict)] + + return [] + + @staticmethod + def _coerce_score(raw_score: Any, normalize: bool) -> float: + try: + score = float(raw_score) + except (TypeError, ValueError): + return 0.0 + + if not normalize: + return score + # DashScope relevance_score is typically already in [0,1]; keep it. + if 0.0 <= score <= 1.0: + return score + # Fallback when provider returns logits/raw scores. + if score > 60: + return 1.0 + if score < -60: + return 0.0 + return 1.0 / (1.0 + math.exp(-score)) + + def score_with_meta_topn( + self, + query: str, + docs: List[str], + normalize: bool = True, + top_n: int | None = None, + ) -> Tuple[List[float], Dict[str, Any]]: + start_ts = time.time() + total_docs = len(docs) if docs else 0 + output_scores: List[float] = [0.0] * total_docs + + query = "" if query is None else str(query).strip() + indexed: List[Tuple[int, str]] = [] + for i, doc in enumerate(docs or []): + if doc is None: + continue + text = str(doc).strip() + if not text: + continue + indexed.append((i, text)) + + if not query or not indexed: + elapsed_ms = (time.time() - start_ts) * 1000.0 + return output_scores, { + "input_docs": total_docs, + "usable_docs": len(indexed), + "unique_docs": 0, + "dedup_ratio": 0.0, + "elapsed_ms": round(elapsed_ms, 3), + "model": self._model_name, + "backend": "dashscope_rerank", + "normalize": normalize, + "top_n": 0, + } + + indexed_texts = [text for _, text in indexed] + unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts) + + top_n_effective = len(unique_texts) + if top_n is not None and int(top_n) > 0: + top_n_effective = min(top_n_effective, int(top_n)) + if self._top_n_cap > 0: + top_n_effective = min(top_n_effective, self._top_n_cap) + + response = self._post_rerank(query=query, docs=unique_texts, top_n=top_n_effective) + results = self._extract_results(response) + + unique_scores: List[float] = [0.0] * len(unique_texts) + for rank, item in enumerate(results): + raw_idx = item.get("index", rank) + try: + idx = int(raw_idx) + except (TypeError, ValueError): + continue + if idx < 0 or idx >= len(unique_scores): + continue + raw_score = item.get("relevance_score", item.get("score")) + unique_scores[idx] = self._coerce_score(raw_score, normalize=normalize) + + for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): + output_scores[orig_idx] = float(unique_scores[unique_idx]) + + elapsed_ms = (time.time() - start_ts) * 1000.0 + dedup_ratio = 0.0 + if indexed: + dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) + + return output_scores, { + "input_docs": total_docs, + "usable_docs": len(indexed), + "unique_docs": len(unique_texts), + "dedup_ratio": round(dedup_ratio, 4), + "elapsed_ms": round(elapsed_ms, 3), + "model": self._model_name, + "backend": "dashscope_rerank", + "normalize": normalize, + "top_n": top_n_effective, + "requested_top_n": int(top_n) if top_n is not None else None, + "response_results": len(results), + "endpoint": self._endpoint, + } + + def score_with_meta( + self, + query: str, + docs: List[str], + normalize: bool = True, + ) -> Tuple[List[float], Dict[str, Any]]: + return self.score_with_meta_topn(query=query, docs=docs, normalize=normalize, top_n=None) diff --git a/reranker/server.py b/reranker/server.py index 2c46cc3..f944d77 100644 --- a/reranker/server.py +++ b/reranker/server.py @@ -1,11 +1,13 @@ """ -Reranker service - unified /rerank API backed by pluggable backends (BGE, Qwen3-vLLM). +Reranker service - unified /rerank API backed by pluggable backends +(BGE, Qwen3-vLLM, Qwen3-Transformers, DashScope cloud rerank). POST /rerank Request: { "query": "...", "docs": ["doc1", "doc2", ...], "normalize": optional bool } Response: { "scores": [float], "meta": {...} } -Backend selected via config: services.rerank.backend (bge | qwen3_vllm), env RERANK_BACKEND. +Backend selected via config: services.rerank.backend +(bge | qwen3_vllm | qwen3_transformers | dashscope_rerank), env RERANK_BACKEND. """ import logging @@ -60,6 +62,10 @@ class RerankRequest(BaseModel): normalize: Optional[bool] = Field( default=CONFIG.NORMALIZE, description="Apply sigmoid normalization" ) + top_n: Optional[int] = Field( + default=None, + description="Optional top_n hint for backends that support partial ranking", + ) class RerankResponse(BaseModel): @@ -118,8 +124,11 @@ def rerank(request: RerankRequest) -> RerankResponse: status_code=400, detail=f"Too many docs: {len(request.docs)} > {CONFIG.MAX_DOCS}", ) + if request.top_n is not None and int(request.top_n) <= 0: + raise HTTPException(status_code=400, detail="top_n must be > 0") normalize = CONFIG.NORMALIZE if request.normalize is None else bool(request.normalize) + top_n = int(request.top_n) if request.top_n is not None else None start_ts = time.time() logger.info( @@ -130,8 +139,18 @@ def rerank(request: RerankRequest) -> RerankResponse: _compact_preview(query, _LOG_TEXT_PREVIEW_CHARS), _preview_docs(request.docs, _LOG_DOC_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS), ) - scores, meta = _reranker.score_with_meta(query, request.docs, normalize=normalize) + if top_n is not None and hasattr(_reranker, "score_with_meta_topn"): + scores, meta = getattr(_reranker, "score_with_meta_topn")( + query, + request.docs, + normalize=normalize, + top_n=top_n, + ) + else: + scores, meta = _reranker.score_with_meta(query, request.docs, normalize=normalize) meta = dict(meta) + if top_n is not None: + meta.setdefault("requested_top_n", top_n) meta.update({"service_elapsed_ms": round((time.time() - start_ts) * 1000.0, 3)}) score_preview = [round(float(s), 6) for s in scores[:_LOG_DOC_PREVIEW_COUNT]] logger.info( diff --git a/scripts/perf_api_benchmark.py b/scripts/perf_api_benchmark.py index 27039aa..7f0defc 100755 --- a/scripts/perf_api_benchmark.py +++ b/scripts/perf_api_benchmark.py @@ -467,6 +467,12 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--rerank-sentence-max-words", type=int, default=40, help="Maximum words per generated doc sentence") parser.add_argument("--rerank-query", type=str, default="wireless mouse", help="Fixed query used for rerank dynamic docs mode") parser.add_argument("--rerank-seed", type=int, default=20260312, help="Base random seed for rerank dynamic docs mode") + parser.add_argument( + "--rerank-top-n", + type=int, + default=0, + help="Optional top_n for rerank requests in dynamic docs mode (0 means omit top_n).", + ) return parser.parse_args() @@ -487,6 +493,8 @@ def build_rerank_dynamic_cfg(args: argparse.Namespace) -> Dict[str, Any]: ) if args.rerank_seed < 0: raise ValueError(f"rerank-seed must be >= 0, got {args.rerank_seed}") + if int(args.rerank_top_n) < 0: + raise ValueError(f"rerank-top-n must be >= 0, got {args.rerank_top_n}") # Use deterministic, letter-only pseudo words to avoid long tokenization of numeric strings. syllables = [ @@ -513,6 +521,7 @@ def build_rerank_dynamic_cfg(args: argparse.Namespace) -> Dict[str, Any]: "max_words": max_words, "seed": int(args.rerank_seed), "normalize": True, + "top_n": int(args.rerank_top_n), "word_pool": word_pool, } @@ -530,6 +539,7 @@ def build_random_rerank_payload( "query": cfg["query"], "docs": docs, "normalize": bool(cfg.get("normalize", True)), + **({"top_n": int(cfg["top_n"])} if int(cfg.get("top_n", 0)) > 0 else {}), } @@ -595,6 +605,7 @@ async def main_async() -> int: print(f" rerank_sentence_words=[{args.rerank_sentence_min_words},{args.rerank_sentence_max_words}]") print(f" rerank_query={args.rerank_query}") print(f" rerank_seed={args.rerank_seed}") + print(f" rerank_top_n={args.rerank_top_n}") results: List[Dict[str, Any]] = [] total_jobs = len(run_names) * len(concurrency_values) @@ -643,6 +654,7 @@ async def main_async() -> int: "rerank_sentence_max_words": args.rerank_sentence_max_words, "rerank_query": args.rerank_query, "rerank_seed": args.rerank_seed, + "rerank_top_n": args.rerank_top_n, }, "results": results, "overall": aggregate_results(results), diff --git a/search/rerank_client.py b/search/rerank_client.py index bdb414d..28b6c22 100644 --- a/search/rerank_client.py +++ b/search/rerank_client.py @@ -80,6 +80,7 @@ def call_rerank_service( query: str, docs: List[str], timeout_sec: float = DEFAULT_TIMEOUT_SEC, + top_n: Optional[int] = None, ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]: """ 调用重排服务 POST /rerank,返回分数列表与 meta。 @@ -89,7 +90,7 @@ def call_rerank_service( return [], {} try: client = create_rerank_provider() - return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec) + return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n) except Exception as e: logger.warning("Rerank request failed: %s", e, exc_info=True) return None, None @@ -176,10 +177,12 @@ def run_rerank( weight_ai: float = DEFAULT_WEIGHT_AI, rerank_query_template: str = "{query}", rerank_doc_template: str = "{title}", + top_n: Optional[int] = None, ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]: """ 完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。 Provider 和 URL 从 services_config 读取。 + top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。 """ hits = es_response.get("hits", {}).get("hits") or [] if not hits: @@ -191,6 +194,7 @@ def run_rerank( query_text, docs, timeout_sec=timeout_sec, + top_n=top_n, ) if scores is None or len(scores) != len(hits): diff --git a/search/searcher.py b/search/searcher.py index 091fa97..6c4057e 100644 --- a/search/searcher.py +++ b/search/searcher.py @@ -507,6 +507,7 @@ class Searcher: weight_ai=rc.weight_ai, rerank_query_template=effective_query_template, rerank_doc_template=effective_doc_template, + top_n=(from_ + size), ) if rerank_meta is not None: diff --git a/tests/test_rerank_provider_topn.py b/tests/test_rerank_provider_topn.py new file mode 100644 index 0000000..d0819fd --- /dev/null +++ b/tests/test_rerank_provider_topn.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Any, Dict + +from providers.rerank import HttpRerankProvider + + +class _FakeResponse: + def __init__(self, status_code: int, data: Dict[str, Any]): + self.status_code = status_code + self._data = data + self.text = str(data) + + def json(self): + return self._data + + +def test_http_rerank_provider_includes_top_n(monkeypatch): + captured: Dict[str, Any] = {} + + def _fake_post(url, json, timeout): + captured["url"] = url + captured["json"] = json + captured["timeout"] = timeout + return _FakeResponse(200, {"scores": [0.1, 0.2], "meta": {"ok": True}}) + + monkeypatch.setattr("providers.rerank.requests.post", _fake_post) + + provider = HttpRerankProvider("http://127.0.0.1:6007/rerank") + scores, meta = provider.rerank("q", ["a", "b"], timeout_sec=3.0, top_n=2) + + assert scores == [0.1, 0.2] + assert meta == {"ok": True} + assert captured["json"]["top_n"] == 2 diff --git a/tests/test_reranker_dashscope_backend.py b/tests/test_reranker_dashscope_backend.py new file mode 100644 index 0000000..e1b0bb4 --- /dev/null +++ b/tests/test_reranker_dashscope_backend.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from reranker.backends import get_rerank_backend +from reranker.backends.dashscope_rerank import DashScopeRerankBackend + + +def test_dashscope_backend_factory_loads(): + backend = get_rerank_backend( + "dashscope_rerank", + { + "model_name": "qwen3-rerank", + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", + "api_key": "test-key", + }, + ) + assert isinstance(backend, DashScopeRerankBackend) + + +def test_dashscope_backend_score_with_meta_dedup_and_restore(monkeypatch): + backend = DashScopeRerankBackend( + { + "model_name": "qwen3-rerank", + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", + "api_key": "test-key", + "top_n_cap": 0, + } + ) + + def _fake_post(query: str, docs: list[str], top_n: int): + assert query == "wireless mouse" + # deduplicated docs + assert docs == ["doc-a", "doc-b"] + assert top_n == 2 + return { + "results": [ + {"index": 1, "relevance_score": 0.9}, + {"index": 0, "relevance_score": 0.2}, + ] + } + + monkeypatch.setattr(backend, "_post_rerank", _fake_post) + scores, meta = backend.score_with_meta( + query="wireless mouse", + docs=["doc-a", "doc-b", "doc-a", "", " ", None], + normalize=True, + ) + + assert scores == [0.2, 0.9, 0.2, 0.0, 0.0, 0.0] + assert meta["input_docs"] == 6 + assert meta["usable_docs"] == 3 + assert meta["unique_docs"] == 2 + assert meta["top_n"] == 2 + assert meta["response_results"] == 2 + assert meta["backend"] == "dashscope_rerank" + + +def test_dashscope_backend_top_n_cap_and_normalize_fallback(monkeypatch): + backend = DashScopeRerankBackend( + { + "model_name": "qwen3-rerank", + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", + "api_key": "test-key", + "top_n_cap": 1, + } + ) + + def _fake_post(query: str, docs: list[str], top_n: int): + assert query == "q" + assert len(docs) == 2 + assert top_n == 1 + # Only top-1 returned, score outside [0,1] to trigger sigmoid fallback + return {"results": [{"index": 1, "score": 3.0}]} + + monkeypatch.setattr(backend, "_post_rerank", _fake_post) + scores_norm, _ = backend.score_with_meta(query="q", docs=["a", "b"], normalize=True) + scores_raw, _ = backend.score_with_meta(query="q", docs=["a", "b"], normalize=False) + + assert scores_norm[0] == 0.0 + assert 0.95 < scores_norm[1] < 0.96 + assert scores_raw == [0.0, 3.0] + + +def test_dashscope_backend_score_with_meta_topn_request(monkeypatch): + backend = DashScopeRerankBackend( + { + "model_name": "qwen3-rerank", + "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks", + "api_key": "test-key", + "top_n_cap": 0, + } + ) + + def _fake_post(query: str, docs: list[str], top_n: int): + assert query == "q" + assert docs == ["d1", "d2", "d3"] + assert top_n == 2 + return {"results": [{"index": 2, "relevance_score": 0.8}, {"index": 0, "relevance_score": 0.3}]} + + monkeypatch.setattr(backend, "_post_rerank", _fake_post) + scores, meta = backend.score_with_meta_topn(query="q", docs=["d1", "d2", "d3"], top_n=2) + assert scores == [0.3, 0.0, 0.8] + assert meta["top_n"] == 2 + assert meta["requested_top_n"] == 2 diff --git a/tests/test_reranker_server_topn.py b/tests/test_reranker_server_topn.py new file mode 100644 index 0000000..4618814 --- /dev/null +++ b/tests/test_reranker_server_topn.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import Any, Dict, List + +from fastapi.testclient import TestClient + + +class _FakeTopNReranker: + _model_name = "fake-topn-reranker" + + def score_with_meta(self, query: str, docs: List[str], normalize: bool = True): + return [0.1 for _ in docs], {"input_docs": len(docs), "path": "base"} + + def score_with_meta_topn( + self, + query: str, + docs: List[str], + normalize: bool = True, + top_n: int | None = None, + ): + scores = [0.0 for _ in docs] + if docs and top_n: + scores[0] = 1.0 + return scores, {"input_docs": len(docs), "path": "topn", "top_n": top_n} + + +def test_reranker_server_forwards_top_n(): + import reranker.server as reranker_server + + reranker_server.app.router.on_startup.clear() + reranker_server._reranker = _FakeTopNReranker() + reranker_server._backend_name = "fake_topn" + + with TestClient(reranker_server.app) as client: + response = client.post( + "/rerank", + json={ + "query": "wireless mouse", + "docs": ["a", "b", "c"], + "top_n": 2, + }, + ) + assert response.status_code == 200 + data: Dict[str, Any] = response.json() + assert data["scores"] == [1.0, 0.0, 0.0] + assert data["meta"]["path"] == "topn" + assert data["meta"]["requested_top_n"] == 2 + assert data["meta"]["top_n"] == 2 -- libgit2 0.21.2