From daf66a51dbc4d9319f05454c1960a1fe44b71394 Mon Sep 17 00:00:00 2001 From: tangwang Date: Tue, 10 Mar 2026 22:10:49 +0800 Subject: [PATCH] 已完成接口级压测脚本,覆盖搜索、suggest 和微服务(embedding/translate/rerank)。 --- docs/搜索API对接指南.md | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ scripts/perf_api_benchmark.py | 464 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ scripts/perf_cases.json.example | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ suggestion/TROUBLESHOOTING.md | 41 +++++++++++++++++++++++++++++++++++++++++ suggestion/builder.py | 39 ++++++++++++++++++++++++++++++++------- tests/test_suggestions.py | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 746 insertions(+), 7 deletions(-) create mode 100755 scripts/perf_api_benchmark.py create mode 100644 scripts/perf_cases.json.example diff --git a/docs/搜索API对接指南.md b/docs/搜索API对接指南.md index bcc750e..bcc0601 100644 --- a/docs/搜索API对接指南.md +++ b/docs/搜索API对接指南.md @@ -2081,3 +2081,60 @@ curl "http://localhost:6006/health" | `hanlp_standard` ⚠️ TODO(暂不支持) | 中文 | 中文查询分析器(用于中文字段) | | `english` | 英文 | 标准英文分析器(用于英文字段) | | `lowercase` | - | 小写标准化器(用于keyword子字段) | + +--- + +## 10. 接口级压测脚本 + +仓库提供统一压测脚本:`scripts/perf_api_benchmark.py`,用于对以下接口做并发压测: + +- 后端搜索:`POST /search/` +- 搜索建议:`GET /search/suggestions` +- 向量服务:`POST /embed/text` +- 翻译服务:`POST /translate` +- 重排服务:`POST /rerank` + +### 10.1 快速示例 + +```bash +# suggest 压测(tenant 162) +python scripts/perf_api_benchmark.py \ + --scenario backend_suggest \ + --tenant-id 162 \ + --duration 30 \ + --concurrency 50 + +# search 压测 +python scripts/perf_api_benchmark.py \ + --scenario backend_search \ + --tenant-id 162 \ + --duration 30 \ + --concurrency 20 + +# 全链路压测(search + suggest + embedding + translate + rerank) +python scripts/perf_api_benchmark.py \ + --scenario all \ + --tenant-id 162 \ + --duration 60 \ + --concurrency 30 \ + --output perf_reports/all.json +``` + +### 10.2 自定义用例 + +可通过 `--cases-file` 覆盖默认请求模板。示例文件: + +```bash +scripts/perf_cases.json.example +``` + +执行示例: + +```bash +python scripts/perf_api_benchmark.py \ + --scenario all \ + --tenant-id 162 \ + --cases-file scripts/perf_cases.json.example \ + --duration 60 \ + --concurrency 40 +``` diff --git a/scripts/perf_api_benchmark.py b/scripts/perf_api_benchmark.py new file mode 100755 index 0000000..f2dba7d --- /dev/null +++ b/scripts/perf_api_benchmark.py @@ -0,0 +1,464 @@ +#!/usr/bin/env python3 +""" +API-level performance test script for search stack services. + +Default scenarios (aligned with docs/搜索API对接指南.md): +- backend_search POST /search/ +- backend_suggest GET /search/suggestions +- embed_text POST /embed/text +- translate POST /translate +- rerank POST /rerank + +Examples: + python scripts/perf_api_benchmark.py --scenario backend_search --duration 30 --concurrency 20 --tenant-id 162 + python scripts/perf_api_benchmark.py --scenario backend_suggest --duration 30 --concurrency 50 --tenant-id 162 + python scripts/perf_api_benchmark.py --scenario all --duration 60 --concurrency 80 --tenant-id 162 + python scripts/perf_api_benchmark.py --scenario all --cases-file scripts/perf_cases.json.example --output perf_result.json +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import math +import statistics +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import httpx + + +@dataclass +class RequestTemplate: + method: str + path: str + params: Optional[Dict[str, Any]] = None + json_body: Optional[Any] = None + headers: Optional[Dict[str, str]] = None + + +@dataclass +class Scenario: + name: str + templates: List[RequestTemplate] + timeout_sec: float + + +@dataclass +class RequestResult: + ok: bool + status_code: int + latency_ms: float + error: str = "" + + +def percentile(sorted_values: List[float], p: float) -> float: + if not sorted_values: + return 0.0 + if p <= 0: + return sorted_values[0] + if p >= 100: + return sorted_values[-1] + rank = (len(sorted_values) - 1) * (p / 100.0) + low = int(math.floor(rank)) + high = int(math.ceil(rank)) + if low == high: + return sorted_values[low] + weight = rank - low + return sorted_values[low] * (1.0 - weight) + sorted_values[high] * weight + + +def make_default_templates(tenant_id: str) -> Dict[str, List[RequestTemplate]]: + return { + "backend_search": [ + RequestTemplate( + method="POST", + path="/search/", + headers={"X-Tenant-ID": tenant_id}, + json_body={"query": "wireless mouse", "size": 10, "language": "en"}, + ), + RequestTemplate( + method="POST", + path="/search/", + headers={"X-Tenant-ID": tenant_id}, + json_body={"query": "芭比娃娃", "size": 10, "language": "zh"}, + ), + RequestTemplate( + method="POST", + path="/search/", + headers={"X-Tenant-ID": tenant_id}, + json_body={"query": "f", "size": 10, "language": "en"}, + ), + ], + "backend_suggest": [ + RequestTemplate( + method="GET", + path="/search/suggestions", + headers={"X-Tenant-ID": tenant_id}, + params={"q": "f", "size": 10, "language": "en"}, + ), + RequestTemplate( + method="GET", + path="/search/suggestions", + headers={"X-Tenant-ID": tenant_id}, + params={"q": "玩", "size": 10, "language": "zh"}, + ), + RequestTemplate( + method="GET", + path="/search/suggestions", + headers={"X-Tenant-ID": tenant_id}, + params={"q": "shi", "size": 10, "language": "en"}, + ), + ], + "embed_text": [ + RequestTemplate( + method="POST", + path="/embed/text", + json_body=["wireless mouse", "gaming keyboard", "barbie doll"], + ) + ], + "translate": [ + RequestTemplate( + method="POST", + path="/translate", + json_body={"text": "商品名称", "target_lang": "en", "source_lang": "zh", "model": "qwen"}, + ), + RequestTemplate( + method="POST", + path="/translate", + json_body={"text": "Product title", "target_lang": "zh", "model": "qwen"}, + ), + ], + "rerank": [ + RequestTemplate( + method="POST", + path="/rerank", + json_body={ + "query": "wireless mouse", + "docs": [ + "Wireless ergonomic mouse with rechargeable battery", + "USB-C cable 1m", + "Gaming mouse 26000 DPI", + ], + "normalize": True, + }, + ) + ], + } + + +def load_cases_from_file(path: Path, tenant_id: str) -> Dict[str, List[RequestTemplate]]: + data = json.loads(path.read_text(encoding="utf-8")) + out: Dict[str, List[RequestTemplate]] = {} + for scenario_name, requests_data in (data.get("scenarios") or {}).items(): + templates: List[RequestTemplate] = [] + for item in requests_data: + headers = dict(item.get("headers") or {}) + if "X-Tenant-ID" in headers and str(headers["X-Tenant-ID"]).strip() == "${tenant_id}": + headers["X-Tenant-ID"] = tenant_id + templates.append( + RequestTemplate( + method=str(item.get("method", "GET")).upper(), + path=str(item.get("path", "")).strip(), + params=item.get("params"), + json_body=item.get("json"), + headers=headers or None, + ) + ) + if templates: + out[scenario_name] = templates + return out + + +def build_scenarios(args: argparse.Namespace) -> Dict[str, Scenario]: + defaults = make_default_templates(args.tenant_id) + if args.cases_file: + custom = load_cases_from_file(Path(args.cases_file), tenant_id=args.tenant_id) + defaults.update(custom) + + scenario_base = { + "backend_search": args.backend_base, + "backend_suggest": args.backend_base, + "embed_text": args.embedding_base, + "translate": args.translator_base, + "rerank": args.reranker_base, + } + + scenarios: Dict[str, Scenario] = {} + for name, templates in defaults.items(): + if name not in scenario_base: + continue + base = scenario_base[name].rstrip("/") + rewritten: List[RequestTemplate] = [] + for t in templates: + path = t.path if t.path.startswith("/") else f"/{t.path}" + rewritten.append( + RequestTemplate( + method=t.method, + path=f"{base}{path}", + params=t.params, + json_body=t.json_body, + headers=t.headers, + ) + ) + scenarios[name] = Scenario(name=name, templates=rewritten, timeout_sec=args.timeout) + return scenarios + + +async def run_single_scenario( + scenario: Scenario, + duration_sec: int, + concurrency: int, + max_requests: int, + max_errors: int, +) -> Dict[str, Any]: + latencies: List[float] = [] + status_counter: Dict[int, int] = {} + err_counter: Dict[str, int] = {} + total_requests = 0 + success_requests = 0 + stop_flag = False + lock = asyncio.Lock() + start = time.perf_counter() + + timeout = httpx.Timeout(timeout=scenario.timeout_sec) + limits = httpx.Limits(max_connections=max(concurrency * 2, 20), max_keepalive_connections=max(concurrency, 10)) + + async def worker(worker_id: int, client: httpx.AsyncClient) -> None: + nonlocal total_requests, success_requests, stop_flag + idx = worker_id % len(scenario.templates) + + while not stop_flag: + elapsed = time.perf_counter() - start + if duration_sec > 0 and elapsed >= duration_sec: + break + + async with lock: + if max_requests > 0 and total_requests >= max_requests: + stop_flag = True + break + total_requests += 1 + + tpl = scenario.templates[idx % len(scenario.templates)] + idx += 1 + + t0 = time.perf_counter() + ok = False + status = 0 + err = "" + try: + resp = await client.request( + method=tpl.method, + url=tpl.path, + params=tpl.params, + json=tpl.json_body, + headers=tpl.headers, + ) + status = int(resp.status_code) + ok = 200 <= status < 300 + if not ok: + err = f"http_{status}" + except Exception as e: + err = type(e).__name__ + t1 = time.perf_counter() + cost_ms = (t1 - t0) * 1000.0 + + async with lock: + latencies.append(cost_ms) + if status: + status_counter[status] = status_counter.get(status, 0) + 1 + if ok: + success_requests += 1 + else: + err_counter[err or "unknown"] = err_counter.get(err or "unknown", 0) + 1 + total_err = sum(err_counter.values()) + if max_errors > 0 and total_err >= max_errors: + stop_flag = True + + async with httpx.AsyncClient(timeout=timeout, limits=limits) as client: + tasks = [asyncio.create_task(worker(i, client)) for i in range(concurrency)] + await asyncio.gather(*tasks) + + elapsed = max(time.perf_counter() - start, 1e-9) + lat_sorted = sorted(latencies) + + result = { + "scenario": scenario.name, + "duration_sec": round(elapsed, 3), + "total_requests": total_requests, + "success_requests": success_requests, + "failed_requests": max(total_requests - success_requests, 0), + "success_rate": round((success_requests / total_requests) * 100.0, 2) if total_requests else 0.0, + "throughput_rps": round(total_requests / elapsed, 2), + "latency_ms": { + "avg": round(statistics.mean(lat_sorted), 2) if lat_sorted else 0.0, + "p50": round(percentile(lat_sorted, 50), 2), + "p90": round(percentile(lat_sorted, 90), 2), + "p95": round(percentile(lat_sorted, 95), 2), + "p99": round(percentile(lat_sorted, 99), 2), + "max": round(max(lat_sorted), 2) if lat_sorted else 0.0, + }, + "status_codes": dict(sorted(status_counter.items(), key=lambda x: x[0])), + "errors": dict(sorted(err_counter.items(), key=lambda x: x[0])), + } + return result + + +def format_summary(result: Dict[str, Any]) -> str: + lines = [] + lines.append(f"\\n=== Scenario: {result['scenario']} ===") + lines.append( + "requests={total_requests} success={success_requests} fail={failed_requests} " + "success_rate={success_rate}% rps={throughput_rps}".format(**result) + ) + lat = result["latency_ms"] + lines.append( + f"latency(ms): avg={lat['avg']} p50={lat['p50']} p90={lat['p90']} p95={lat['p95']} p99={lat['p99']} max={lat['max']}" + ) + lines.append(f"status_codes: {result['status_codes']}") + if result["errors"]: + lines.append(f"errors: {result['errors']}") + return "\\n".join(lines) + + +def aggregate_results(results: List[Dict[str, Any]]) -> Dict[str, Any]: + if not results: + return {} + total_requests = sum(x["total_requests"] for x in results) + success_requests = sum(x["success_requests"] for x in results) + failed_requests = sum(x["failed_requests"] for x in results) + total_duration = sum(x["duration_sec"] for x in results) + weighted_avg_latency = 0.0 + if total_requests > 0: + weighted_avg_latency = sum(x["latency_ms"]["avg"] * x["total_requests"] for x in results) / total_requests + + return { + "scenario": "ALL", + "total_requests": total_requests, + "success_requests": success_requests, + "failed_requests": failed_requests, + "success_rate": round((success_requests / total_requests) * 100.0, 2) if total_requests else 0.0, + "aggregate_rps": round(total_requests / max(total_duration, 1e-9), 2), + "weighted_avg_latency_ms": round(weighted_avg_latency, 2), + } + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Interface-level load test for search and related microservices") + parser.add_argument( + "--scenario", + type=str, + default="all", + help="Scenario: backend_search | backend_suggest | embed_text | translate | rerank | all", + ) + parser.add_argument("--tenant-id", type=str, default="162", help="Tenant ID for backend search/suggest") + parser.add_argument("--duration", type=int, default=30, help="Duration seconds per scenario; <=0 means no duration cap") + parser.add_argument("--concurrency", type=int, default=20, help="Concurrent workers per scenario") + parser.add_argument("--max-requests", type=int, default=0, help="Stop after N requests per scenario (0 means unlimited)") + parser.add_argument("--timeout", type=float, default=10.0, help="Request timeout seconds") + parser.add_argument("--max-errors", type=int, default=0, help="Stop scenario when accumulated errors reach this value") + + parser.add_argument("--backend-base", type=str, default="http://127.0.0.1:6002", help="Base URL for backend search API") + parser.add_argument("--embedding-base", type=str, default="http://127.0.0.1:6005", help="Base URL for embedding service") + parser.add_argument("--translator-base", type=str, default="http://127.0.0.1:6006", help="Base URL for translation service") + parser.add_argument("--reranker-base", type=str, default="http://127.0.0.1:6007", help="Base URL for reranker service") + + parser.add_argument("--cases-file", type=str, default="", help="Optional JSON file to override/add request templates") + parser.add_argument("--output", type=str, default="", help="Optional output JSON path") + parser.add_argument("--pause", type=float, default=0.0, help="Pause seconds between scenarios in all mode") + return parser.parse_args() + + +async def main_async() -> int: + args = parse_args() + scenarios = build_scenarios(args) + + all_names = ["backend_search", "backend_suggest", "embed_text", "translate", "rerank"] + if args.scenario == "all": + run_names = [x for x in all_names if x in scenarios] + else: + if args.scenario not in scenarios: + print(f"Unknown scenario: {args.scenario}") + print(f"Available: {', '.join(sorted(scenarios.keys()))}") + return 2 + run_names = [args.scenario] + + if not run_names: + print("No scenarios to run.") + return 2 + + print("Load test config:") + print(f" scenario={args.scenario}") + print(f" tenant_id={args.tenant_id}") + print(f" duration={args.duration}s") + print(f" concurrency={args.concurrency}") + print(f" max_requests={args.max_requests}") + print(f" timeout={args.timeout}s") + print(f" max_errors={args.max_errors}") + print(f" backend_base={args.backend_base}") + print(f" embedding_base={args.embedding_base}") + print(f" translator_base={args.translator_base}") + print(f" reranker_base={args.reranker_base}") + + results: List[Dict[str, Any]] = [] + for i, name in enumerate(run_names, start=1): + scenario = scenarios[name] + print(f"\\n[{i}/{len(run_names)}] running {name} ...") + result = await run_single_scenario( + scenario=scenario, + duration_sec=args.duration, + concurrency=args.concurrency, + max_requests=args.max_requests, + max_errors=args.max_errors, + ) + print(format_summary(result)) + results.append(result) + + if args.pause > 0 and i < len(run_names): + await asyncio.sleep(args.pause) + + final = { + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + "config": { + "scenario": args.scenario, + "tenant_id": args.tenant_id, + "duration_sec": args.duration, + "concurrency": args.concurrency, + "max_requests": args.max_requests, + "timeout_sec": args.timeout, + "max_errors": args.max_errors, + "backend_base": args.backend_base, + "embedding_base": args.embedding_base, + "translator_base": args.translator_base, + "reranker_base": args.reranker_base, + "cases_file": args.cases_file or None, + }, + "results": results, + "overall": aggregate_results(results), + } + + print("\\n=== Overall ===") + print(json.dumps(final["overall"], ensure_ascii=False, indent=2)) + + if args.output: + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(final, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"Saved JSON report: {out_path}") + + return 0 + + +def main() -> int: + try: + return asyncio.run(main_async()) + except KeyboardInterrupt: + print("Interrupted by user") + return 130 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/perf_cases.json.example b/scripts/perf_cases.json.example new file mode 100644 index 0000000..df4a5be --- /dev/null +++ b/scripts/perf_cases.json.example @@ -0,0 +1,62 @@ +{ + "scenarios": { + "backend_search": [ + { + "method": "POST", + "path": "/search/", + "headers": {"X-Tenant-ID": "${tenant_id}"}, + "json": {"query": "wireless mouse", "size": 20, "language": "en", "enable_rerank": false} + }, + { + "method": "POST", + "path": "/search/", + "headers": {"X-Tenant-ID": "${tenant_id}"}, + "json": {"query": "芭比娃娃", "size": 20, "language": "zh", "enable_rerank": false} + } + ], + "backend_suggest": [ + { + "method": "GET", + "path": "/search/suggestions", + "headers": {"X-Tenant-ID": "${tenant_id}"}, + "params": {"q": "f", "size": 20, "language": "en"} + }, + { + "method": "GET", + "path": "/search/suggestions", + "headers": {"X-Tenant-ID": "${tenant_id}"}, + "params": {"q": "玩", "size": 20, "language": "zh"} + } + ], + "embed_text": [ + { + "method": "POST", + "path": "/embed/text", + "json": ["wireless mouse", "gaming keyboard", "USB-C cable", "barbie doll"] + } + ], + "translate": [ + { + "method": "POST", + "path": "/translate", + "json": {"text": "商品标题", "target_lang": "en", "source_lang": "zh", "model": "qwen"} + } + ], + "rerank": [ + { + "method": "POST", + "path": "/rerank", + "json": { + "query": "wireless mouse", + "docs": [ + "Wireless ergonomic mouse", + "Bluetooth gaming mouse", + "USB cable 1 meter", + "Mouse pad large size" + ], + "normalize": true + } + } + ] + } +} diff --git a/suggestion/TROUBLESHOOTING.md b/suggestion/TROUBLESHOOTING.md index b0427b5..577aa54 100644 --- a/suggestion/TROUBLESHOOTING.md +++ b/suggestion/TROUBLESHOOTING.md @@ -93,3 +93,44 @@ curl -u "$ES_USERNAME:$ES_PASSWORD" "$ES_HOST" ``` 或先执行一次全量。 + +## 8. `q=F` 这类前缀为空,但商品里明明有 `F...` 标题 + +### 典型原因 + +- suggestion 索引里只写入了 query_log,没写入商品 title(例如商品文档缺少 `spu_id`,但有 `id`)。 +- 英文标题太长,被噪声过滤(现在会自动提取前导短语,例如 `Furby Furblets 2-Pack`)。 + +### 逐条排查 + +1. 看 suggestion alias 是否有 `en` 文档: + +```bash +ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_162_current" +curl -u "$ES_USERNAME:$ES_PASSWORD" "$ES_HOST/$ALIAS_NAME/_search?pretty" \ + -H 'Content-Type: application/json' \ + -d '{"size":0,"aggs":{"langs":{"terms":{"field":"lang","size":20}}}}' +``` + +2. 查 `en` 下是否有 `f` 前缀: + +```bash +curl -u "$ES_USERNAME:$ES_PASSWORD" "$ES_HOST/$ALIAS_NAME/_search?pretty" \ + -H 'Content-Type: application/json' \ + -d '{"size":20,"_source":["text","text_norm","lang"],"query":{"bool":{"filter":[{"term":{"lang":"en"}}],"must":[{"prefix":{"text_norm":"f"}}]}}}' +``` + +3. 对照商品索引确认源数据确实存在 `F...`: + +```bash +curl -u "$ES_USERNAME:$ES_PASSWORD" "$ES_HOST/search_products_tenant_162/_search?pretty" \ + -H 'Content-Type: application/json' \ + -d '{"size":20,"_source":["id","spu_id","title.en"],"query":{"match_phrase_prefix":{"title.en":"f"}}}' +``` + +4. 重建后再测 API: + +```bash +./scripts/rebuild_suggestions.sh 162 F en +curl "http://localhost:6002/search/suggestions?q=F&size=40&language=en&tenant_id=162" +``` diff --git a/suggestion/builder.py b/suggestion/builder.py index a7027f5..90731c0 100644 --- a/suggestion/builder.py +++ b/suggestion/builder.py @@ -128,6 +128,27 @@ class SuggestionIndexBuilder: return text_value @staticmethod + def _prepare_title_for_suggest(title: str, max_len: int = 120) -> str: + """ + Keep title-derived suggestions concise: + - keep raw title when short enough + - for long titles, keep the leading phrase before common separators + - fallback to hard truncate + """ + raw = str(title or "").strip() + if not raw: + return "" + if len(raw) <= max_len: + return raw + + head = re.split(r"[,,;;|/\\\\((\\[【]", raw, maxsplit=1)[0].strip() + if 1 < len(head) <= max_len: + return head + + truncated = raw[:max_len].rstrip(" ,,;;|/\\\\-—–()()[]【】") + return truncated or raw[:max_len] + + @staticmethod def _split_qanchors(value: Any) -> List[str]: if value is None: return [] @@ -252,8 +273,12 @@ class SuggestionIndexBuilder: while True: body: Dict[str, Any] = { "size": batch_size, - "_source": ["spu_id", "title", "qanchors"], - "sort": [{"spu_id": "asc"}], + "_source": ["id", "spu_id", "title", "qanchors"], + # Prefer spu_id when present; fall back to id.keyword for current mappings. + "sort": [ + {"spu_id": {"order": "asc", "missing": "_last"}}, + {"id.keyword": {"order": "asc", "missing": "_last"}}, + ], "query": {"match_all": {}}, } if search_after is not None: @@ -431,8 +456,8 @@ class SuggestionIndexBuilder: # Step 1: product title/qanchors for hit in self._iter_products(tenant_id, batch_size=batch_size): src = hit.get("_source", {}) or {} - spu_id = str(src.get("spu_id") or "") - if not spu_id: + product_id = str(src.get("spu_id") or src.get("id") or hit.get("_id") or "") + if not product_id: continue title_obj = src.get("title") or {} qanchor_obj = src.get("qanchors") or {} @@ -440,7 +465,7 @@ class SuggestionIndexBuilder: for lang in index_languages: title = "" if isinstance(title_obj, dict): - title = str(title_obj.get(lang) or "").strip() + title = self._prepare_title_for_suggest(title_obj.get(lang) or "") if title: text_norm = self._normalize_text(title) if not self._looks_noise(text_norm): @@ -449,7 +474,7 @@ class SuggestionIndexBuilder: if c is None: c = SuggestionCandidate(text=title, text_norm=text_norm, lang=lang) key_to_candidate[key] = c - c.add_product("title", spu_id=spu_id) + c.add_product("title", spu_id=product_id) q_raw = None if isinstance(qanchor_obj, dict): @@ -463,7 +488,7 @@ class SuggestionIndexBuilder: if c is None: c = SuggestionCandidate(text=q_text, text_norm=text_norm, lang=lang) key_to_candidate[key] = c - c.add_product("qanchor", spu_id=spu_id) + c.add_product("qanchor", spu_id=product_id) # Step 2: query logs now = datetime.now(timezone.utc) diff --git a/tests/test_suggestions.py b/tests/test_suggestions.py index 01a0443..af1d1c0 100644 --- a/tests/test_suggestions.py +++ b/tests/test_suggestions.py @@ -345,3 +345,93 @@ def test_incremental_updates_existing_index(monkeypatch): bulk_calls = [x for x in fake_es.calls if x.get("op") == "bulk_actions"] assert len(bulk_calls) == 1 assert len(bulk_calls[0]["actions"]) == 1 + + +@pytest.mark.unit +def test_build_full_candidates_fallback_to_id_when_spu_id_missing(monkeypatch): + fake_es = FakeESClient() + builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None) + + monkeypatch.setattr( + builder, + "_iter_products", + lambda tenant_id, batch_size=500: iter( + [ + { + "_id": "521", + "_source": { + "id": "521", + "title": {"en": "Furby Toy"}, + "qanchors": {"en": "furby"}, + }, + } + ] + ), + ) + monkeypatch.setattr(builder, "_iter_query_log_rows", lambda **kwargs: iter([])) + + key_to_candidate = builder._build_full_candidates( + tenant_id="162", + index_languages=["en"], + primary_language="en", + days=365, + batch_size=100, + min_query_len=1, + ) + + title_key = ("en", "furby toy") + qanchor_key = ("en", "furby") + assert title_key in key_to_candidate + assert qanchor_key in key_to_candidate + assert key_to_candidate[title_key].title_spu_ids == {"521"} + assert key_to_candidate[qanchor_key].qanchor_spu_ids == {"521"} + + +@pytest.mark.unit +def test_build_full_candidates_splits_long_title_for_suggest(monkeypatch): + fake_es = FakeESClient() + builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None) + + long_title = ( + "Furby Furblets 2-Pack, Mini Friends Ray-Vee & Hip-Bop, 45+ Sounds Each, " + "Music & Furbish Phrases, Electronic Plush Toys, Rainbow & Pink/Purple, " + "Ages 6+ (Amazon Exclusive)" + ) + monkeypatch.setattr( + builder, + "_iter_products", + lambda tenant_id, batch_size=500: iter( + [{"_id": "521", "_source": {"id": "521", "title": {"en": long_title}, "qanchors": {}}}] + ), + ) + monkeypatch.setattr(builder, "_iter_query_log_rows", lambda **kwargs: iter([])) + + key_to_candidate = builder._build_full_candidates( + tenant_id="162", + index_languages=["en"], + primary_language="en", + days=365, + batch_size=100, + min_query_len=1, + ) + + key = ("en", "furby furblets 2-pack") + assert key in key_to_candidate + assert key_to_candidate[key].text == "Furby Furblets 2-Pack" + + +@pytest.mark.unit +def test_iter_products_requests_dual_sort_and_fields(): + fake_es = FakeESClient() + builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None) + + list(builder._iter_products(tenant_id="162", batch_size=10)) + + search_calls = [x for x in fake_es.calls if x.get("op") == "search"] + assert len(search_calls) >= 1 + body = search_calls[0]["body"] + sort = body.get("sort", []) + assert {"spu_id": {"order": "asc", "missing": "_last"}} in sort + assert {"id.keyword": {"order": "asc", "missing": "_last"}} in sort + assert "id" in body.get("_source", []) + assert "spu_id" in body.get("_source", []) -- libgit2 0.21.2