diff --git a/api/models.py b/api/models.py index 472d6e7..9b9d384 100644 --- a/api/models.py +++ b/api/models.py @@ -151,9 +151,12 @@ class SearchRequest(BaseModel): min_score: Optional[float] = Field(None, ge=0, description="最小相关性分数阈值") highlight: bool = Field(False, description="是否高亮搜索关键词(暂不实现)") debug: bool = Field(False, description="是否返回调试信息") - enable_rerank: bool = Field( - False, - description="是否开启重排(调用外部重排服务对 ES 结果进行二次排序)" + enable_rerank: Optional[bool] = Field( + None, + description=( + "是否开启重排(调用外部重排服务对 ES 结果进行二次排序)。" + "不传则使用服务端配置 rerank.enabled(默认开启)。" + ) ) rerank_query_template: Optional[str] = Field( None, diff --git a/api/routes/search.py b/api/routes/search.py index 72e91b8..9cc10df 100644 --- a/api/routes/search.py +++ b/api/routes/search.py @@ -133,6 +133,16 @@ async def search(request: SearchRequest, http_request: Request): # Include performance summary in response performance_summary = context.get_summary() if context else None + stage_timings = { + k: round(v, 2) for k, v in context.performance_metrics.stage_timings.items() + } + total_ms = round(float(context.performance_metrics.total_duration or result.took_ms), 2) + context.logger.info( + "Before response | total_ms: %s | stage_timings_ms: %s", + total_ms, + stage_timings, + extra={'reqid': context.reqid, 'uid': context.uid} + ) # Convert to response model response = SearchResponse( diff --git a/config/config.yaml b/config/config.yaml index 71b874a..1a3f322 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -113,6 +113,7 @@ function_score: # 重排配置(provider/URL 在 services.rerank) rerank: + enabled: true rerank_window: 1000 timeout_sec: 15.0 weight_es: 0.4 diff --git a/config/config_loader.py b/config/config_loader.py index df34e32..afa64fb 100644 --- a/config/config_loader.py +++ b/config/config_loader.py @@ -106,6 +106,7 @@ class RankingConfig: @dataclass class RerankConfig: """重排配置(provider/URL 在 services.rerank)""" + enabled: bool = True rerank_window: int = 1000 timeout_sec: float = 15.0 weight_es: float = 0.4 @@ -310,6 +311,7 @@ class ConfigLoader: # Parse Rerank (provider/URL in services.rerank) rerank_data = config_data.get("rerank", {}) rerank = RerankConfig( + enabled=bool(rerank_data.get("enabled", True)), rerank_window=int(rerank_data.get("rerank_window", 1000)), timeout_sec=float(rerank_data.get("timeout_sec", 15.0)), weight_es=float(rerank_data.get("weight_es", 0.4)), @@ -518,6 +520,7 @@ class ConfigLoader: "functions": config.function_score.functions }, "rerank": { + "enabled": config.rerank.enabled, "rerank_window": config.rerank.rerank_window, "timeout_sec": config.rerank.timeout_sec, "weight_es": config.rerank.weight_es, diff --git a/docs/性能测试报告.md b/docs/性能测试报告.md new file mode 100644 index 0000000..211dc95 --- /dev/null +++ b/docs/性能测试报告.md @@ -0,0 +1,230 @@ +# 性能测试报告 + +## 1. 文档目标 + +本报告用于沉淀 `search / suggest / embedding / reranker` 四类接口的并发性能基线,并提供可复现的完整执行流程。 +新同事可直接按本文命令重跑全流程,得到同结构结果文件并横向对比。 + +## 2. 本次测试范围与方法 + +测试范围: +- `backend_search` -> `POST /search/` +- `backend_suggest` -> `GET /search/suggestions` +- `embed_text` -> `POST /embed/text` +- `rerank` -> `POST /rerank` + +并发矩阵: +- `1 / 5 / 10 / 20` + +执行方式: +- 每组压测持续 `20s` +- 使用统一脚本 `scripts/perf_api_benchmark.py` +- 通过 `--scenario` 多值 + `--concurrency-list` 一次性跑完 `场景 x 并发` + +## 3. 压测工具优化说明(复用现有脚本) + +为了解决原脚本“一次只能跑一个场景+一个并发”的可用性问题,本次直接扩展现有脚本: +- `scripts/perf_api_benchmark.py` + +能力: +- 一条命令执行 `场景列表 x 并发列表` 全矩阵 +- 输出单份 JSON 报告(含每组结果与 overall 汇总) + +示例: + +```bash +.venv/bin/python scripts/perf_api_benchmark.py \ + --scenario backend_search,backend_suggest,embed_text,rerank \ + --concurrency-list 1,5,10,20 \ + --duration 20 \ + --tenant-id 162 \ + --output perf_reports/$(date +%F)/perf_matrix_report.json +``` + +## 4. 测试环境快照(本次) + +时间: +- `2026-03-12 08:11:34 CST` + +代码版本: +- Git commit: `28e57bb` +- Python: `3.12.3` + +机器信息: +- OS: `Linux ai-db 6.8.0-71-generic` +- CPU: `Intel(R) Xeon(R) Platinum 8255C CPU @ 2.50GHz` +- vCPU: `8` +- 内存: `30Gi`(可用约 `15Gi`) + +服务健康: +- `GET http://127.0.0.1:6002/health` -> healthy +- `GET http://127.0.0.1:6005/health` -> embedding loaded (`tei`) +- `GET http://127.0.0.1:6006/health` -> translation healthy +- `GET http://127.0.0.1:6007/health` -> reranker loaded (`Qwen/Qwen3-Reranker-0.6B`) + +索引doc数/租户基本信息: +tenant_id = 162 :注意当前该租户总 doc 数只有53,reranker、suggest、search的性能指标跟租户的doc数高度相关。以后要补充一个 +``` +curl -u 'saas:4hOaLaf41y2VuI8y' -X GET 'http://localhost:9200/search_products_tenant_162/_count?pretty' -H 'Content-Type: application/json' -d '{ + "query": { + "match_all": {} + } +}' +``` + +## 5. 执行前准备(可复现步骤) + +### 5.1 环境与依赖 + +```bash +cd /data/saas-search +source activate.sh +.venv/bin/python --version +``` + +### 5.2 启动服务 + +推荐: + +```bash +./scripts/service_ctl.sh start embedding translator reranker backend +``` + +如果 `backend` 未成功常驻,可临时手动启动: + +```bash +.venv/bin/python main.py serve --host 0.0.0.0 --port 6002 --es-host http://localhost:9200 +``` + +### 5.3 健康检查 + +```bash +curl -sS http://127.0.0.1:6002/health +curl -sS http://127.0.0.1:6005/health +curl -sS http://127.0.0.1:6006/health +curl -sS http://127.0.0.1:6007/health +``` + +## 6. 压测执行命令(本次实际) + +```bash +cd /data/saas-search +.venv/bin/python scripts/perf_api_benchmark.py \ + --scenario backend_search,backend_suggest,embed_text,rerank \ + --concurrency-list 1,5,10,20 \ + --duration 20 \ + --tenant-id 162 \ + --backend-base http://127.0.0.1:6002 \ + --embedding-base http://127.0.0.1:6005 \ + --translator-base http://127.0.0.1:6006 \ + --reranker-base http://127.0.0.1:6007 \ + --output perf_reports/2026-03-12/perf_matrix_report.json +``` + +产物文件: +- `perf_reports/2026-03-12/perf_matrix_report.json` +- `results[]` 中每条包含 `scenario + concurrency` 的单组结果 +- `overall` 为本次执行总体汇总 + +## 7. 结果总览(本次实测) + +### 7.1 Search(backend_search) + +| 并发 | 请求数 | 成功率 | 吞吐(RPS) | Avg(ms) | P95(ms) | Max(ms) | +|---:|---:|---:|---:|---:|---:|---:| +| 1 | 160 | 100.0% | 7.98 | 124.89 | 228.06 | 345.49 | +| 5 | 161 | 100.0% | 7.89 | 628.91 | 1271.49 | 1441.02 | +| 10 | 181 | 100.0% | 8.78 | 1129.23 | 1295.88 | 1330.96 | +| 20 | 161 | 100.0% | 7.63 | 2594.00 | 4706.44 | 4783.05 | + +### 7.2 Suggest(backend_suggest) + +| 并发 | 请求数 | 成功率 | 吞吐(RPS) | Avg(ms) | P95(ms) | Max(ms) | +|---:|---:|---:|---:|---:|---:|---:| +| 1 | 3502 | 100.0% | 175.09 | 5.68 | 8.70 | 15.98 | +| 5 | 4168 | 100.0% | 208.10 | 23.93 | 36.93 | 59.53 | +| 10 | 4152 | 100.0% | 207.25 | 48.05 | 59.45 | 127.20 | +| 20 | 4190 | 100.0% | 208.99 | 95.20 | 110.74 | 181.37 | + +### 7.3 Embedding(embed_text) + +| 并发 | 请求数 | 成功率 | 吞吐(RPS) | Avg(ms) | P95(ms) | Max(ms) | +|---:|---:|---:|---:|---:|---:|---:| +| 1 | 966 | 100.0% | 48.27 | 20.63 | 23.41 | 49.80 | +| 5 | 1796 | 100.0% | 89.57 | 55.55 | 69.62 | 109.84 | +| 10 | 2095 | 100.0% | 104.42 | 95.22 | 117.66 | 152.48 | +| 20 | 2393 | 100.0% | 118.70 | 167.37 | 212.21 | 318.70 | + +### 7.4 Reranker(rerank) + +| 并发 | 请求数 | 成功率 | 吞吐(RPS) | Avg(ms) | P95(ms) | Max(ms) | +|---:|---:|---:|---:|---:|---:|---:| +| 1 | 802 | 100.0% | 40.06 | 24.87 | 37.45 | 49.63 | +| 5 | 796 | 100.0% | 39.53 | 125.70 | 190.02 | 218.60 | +| 10 | 853 | 100.0% | 41.89 | 235.87 | 315.37 | 402.27 | +| 20 | 836 | 100.0% | 40.92 | 481.98 | 723.56 | 781.81 | + +## 8. 指标解读与并发建议 + +### 8.1 关键观察 + +- `backend_search`:吞吐约 `8 rps` 平台化,延迟随并发上升明显,属于重链路(检索+向量+重排)特征。 +- `backend_suggest`:吞吐高且稳定(约 `200+ rps`),对并发更友好。 +- `embed_text`:随并发提升吞吐持续增长,延迟平滑上升,扩展性较好。 +- `rerank`:吞吐在 `~40 rps` 附近平台化,延迟随并发线性抬升,符合模型推理瓶颈特征。 + +### 8.2 并发压测建议 + +- 冒烟并发:`1/5` +- 常规回归:`1/5/10/20` +- 稳态评估:建议把 `duration` 提升到 `60~300s` +- 峰值评估:在确认 timeout 与 max_errors 策略后,追加 `30/50` 并发 + +## 9. 如何复现“完整全过程” + +1. 准备环境(第 5 节) +2. 启动服务并通过健康检查 +3. 执行矩阵命令(第 6 节) +4. 查看结果: + - 原始明细:`perf_reports//perf_matrix_report.json` 的 `results[]` + - 汇总结果:同文件中的 `overall` +5. 若需导出到周报或 PR,直接拷贝本报告第 7 节四张表 + +## 10. 常见问题与排障 + +### 10.1 backend 端口起来又掉 + +现象: +- `service_ctl status backend` 显示 `running=no` + +处理: +- 先看 `logs/backend.log` +- 用手动命令前台启动,确认根因: + +```bash +.venv/bin/python main.py serve --host 0.0.0.0 --port 6002 --es-host http://localhost:9200 +``` + +### 10.2 压测脚本依赖缺失 + +现象: +- 报 `ModuleNotFoundError: httpx` + +处理: +- 使用项目虚拟环境执行: + +```bash +.venv/bin/python scripts/perf_api_benchmark.py -h +``` + +### 10.3 某场景成功率下降 + +排查顺序: +1. 看 `errors` 字段(HTTP码、timeout、payload校验失败) +2. 检查对应服务健康与日志 +3. 缩小并发重跑单场景定位阈值 + +## 11. 关联文件 + +- 压测脚本:`scripts/perf_api_benchmark.py` +- 本次结果:`perf_reports/2026-03-12/perf_matrix_report.json` diff --git a/docs/搜索API对接指南.md b/docs/搜索API对接指南.md index e62a8c6..ef4a0c7 100644 --- a/docs/搜索API对接指南.md +++ b/docs/搜索API对接指南.md @@ -201,7 +201,7 @@ response = requests.post(url, headers=headers, json={"query": "芭比娃娃"}) "min_score": 0.0, "sku_filter_dimension": ["string"], "debug": false, - "enable_rerank": false, + "enable_rerank": null, "rerank_query_template": "{query}", "rerank_doc_template": "{title}", "user_id": "string", @@ -225,7 +225,7 @@ response = requests.post(url, headers=headers, json={"query": "芭比娃娃"}) | `min_score` | float | N | null | 最小相关性分数阈值 | | `sku_filter_dimension` | array[string] | N | null | 子SKU筛选维度列表(见[SKU筛选维度](#35-sku筛选维度)) | | `debug` | boolean | N | false | 是否返回调试信息 | -| `enable_rerank` | boolean | N | false | 是否开启重排(调用外部重排服务对 ES 结果进行二次排序)。开启后若 `from+size<=rerank_window` 才会触发重排 | +| `enable_rerank` | boolean/null | N | null | 是否开启重排(调用外部重排服务对 ES 结果进行二次排序)。不传/传 null 使用服务端 `rerank.enabled`(默认开启)。开启后会先对 ES Top1000(`rerank_window`)重排,再按分页截取;若 `from+size>1000`,则不重排,直接按分页从 ES 返回 | | `rerank_query_template` | string | N | null | 重排 query 模板(可选)。支持 `{query}` 占位符;不传则使用服务端配置 | | `rerank_doc_template` | string | N | null | 重排 doc 模板(可选)。支持 `{title} {brief} {vendor} {description} {category_path}`;不传则使用服务端配置 | | `user_id` | string | N | null | 用户ID(用于个性化,预留) | diff --git a/perf_reports/2026-03-12/matrix_report/summary.md b/perf_reports/2026-03-12/matrix_report/summary.md new file mode 100644 index 0000000..ab61e7c --- /dev/null +++ b/perf_reports/2026-03-12/matrix_report/summary.md @@ -0,0 +1,41 @@ +# 性能测试矩阵结果 + +- 生成时间: 2026-03-12 08:11:03 +- 场景: backend_search, backend_suggest, embed_text, rerank +- 并发: 1, 5, 10, 20 + +## backend_search + +| 并发 | 请求数 | 成功率 | 吞吐(RPS) | Avg(ms) | P50 | P90 | P95 | P99 | Max | +|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:| +| 1 | 160 | 100.0% | 7.98 | 124.89 | 109.37 | 162.61 | 228.06 | 329.57 | 345.49 | +| 5 | 161 | 100.0% | 7.89 | 628.91 | 541.87 | 726.7 | 1271.49 | 1285.88 | 1441.02 | +| 10 | 181 | 100.0% | 8.78 | 1129.23 | 1100.46 | 1251.53 | 1295.88 | 1320.78 | 1330.96 | +| 20 | 161 | 100.0% | 7.63 | 2594.0 | 2303.96 | 4681.35 | 4706.44 | 4727.58 | 4783.05 | + +## backend_suggest + +| 并发 | 请求数 | 成功率 | 吞吐(RPS) | Avg(ms) | P50 | P90 | P95 | P99 | Max | +|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:| +| 1 | 3502 | 100.0% | 175.09 | 5.68 | 4.36 | 8.42 | 8.7 | 10.43 | 15.98 | +| 5 | 4168 | 100.0% | 208.1 | 23.93 | 21.72 | 35.72 | 36.93 | 42.08 | 59.53 | +| 10 | 4152 | 100.0% | 207.25 | 48.05 | 46.72 | 55.72 | 59.45 | 70.74 | 127.2 | +| 20 | 4190 | 100.0% | 208.99 | 95.2 | 93.51 | 104.44 | 110.74 | 164.22 | 181.37 | + +## embed_text + +| 并发 | 请求数 | 成功率 | 吞吐(RPS) | Avg(ms) | P50 | P90 | P95 | P99 | Max | +|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:| +| 1 | 966 | 100.0% | 48.27 | 20.63 | 20.0 | 21.14 | 23.41 | 30.03 | 49.8 | +| 5 | 1796 | 100.0% | 89.57 | 55.55 | 54.43 | 66.64 | 69.62 | 75.85 | 109.84 | +| 10 | 2095 | 100.0% | 104.42 | 95.22 | 98.74 | 112.52 | 117.66 | 135.94 | 152.48 | +| 20 | 2393 | 100.0% | 118.7 | 167.37 | 169.0 | 198.66 | 212.21 | 251.56 | 318.7 | + +## rerank + +| 并发 | 请求数 | 成功率 | 吞吐(RPS) | Avg(ms) | P50 | P90 | P95 | P99 | Max | +|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:| +| 1 | 802 | 100.0% | 40.06 | 24.87 | 23.0 | 30.96 | 37.45 | 43.98 | 49.63 | +| 5 | 796 | 100.0% | 39.53 | 125.7 | 113.04 | 178.39 | 190.02 | 202.09 | 218.6 | +| 10 | 853 | 100.0% | 41.89 | 235.87 | 224.75 | 274.4 | 315.37 | 383.74 | 402.27 | +| 20 | 836 | 100.0% | 40.92 | 481.98 | 454.32 | 565.75 | 723.56 | 764.15 | 781.81 | diff --git a/scripts/perf_api_benchmark.py b/scripts/perf_api_benchmark.py index ceb4ef0..510181b 100755 --- a/scripts/perf_api_benchmark.py +++ b/scripts/perf_api_benchmark.py @@ -398,13 +398,34 @@ def aggregate_results(results: List[Dict[str, Any]]) -> Dict[str, Any]: } +def parse_csv_items(raw: str) -> List[str]: + return [x.strip() for x in str(raw or "").split(",") if x.strip()] + + +def parse_csv_ints(raw: str) -> List[int]: + values: List[int] = [] + seen = set() + for item in parse_csv_items(raw): + try: + value = int(item) + except ValueError as exc: + raise ValueError(f"Invalid integer in CSV list: {item}") from exc + if value <= 0: + raise ValueError(f"Concurrency must be > 0, got {value}") + if value in seen: + continue + seen.add(value) + values.append(value) + return values + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Interface-level load test for search and related microservices") parser.add_argument( "--scenario", type=str, default="all", - help="Scenario: backend_search | backend_suggest | embed_text | translate | rerank | all", + help="Scenario: backend_search | backend_suggest | embed_text | translate | rerank | all | comma-separated list", ) parser.add_argument("--tenant-id", type=str, default="162", help="Tenant ID for backend search/suggest") parser.add_argument("--duration", type=int, default=30, help="Duration seconds per scenario; <=0 means no duration cap") @@ -421,6 +442,12 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--cases-file", type=str, default="", help="Optional JSON file to override/add request templates") parser.add_argument("--output", type=str, default="", help="Optional output JSON path") parser.add_argument("--pause", type=float, default=0.0, help="Pause seconds between scenarios in all mode") + parser.add_argument( + "--concurrency-list", + type=str, + default="", + help="Comma-separated concurrency list (e.g. 1,5,10,20). If set, overrides --concurrency.", + ) return parser.parse_args() @@ -432,21 +459,38 @@ async def main_async() -> int: if args.scenario == "all": run_names = [x for x in all_names if x in scenarios] else: - if args.scenario not in scenarios: - print(f"Unknown scenario: {args.scenario}") + requested = parse_csv_items(args.scenario) + if not requested: + print("No scenario specified.") + return 2 + unknown = [name for name in requested if name not in scenarios] + if unknown: + print(f"Unknown scenario(s): {', '.join(unknown)}") print(f"Available: {', '.join(sorted(scenarios.keys()))}") return 2 - run_names = [args.scenario] + run_names = requested if not run_names: print("No scenarios to run.") return 2 + concurrency_values = [args.concurrency] + if args.concurrency_list: + try: + concurrency_values = parse_csv_ints(args.concurrency_list) + except ValueError as exc: + print(str(exc)) + return 2 + if not concurrency_values: + print("concurrency-list is empty after parsing.") + return 2 + print("Load test config:") print(f" scenario={args.scenario}") print(f" tenant_id={args.tenant_id}") print(f" duration={args.duration}s") print(f" concurrency={args.concurrency}") + print(f" concurrency_list={concurrency_values}") print(f" max_requests={args.max_requests}") print(f" timeout={args.timeout}s") print(f" max_errors={args.max_errors}") @@ -456,29 +500,36 @@ async def main_async() -> int: print(f" reranker_base={args.reranker_base}") results: List[Dict[str, Any]] = [] - for i, name in enumerate(run_names, start=1): + total_jobs = len(run_names) * len(concurrency_values) + job_idx = 0 + for name in run_names: scenario = scenarios[name] - print(f"\\n[{i}/{len(run_names)}] running {name} ...") - result = await run_single_scenario( - scenario=scenario, - duration_sec=args.duration, - concurrency=args.concurrency, - max_requests=args.max_requests, - max_errors=args.max_errors, - ) - print(format_summary(result)) - results.append(result) - - if args.pause > 0 and i < len(run_names): - await asyncio.sleep(args.pause) + for c in concurrency_values: + job_idx += 1 + print(f"\\n[{job_idx}/{total_jobs}] running {name} @ concurrency={c} ...") + result = await run_single_scenario( + scenario=scenario, + duration_sec=args.duration, + concurrency=c, + max_requests=args.max_requests, + max_errors=args.max_errors, + ) + result["concurrency"] = c + print(format_summary(result)) + results.append(result) + + if args.pause > 0 and job_idx < total_jobs: + await asyncio.sleep(args.pause) final = { "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), "config": { "scenario": args.scenario, + "run_names": run_names, "tenant_id": args.tenant_id, "duration_sec": args.duration, "concurrency": args.concurrency, + "concurrency_list": concurrency_values, "max_requests": args.max_requests, "timeout_sec": args.timeout, "max_errors": args.max_errors, diff --git a/search/es_query_builder.py b/search/es_query_builder.py index 8a5edd8..8620a1f 100644 --- a/search/es_query_builder.py +++ b/search/es_query_builder.py @@ -815,45 +815,6 @@ class ESQueryBuilder: return filter_clauses - def add_spu_collapse( - self, - es_query: Dict[str, Any], - spu_field: str, - inner_hits_size: int = 3 - ) -> Dict[str, Any]: - """ - Add SPU aggregation/collapse to query. - - Args: - es_query: Existing ES query - spu_field: Field containing SPU ID - inner_hits_size: Number of SKUs to return per SPU - - Returns: - Modified ES query - """ - # Add collapse - es_query["collapse"] = { - "field": spu_field, - "inner_hits": { - "_source": False, - "name": "top_docs", - "size": inner_hits_size - } - } - - # Add cardinality aggregation to count unique SPUs - if "aggs" not in es_query: - es_query["aggs"] = {} - - es_query["aggs"]["unique_count"] = { - "cardinality": { - "field": spu_field - } - } - - return es_query - def add_sorting( self, es_query: Dict[str, Any], diff --git a/search/searcher.py b/search/searcher.py index 378e0a7..c5642ff 100644 --- a/search/searcher.py +++ b/search/searcher.py @@ -9,6 +9,7 @@ import os import time, json import logging import hashlib +from string import Formatter from utils.es_client import ESClient from query import QueryParser, ParsedQuery @@ -157,6 +158,75 @@ class Searcher: return es_query["_source"] = {"includes": self.source_fields} + def _resolve_rerank_source_filter(self, doc_template: str) -> Dict[str, Any]: + """ + Build a lightweight _source filter for rerank prefetch. + + Only fetch fields required by rerank doc template to reduce ES payload. + """ + field_map = { + "title": "title", + "brief": "brief", + "vendor": "vendor", + "description": "description", + "category_path": "category_path", + } + includes: set[str] = set() + template = str(doc_template or "{title}") + for _, field_name, _, _ in Formatter().parse(template): + if not field_name: + continue + key = field_name.split(".", 1)[0].split("!", 1)[0].split(":", 1)[0] + mapped = field_map.get(key) + if mapped: + includes.add(mapped) + + # Fallback to title-only to keep rerank docs usable. + if not includes: + includes.add("title") + + return {"includes": sorted(includes)} + + def _fetch_hits_by_ids( + self, + index_name: str, + doc_ids: List[str], + source_spec: Optional[Any], + ) -> tuple[Dict[str, Dict[str, Any]], int]: + """ + Fetch page documents by IDs for final response fill. + + Returns: + (hits_by_id, es_took_ms) + """ + if not doc_ids: + return {}, 0 + + body: Dict[str, Any] = { + "query": { + "ids": { + "values": doc_ids, + } + } + } + if source_spec is not None: + body["_source"] = source_spec + + resp = self.es_client.search( + index_name=index_name, + body=body, + size=len(doc_ids), + from_=0, + ) + hits = resp.get("hits", {}).get("hits") or [] + hits_by_id: Dict[str, Dict[str, Any]] = {} + for hit in hits: + hid = hit.get("_id") + if hid is None: + continue + hits_by_id[str(hid)] = hit + return hits_by_id, int(resp.get("took", 0) or 0) + def search( self, query: str, @@ -173,7 +243,7 @@ class Searcher: debug: bool = False, language: str = "en", sku_filter_dimension: Optional[List[str]] = None, - enable_rerank: bool = False, + enable_rerank: Optional[bool] = None, rerank_query_template: Optional[str] = None, rerank_doc_template: Optional[str] = None, ) -> SearchResult: @@ -206,9 +276,13 @@ class Searcher: index_langs = tenant_cfg.get("index_languages") or [] enable_translation = len(index_langs) > 0 enable_embedding = self.config.query_config.enable_text_embedding - # 重排仅由请求参数 enable_rerank 控制,唯一实现为调用外部 BGE 重排服务 - do_rerank = bool(enable_rerank) - rerank_window = self.config.rerank.rerank_window or 1000 + rc = self.config.rerank + effective_query_template = rerank_query_template or rc.rerank_query_template + effective_doc_template = rerank_doc_template or rc.rerank_doc_template + # 重排开关优先级:请求参数显式传值 > 服务端配置(默认开启) + rerank_enabled_by_config = bool(rc.enabled) + do_rerank = rerank_enabled_by_config if enable_rerank is None else bool(enable_rerank) + rerank_window = rc.rerank_window or 1000 # 若开启重排且请求范围在窗口内:从 ES 取前 rerank_window 条、重排后再按 from/size 分页;否则不重排,按原 from/size 查 ES in_rerank_window = do_rerank and (from_ + size) <= rerank_window es_fetch_from = 0 if in_rerank_window else from_ @@ -219,7 +293,9 @@ class Searcher: context.logger.info( f"开始搜索请求 | 查询: '{query}' | 参数: size={size}, from_={from_}, " - f"enable_rerank={do_rerank}, in_rerank_window={in_rerank_window}, es_fetch=({es_fetch_from},{es_fetch_size}) | " + f"enable_rerank(request)={enable_rerank}, enable_rerank(config)={rerank_enabled_by_config}, " + f"enable_rerank(effective)={do_rerank}, in_rerank_window={in_rerank_window}, " + f"es_fetch=({es_fetch_from},{es_fetch_size}) | " f"enable_translation={enable_translation}, enable_embedding={enable_embedding}, min_score={min_score}", extra={'reqid': context.reqid, 'uid': context.uid} ) @@ -231,8 +307,10 @@ class Searcher: 'es_fetch_from': es_fetch_from, 'es_fetch_size': es_fetch_size, 'in_rerank_window': in_rerank_window, - 'rerank_query_template': rerank_query_template, - 'rerank_doc_template': rerank_doc_template, + 'rerank_enabled_by_config': rerank_enabled_by_config, + 'enable_rerank_request': enable_rerank, + 'rerank_query_template': effective_query_template, + 'rerank_doc_template': effective_doc_template, 'filters': filters, 'range_filters': range_filters, 'facets': facets, @@ -323,26 +401,40 @@ class Searcher: if sort_by: es_query = self.query_builder.add_sorting(es_query, sort_by, sort_order) + # Keep requested response _source semantics for the final response fill. + response_source_spec = es_query.get("_source") + + # In rerank window, first pass only fetches minimal fields required by rerank template. + es_query_for_fetch = es_query + rerank_prefetch_source = None + if in_rerank_window: + rerank_prefetch_source = self._resolve_rerank_source_filter(effective_doc_template) + es_query_for_fetch = dict(es_query) + es_query_for_fetch["_source"] = rerank_prefetch_source + # Extract size and from from body for ES client parameters - body_for_es = {k: v for k, v in es_query.items() if k not in ['size', 'from']} + body_for_es = {k: v for k, v in es_query_for_fetch.items() if k not in ['size', 'from']} # Store ES query in context context.store_intermediate_result('es_query', es_query) + if in_rerank_window and rerank_prefetch_source is not None: + context.store_intermediate_result('es_query_rerank_prefetch_source', rerank_prefetch_source) context.store_intermediate_result('es_body_for_search', body_for_es) # Serialize ES query to compute a compact size + stable digest for correlation - es_query_compact = json.dumps(es_query, ensure_ascii=False, separators=(",", ":")) + es_query_compact = json.dumps(es_query_for_fetch, ensure_ascii=False, separators=(",", ":")) es_query_digest = hashlib.sha256(es_query_compact.encode("utf-8")).hexdigest()[:16] knn_enabled = bool(enable_embedding and parsed_query.query_vector is not None) vector_dims = int(len(parsed_query.query_vector)) if parsed_query.query_vector is not None else 0 context.logger.info( - "ES query built | size: %s chars | digest: %s | KNN: %s | vector_dims: %s | facets: %s", + "ES query built | size: %s chars | digest: %s | KNN: %s | vector_dims: %s | facets: %s | rerank_prefetch_source: %s", len(es_query_compact), es_query_digest, "yes" if knn_enabled else "no", vector_dims, "yes" if facets else "no", + rerank_prefetch_source, extra={'reqid': context.reqid, 'uid': context.uid} ) _log_backend_verbose({ @@ -355,7 +447,7 @@ class Searcher: "knn_enabled": knn_enabled, "vector_dims": vector_dims, "has_facets": bool(facets), - "query": es_query, + "query": es_query_for_fetch, }) except Exception as e: context.set_error(e) @@ -406,9 +498,6 @@ class Searcher: from .rerank_client import run_rerank rerank_query = parsed_query.original_query if parsed_query else query - rc = self.config.rerank - effective_query_template = rerank_query_template or rc.rerank_query_template - effective_doc_template = rerank_doc_template or rc.rerank_doc_template es_response, rerank_meta, fused_debug = run_rerank( query=rerank_query, es_response=es_response, @@ -457,6 +546,41 @@ class Searcher: es_response["hits"]["max_score"] = 0.0 else: es_response["hits"]["max_score"] = 0.0 + + # Page fill: fetch detailed fields only for final page hits. + if sliced: + if response_source_spec is False: + for hit in sliced: + hit.pop("_source", None) + context.logger.info( + "分页详情回填跳过 | 原查询 _source=false", + extra={'reqid': context.reqid, 'uid': context.uid} + ) + else: + page_ids = [str(h.get("_id")) for h in sliced if h.get("_id") is not None] + details_by_id, fill_took = self._fetch_hits_by_ids( + index_name=index_name, + doc_ids=page_ids, + source_spec=response_source_spec, + ) + filled = 0 + for hit in sliced: + hid = hit.get("_id") + if hid is None: + continue + detail_hit = details_by_id.get(str(hid)) + if detail_hit is None: + continue + if "_source" in detail_hit: + hit["_source"] = detail_hit.get("_source") or {} + filled += 1 + if fill_took: + es_response["took"] = int((es_response.get("took", 0) or 0) + fill_took) + context.logger.info( + f"分页详情回填 | ids={len(page_ids)} | filled={filled} | took={fill_took}ms", + extra={'reqid': context.reqid, 'uid': context.uid} + ) + context.logger.info( f"重排分页切片 | from={from_}, size={size}, 返回={len(sliced)}条", extra={'reqid': context.reqid, 'uid': context.uid} diff --git a/tests/test_search_rerank_window.py b/tests/test_search_rerank_window.py new file mode 100644 index 0000000..77f6006 --- /dev/null +++ b/tests/test_search_rerank_window.py @@ -0,0 +1,314 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List + +import yaml + +from config import ( + ConfigLoader, + FunctionScoreConfig, + IndexConfig, + QueryConfig, + RankingConfig, + RerankConfig, + SPUConfig, + SearchConfig, +) +from context import create_request_context +from search.searcher import Searcher + + +@dataclass +class _FakeParsedQuery: + original_query: str + query_normalized: str + rewritten_query: str + detected_language: str = "en" + translations: Dict[str, str] = None + query_vector: Any = None + domain: str = "default" + + def to_dict(self) -> Dict[str, Any]: + return { + "original_query": self.original_query, + "query_normalized": self.query_normalized, + "rewritten_query": self.rewritten_query, + "detected_language": self.detected_language, + "translations": self.translations or {}, + "domain": self.domain, + } + + +class _FakeQueryParser: + def parse(self, query: str, tenant_id: str, generate_vector: bool, context: Any): + return _FakeParsedQuery( + original_query=query, + query_normalized=query, + rewritten_query=query, + translations={}, + ) + + +class _FakeQueryBuilder: + def build_query(self, **kwargs): + return { + "query": {"match_all": {}}, + "size": kwargs["size"], + "from": kwargs["from_"], + } + + def build_facets(self, facets: Any): + return {} + + def add_sorting(self, es_query: Dict[str, Any], sort_by: str, sort_order: str): + return es_query + + +class _FakeESClient: + def __init__(self, total_hits: int = 5000): + self.calls: List[Dict[str, Any]] = [] + self.total_hits = total_hits + + @staticmethod + def _apply_source_filter(src: Dict[str, Any], source_spec: Any) -> Dict[str, Any]: + if source_spec is None: + return dict(src) + if source_spec is False: + return {} + if isinstance(source_spec, dict): + includes = source_spec.get("includes") or [] + elif isinstance(source_spec, list): + includes = source_spec + else: + includes = [] + if not includes: + return dict(src) + return {k: v for k, v in src.items() if k in set(includes)} + + @staticmethod + def _full_source(doc_id: str) -> Dict[str, Any]: + return { + "spu_id": doc_id, + "title": {"en": f"product-{doc_id}"}, + "brief": {"en": f"brief-{doc_id}"}, + "vendor": {"en": f"vendor-{doc_id}"}, + "skus": [], + } + + def search(self, index_name: str, body: Dict[str, Any], size: int, from_: int): + self.calls.append( + {"index_name": index_name, "body": body, "size": size, "from_": from_} + ) + ids_query = (((body or {}).get("query") or {}).get("ids") or {}).get("values") + source_spec = (body or {}).get("_source") + + if isinstance(ids_query, list): + # Return reversed order intentionally; caller should restore original ranking order. + ids = [str(i) for i in ids_query][::-1] + hits = [] + for doc_id in ids: + src = self._apply_source_filter(self._full_source(doc_id), source_spec) + hit = {"_id": doc_id, "_score": 1.0} + if source_spec is not False: + hit["_source"] = src + hits.append(hit) + else: + end = min(from_ + size, self.total_hits) + hits = [] + for i in range(from_, end): + doc_id = str(i) + src = self._apply_source_filter(self._full_source(doc_id), source_spec) + hit = {"_id": doc_id, "_score": float(self.total_hits - i)} + if source_spec is not False: + hit["_source"] = src + hits.append(hit) + + return { + "took": 8, + "hits": { + "total": {"value": self.total_hits}, + "max_score": hits[0]["_score"] if hits else 0.0, + "hits": hits, + }, + } + + +def _build_search_config(*, rerank_enabled: bool = True, rerank_window: int = 1000): + return SearchConfig( + field_boosts={"title.en": 3.0}, + indexes=[IndexConfig(name="default", label="default", fields=["title.en"])], + query_config=QueryConfig(enable_text_embedding=False, enable_query_rewrite=False), + ranking=RankingConfig(), + function_score=FunctionScoreConfig(), + rerank=RerankConfig(enabled=rerank_enabled, rerank_window=rerank_window), + spu_config=SPUConfig(enabled=False), + es_index_name="test_products", + tenant_config={}, + es_settings={}, + services={}, + ) + + +def _build_searcher(config: SearchConfig, es_client: _FakeESClient) -> Searcher: + searcher = Searcher( + es_client=es_client, + config=config, + query_parser=_FakeQueryParser(), + ) + searcher.query_builder = _FakeQueryBuilder() + return searcher + + +def test_config_loader_rerank_enabled_defaults_true(tmp_path: Path): + config_data = { + "es_index_name": "test_products", + "field_boosts": {"title.en": 3.0}, + "indexes": [{"name": "default", "label": "default", "fields": ["title.en"]}], + "query_config": {"supported_languages": ["en"], "default_language": "en"}, + "spu_config": {"enabled": False}, + "ranking": {"expression": "bm25()", "description": "test"}, + "function_score": {"score_mode": "sum", "boost_mode": "multiply", "functions": []}, + "rerank": {"rerank_window": 1000}, + } + config_path = tmp_path / "config.yaml" + config_path.write_text(yaml.safe_dump(config_data), encoding="utf-8") + + loader = ConfigLoader(config_path) + loaded = loader.load_config(validate=False) + + assert loaded.rerank.enabled is True + + +def test_searcher_reranks_top_window_by_default(monkeypatch): + es_client = _FakeESClient() + searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client) + context = create_request_context(reqid="t1", uid="u1") + + monkeypatch.setattr( + "search.searcher.get_tenant_config_loader", + lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), + ) + + called: Dict[str, Any] = {"count": 0, "docs": 0} + + def _fake_run_rerank(**kwargs): + called["count"] += 1 + called["docs"] = len(kwargs["es_response"]["hits"]["hits"]) + return kwargs["es_response"], None, [] + + monkeypatch.setattr("search.rerank_client.run_rerank", _fake_run_rerank) + + result = searcher.search( + query="toy", + tenant_id="162", + from_=20, + size=10, + context=context, + enable_rerank=None, + ) + + assert called["count"] == 1 + assert called["docs"] == 1000 + assert es_client.calls[0]["from_"] == 0 + assert es_client.calls[0]["size"] == 1000 + assert es_client.calls[0]["body"]["_source"] == {"includes": ["title"]} + assert len(es_client.calls) == 2 + assert es_client.calls[1]["size"] == 10 + assert es_client.calls[1]["from_"] == 0 + assert es_client.calls[1]["body"]["query"]["ids"]["values"] == [str(i) for i in range(20, 30)] + assert len(result.results) == 10 + assert result.results[0].spu_id == "20" + assert result.results[0].brief == "brief-20" + + +def test_searcher_rerank_prefetch_source_follows_doc_template(monkeypatch): + es_client = _FakeESClient() + searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client) + context = create_request_context(reqid="t1b", uid="u1b") + + monkeypatch.setattr( + "search.searcher.get_tenant_config_loader", + lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), + ) + monkeypatch.setattr("search.rerank_client.run_rerank", lambda **kwargs: (kwargs["es_response"], None, [])) + + searcher.search( + query="toy", + tenant_id="162", + from_=0, + size=5, + context=context, + enable_rerank=None, + rerank_doc_template="{title} {vendor} {brief}", + ) + + assert es_client.calls[0]["body"]["_source"] == {"includes": ["brief", "title", "vendor"]} + + +def test_searcher_skips_rerank_when_request_explicitly_false(monkeypatch): + es_client = _FakeESClient() + searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client) + context = create_request_context(reqid="t2", uid="u2") + + monkeypatch.setattr( + "search.searcher.get_tenant_config_loader", + lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), + ) + + called: Dict[str, int] = {"count": 0} + + def _fake_run_rerank(**kwargs): + called["count"] += 1 + return kwargs["es_response"], None, [] + + monkeypatch.setattr("search.rerank_client.run_rerank", _fake_run_rerank) + + searcher.search( + query="toy", + tenant_id="162", + from_=20, + size=10, + context=context, + enable_rerank=False, + ) + + assert called["count"] == 0 + assert es_client.calls[0]["from_"] == 20 + assert es_client.calls[0]["size"] == 10 + assert len(es_client.calls) == 1 + + +def test_searcher_skips_rerank_when_page_exceeds_window(monkeypatch): + es_client = _FakeESClient() + searcher = _build_searcher(_build_search_config(rerank_enabled=True, rerank_window=1000), es_client) + context = create_request_context(reqid="t3", uid="u3") + + monkeypatch.setattr( + "search.searcher.get_tenant_config_loader", + lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), + ) + + called: Dict[str, int] = {"count": 0} + + def _fake_run_rerank(**kwargs): + called["count"] += 1 + return kwargs["es_response"], None, [] + + monkeypatch.setattr("search.rerank_client.run_rerank", _fake_run_rerank) + + searcher.search( + query="toy", + tenant_id="162", + from_=995, + size=10, + context=context, + enable_rerank=None, + ) + + assert called["count"] == 0 + assert es_client.calls[0]["from_"] == 995 + assert es_client.calls[0]["size"] == 10 + assert len(es_client.calls) == 1 diff --git a/utils/es_client.py b/utils/es_client.py index 4d0e1d0..8896e32 100644 --- a/utils/es_client.py +++ b/utils/es_client.py @@ -258,13 +258,23 @@ class ESClient: body.pop("collapse", None) try: - return self.client.search( + response = self.client.search( index=index_name, body=body, size=size, from_=from_, routing=routing, ) + # elasticsearch-py 8.x returns ObjectApiResponse; normalize to mutable dict + # so caller can safely patch hits/took during post-processing. + if hasattr(response, "body"): + payload = response.body + if isinstance(payload, dict): + return dict(payload) + return payload + if isinstance(response, dict): + return response + return dict(response) except Exception as e: logger.error(f"Search failed: {e}", exc_info=True) raise RuntimeError(f"Elasticsearch search failed for index '{index_name}': {e}") from e -- libgit2 0.21.2