Commit 8c8b9d840a72ba73e1ab1d219cbf1bca1e2ad8ce
1 parent
ceaf6d03
ES 拉取 coarse_rank.input_window 条 -> 粗排按 text/knn 融合裁到
coarse_rank.output_window -> 再做 SKU 选择和 title suffix -> 精排调用轻量 reranker 裁到 fine_rank.output_window -> 最终重排调用现有 reranker,并在最终融合里加入 fine_score。同时把 reranker client/provider 改成了按 service_profile 选不同 service_url,这样 fine/final 可以共用同一套服务代码,只起不同实例。
Showing
20 changed files
with
1432 additions
and
53 deletions
Show diff stats
config/__init__.py
| ... | ... | @@ -4,6 +4,9 @@ from config.config_loader import ConfigLoader, ConfigurationError |
| 4 | 4 | from config.loader import AppConfigLoader, get_app_config, reload_app_config |
| 5 | 5 | from config.schema import ( |
| 6 | 6 | AppConfig, |
| 7 | + CoarseRankConfig, | |
| 8 | + CoarseRankFusionConfig, | |
| 9 | + FineRankConfig, | |
| 7 | 10 | FunctionScoreConfig, |
| 8 | 11 | IndexConfig, |
| 9 | 12 | QueryConfig, |
| ... | ... | @@ -31,8 +34,11 @@ from config.utils import get_domain_fields, get_match_fields_for_index |
| 31 | 34 | __all__ = [ |
| 32 | 35 | "AppConfig", |
| 33 | 36 | "AppConfigLoader", |
| 37 | + "CoarseRankConfig", | |
| 38 | + "CoarseRankFusionConfig", | |
| 34 | 39 | "ConfigLoader", |
| 35 | 40 | "ConfigurationError", |
| 41 | + "FineRankConfig", | |
| 36 | 42 | "FunctionScoreConfig", |
| 37 | 43 | "IndexConfig", |
| 38 | 44 | "QueryConfig", | ... | ... |
config/config.yaml
| ... | ... | @@ -228,15 +228,40 @@ function_score: |
| 228 | 228 | boost_mode: "multiply" |
| 229 | 229 | functions: [] |
| 230 | 230 | |
| 231 | +# 粗排配置(仅融合 ES 文本/向量信号,不调用模型) | |
| 232 | +coarse_rank: | |
| 233 | + enabled: true | |
| 234 | + input_window: 700 | |
| 235 | + output_window: 240 | |
| 236 | + fusion: | |
| 237 | + text_bias: 0.1 | |
| 238 | + text_exponent: 0.35 | |
| 239 | + knn_text_weight: 1.0 | |
| 240 | + knn_image_weight: 1.0 | |
| 241 | + knn_tie_breaker: 0.1 | |
| 242 | + knn_bias: 0.6 | |
| 243 | + knn_exponent: 0.0 | |
| 244 | + | |
| 245 | +# 精排配置(轻量 reranker) | |
| 246 | +fine_rank: | |
| 247 | + enabled: true | |
| 248 | + input_window: 240 | |
| 249 | + output_window: 80 | |
| 250 | + timeout_sec: 10.0 | |
| 251 | + rerank_query_template: "{query}" | |
| 252 | + rerank_doc_template: "{title}" | |
| 253 | + service_profile: "fine" | |
| 254 | + | |
| 231 | 255 | # 重排配置(provider/URL 在 services.rerank) |
| 232 | 256 | rerank: |
| 233 | 257 | enabled: true |
| 234 | - rerank_window: 400 | |
| 258 | + rerank_window: 80 | |
| 235 | 259 | timeout_sec: 15.0 |
| 236 | 260 | weight_es: 0.4 |
| 237 | 261 | weight_ai: 0.6 |
| 238 | 262 | rerank_query_template: "{query}" |
| 239 | 263 | rerank_doc_template: "{title}" |
| 264 | + service_profile: "default" | |
| 240 | 265 | # 乘法融合:fused = Π (max(score,0) + bias) ** exponent(rerank / text / knn 三项) |
| 241 | 266 | # 其中 knn_score 先做一层 dis_max: |
| 242 | 267 | # max(knn_text_weight * text_knn, knn_image_weight * image_knn) |
| ... | ... | @@ -244,6 +269,8 @@ rerank: |
| 244 | 269 | fusion: |
| 245 | 270 | rerank_bias: 0.00001 |
| 246 | 271 | rerank_exponent: 1.0 |
| 272 | + fine_bias: 0.00001 | |
| 273 | + fine_exponent: 1.0 | |
| 247 | 274 | text_bias: 0.1 |
| 248 | 275 | text_exponent: 0.35 |
| 249 | 276 | knn_text_weight: 1.0 |
| ... | ... | @@ -399,6 +426,9 @@ services: |
| 399 | 426 | http: |
| 400 | 427 | base_url: "http://127.0.0.1:6007" |
| 401 | 428 | service_url: "http://127.0.0.1:6007/rerank" |
| 429 | + service_urls: | |
| 430 | + default: "http://127.0.0.1:6007/rerank" | |
| 431 | + fine: "http://127.0.0.1:6009/rerank" | |
| 402 | 432 | request: |
| 403 | 433 | max_docs: 1000 |
| 404 | 434 | normalize: true | ... | ... |
config/loader.py
| ... | ... | @@ -27,10 +27,13 @@ except Exception: # pragma: no cover |
| 27 | 27 | from config.schema import ( |
| 28 | 28 | AppConfig, |
| 29 | 29 | AssetsConfig, |
| 30 | + CoarseRankConfig, | |
| 31 | + CoarseRankFusionConfig, | |
| 30 | 32 | ConfigMetadata, |
| 31 | 33 | DatabaseSettings, |
| 32 | 34 | ElasticsearchSettings, |
| 33 | 35 | EmbeddingServiceConfig, |
| 36 | + FineRankConfig, | |
| 34 | 37 | FunctionScoreConfig, |
| 35 | 38 | IndexConfig, |
| 36 | 39 | InfrastructureConfig, |
| ... | ... | @@ -464,6 +467,11 @@ class AppConfigLoader: |
| 464 | 467 | ) |
| 465 | 468 | |
| 466 | 469 | function_score_cfg = raw.get("function_score") if isinstance(raw.get("function_score"), dict) else {} |
| 470 | + coarse_rank_cfg = raw.get("coarse_rank") if isinstance(raw.get("coarse_rank"), dict) else {} | |
| 471 | + coarse_fusion_raw = ( | |
| 472 | + coarse_rank_cfg.get("fusion") if isinstance(coarse_rank_cfg.get("fusion"), dict) else {} | |
| 473 | + ) | |
| 474 | + fine_rank_cfg = raw.get("fine_rank") if isinstance(raw.get("fine_rank"), dict) else {} | |
| 467 | 475 | rerank_cfg = raw.get("rerank") if isinstance(raw.get("rerank"), dict) else {} |
| 468 | 476 | fusion_raw = rerank_cfg.get("fusion") if isinstance(rerank_cfg.get("fusion"), dict) else {} |
| 469 | 477 | spu_cfg = raw.get("spu_config") if isinstance(raw.get("spu_config"), dict) else {} |
| ... | ... | @@ -477,6 +485,33 @@ class AppConfigLoader: |
| 477 | 485 | boost_mode=str(function_score_cfg.get("boost_mode") or "multiply"), |
| 478 | 486 | functions=list(function_score_cfg.get("functions") or []), |
| 479 | 487 | ), |
| 488 | + coarse_rank=CoarseRankConfig( | |
| 489 | + enabled=bool(coarse_rank_cfg.get("enabled", True)), | |
| 490 | + input_window=int(coarse_rank_cfg.get("input_window", 700)), | |
| 491 | + output_window=int(coarse_rank_cfg.get("output_window", 240)), | |
| 492 | + fusion=CoarseRankFusionConfig( | |
| 493 | + text_bias=float(coarse_fusion_raw.get("text_bias", 0.1)), | |
| 494 | + text_exponent=float(coarse_fusion_raw.get("text_exponent", 0.35)), | |
| 495 | + knn_text_weight=float(coarse_fusion_raw.get("knn_text_weight", 1.0)), | |
| 496 | + knn_image_weight=float(coarse_fusion_raw.get("knn_image_weight", 1.0)), | |
| 497 | + knn_tie_breaker=float(coarse_fusion_raw.get("knn_tie_breaker", 0.0)), | |
| 498 | + knn_bias=float(coarse_fusion_raw.get("knn_bias", 0.6)), | |
| 499 | + knn_exponent=float(coarse_fusion_raw.get("knn_exponent", 0.2)), | |
| 500 | + ), | |
| 501 | + ), | |
| 502 | + fine_rank=FineRankConfig( | |
| 503 | + enabled=bool(fine_rank_cfg.get("enabled", True)), | |
| 504 | + input_window=int(fine_rank_cfg.get("input_window", 240)), | |
| 505 | + output_window=int(fine_rank_cfg.get("output_window", 80)), | |
| 506 | + timeout_sec=float(fine_rank_cfg.get("timeout_sec", 10.0)), | |
| 507 | + rerank_query_template=str(fine_rank_cfg.get("rerank_query_template") or "{query}"), | |
| 508 | + rerank_doc_template=str(fine_rank_cfg.get("rerank_doc_template") or "{title}"), | |
| 509 | + service_profile=( | |
| 510 | + str(v) | |
| 511 | + if (v := fine_rank_cfg.get("service_profile")) not in (None, "") | |
| 512 | + else "fine" | |
| 513 | + ), | |
| 514 | + ), | |
| 480 | 515 | rerank=RerankConfig( |
| 481 | 516 | enabled=bool(rerank_cfg.get("enabled", True)), |
| 482 | 517 | rerank_window=int(rerank_cfg.get("rerank_window", 384)), |
| ... | ... | @@ -485,6 +520,11 @@ class AppConfigLoader: |
| 485 | 520 | weight_ai=float(rerank_cfg.get("weight_ai", 0.6)), |
| 486 | 521 | rerank_query_template=str(rerank_cfg.get("rerank_query_template") or "{query}"), |
| 487 | 522 | rerank_doc_template=str(rerank_cfg.get("rerank_doc_template") or "{title}"), |
| 523 | + service_profile=( | |
| 524 | + str(v) | |
| 525 | + if (v := rerank_cfg.get("service_profile")) not in (None, "") | |
| 526 | + else None | |
| 527 | + ), | |
| 488 | 528 | fusion=RerankFusionConfig( |
| 489 | 529 | rerank_bias=float(fusion_raw.get("rerank_bias", 0.00001)), |
| 490 | 530 | rerank_exponent=float(fusion_raw.get("rerank_exponent", 1.0)), |
| ... | ... | @@ -495,6 +535,8 @@ class AppConfigLoader: |
| 495 | 535 | knn_tie_breaker=float(fusion_raw.get("knn_tie_breaker", 0.0)), |
| 496 | 536 | knn_bias=float(fusion_raw.get("knn_bias", 0.6)), |
| 497 | 537 | knn_exponent=float(fusion_raw.get("knn_exponent", 0.2)), |
| 538 | + fine_bias=float(fusion_raw.get("fine_bias", 0.00001)), | |
| 539 | + fine_exponent=float(fusion_raw.get("fine_exponent", 1.0)), | |
| 498 | 540 | ), |
| 499 | 541 | ), |
| 500 | 542 | spu_config=SPUConfig( | ... | ... |
config/schema.py
| ... | ... | @@ -117,6 +117,48 @@ class RerankFusionConfig: |
| 117 | 117 | knn_tie_breaker: float = 0.0 |
| 118 | 118 | knn_bias: float = 0.6 |
| 119 | 119 | knn_exponent: float = 0.2 |
| 120 | + fine_bias: float = 0.00001 | |
| 121 | + fine_exponent: float = 1.0 | |
| 122 | + | |
| 123 | + | |
| 124 | +@dataclass(frozen=True) | |
| 125 | +class CoarseRankFusionConfig: | |
| 126 | + """ | |
| 127 | + Multiplicative fusion without model score: | |
| 128 | + fused = (max(text, 0) + text_bias) ** text_exponent | |
| 129 | + * (max(knn, 0) + knn_bias) ** knn_exponent | |
| 130 | + """ | |
| 131 | + | |
| 132 | + text_bias: float = 0.1 | |
| 133 | + text_exponent: float = 0.35 | |
| 134 | + knn_text_weight: float = 1.0 | |
| 135 | + knn_image_weight: float = 1.0 | |
| 136 | + knn_tie_breaker: float = 0.0 | |
| 137 | + knn_bias: float = 0.6 | |
| 138 | + knn_exponent: float = 0.2 | |
| 139 | + | |
| 140 | + | |
| 141 | +@dataclass(frozen=True) | |
| 142 | +class CoarseRankConfig: | |
| 143 | + """Search-time coarse ranking configuration.""" | |
| 144 | + | |
| 145 | + enabled: bool = True | |
| 146 | + input_window: int = 700 | |
| 147 | + output_window: int = 240 | |
| 148 | + fusion: CoarseRankFusionConfig = field(default_factory=CoarseRankFusionConfig) | |
| 149 | + | |
| 150 | + | |
| 151 | +@dataclass(frozen=True) | |
| 152 | +class FineRankConfig: | |
| 153 | + """Search-time lightweight rerank configuration.""" | |
| 154 | + | |
| 155 | + enabled: bool = True | |
| 156 | + input_window: int = 240 | |
| 157 | + output_window: int = 80 | |
| 158 | + timeout_sec: float = 10.0 | |
| 159 | + rerank_query_template: str = "{query}" | |
| 160 | + rerank_doc_template: str = "{title}" | |
| 161 | + service_profile: Optional[str] = "fine" | |
| 120 | 162 | |
| 121 | 163 | |
| 122 | 164 | @dataclass(frozen=True) |
| ... | ... | @@ -130,6 +172,7 @@ class RerankConfig: |
| 130 | 172 | weight_ai: float = 0.6 |
| 131 | 173 | rerank_query_template: str = "{query}" |
| 132 | 174 | rerank_doc_template: str = "{title}" |
| 175 | + service_profile: Optional[str] = None | |
| 133 | 176 | fusion: RerankFusionConfig = field(default_factory=RerankFusionConfig) |
| 134 | 177 | |
| 135 | 178 | |
| ... | ... | @@ -141,6 +184,8 @@ class SearchConfig: |
| 141 | 184 | indexes: List[IndexConfig] = field(default_factory=list) |
| 142 | 185 | query_config: QueryConfig = field(default_factory=QueryConfig) |
| 143 | 186 | function_score: FunctionScoreConfig = field(default_factory=FunctionScoreConfig) |
| 187 | + coarse_rank: CoarseRankConfig = field(default_factory=CoarseRankConfig) | |
| 188 | + fine_rank: FineRankConfig = field(default_factory=FineRankConfig) | |
| 144 | 189 | rerank: RerankConfig = field(default_factory=RerankConfig) |
| 145 | 190 | spu_config: SPUConfig = field(default_factory=SPUConfig) |
| 146 | 191 | es_index_name: str = "search_products" | ... | ... |
config/services_config.py
| ... | ... | @@ -71,13 +71,20 @@ def get_rerank_backend_config() -> Tuple[str, Dict[str, Any]]: |
| 71 | 71 | return cfg.backend, cfg.get_backend_config() |
| 72 | 72 | |
| 73 | 73 | |
| 74 | -def get_rerank_base_url() -> str: | |
| 74 | +def get_rerank_base_url(profile: str | None = None) -> str: | |
| 75 | 75 | provider_cfg = get_app_config().services.rerank.get_provider_config() |
| 76 | - base = provider_cfg.get("service_url") or provider_cfg.get("base_url") | |
| 76 | + base = None | |
| 77 | + profile_name = str(profile).strip() if profile else "" | |
| 78 | + if profile_name: | |
| 79 | + service_urls = provider_cfg.get("service_urls") | |
| 80 | + if isinstance(service_urls, dict): | |
| 81 | + base = service_urls.get(profile_name) | |
| 82 | + if not base: | |
| 83 | + base = provider_cfg.get("service_url") or provider_cfg.get("base_url") | |
| 77 | 84 | if not base: |
| 78 | 85 | raise ValueError("Rerank service URL is not configured") |
| 79 | 86 | return str(base).rstrip("/") |
| 80 | 87 | |
| 81 | 88 | |
| 82 | -def get_rerank_service_url() -> str: | |
| 83 | - return get_rerank_base_url() | |
| 89 | +def get_rerank_service_url(profile: str | None = None) -> str: | |
| 90 | + return get_rerank_base_url(profile=profile) | ... | ... |
context/request_context.py
| ... | ... | @@ -26,6 +26,8 @@ class RequestContextStage(Enum): |
| 26 | 26 | # ES 按 ID 回源分页详情回填 |
| 27 | 27 | ELASTICSEARCH_PAGE_FILL = "elasticsearch_page_fill" |
| 28 | 28 | RESULT_PROCESSING = "result_processing" |
| 29 | + COARSE_RANKING = "coarse_ranking" | |
| 30 | + FINE_RANKING = "fine_ranking" | |
| 29 | 31 | RERANKING = "reranking" |
| 30 | 32 | # 款式意图 SKU 预筛选(StyleSkuSelector.prepare_hits) |
| 31 | 33 | STYLE_SKU_PREPARE_HITS = "style_sku_prepare_hits" |
| ... | ... | @@ -407,4 +409,4 @@ def clear_current_request_context() -> None: |
| 407 | 409 | reset_request_log_context(tokens) |
| 408 | 410 | delattr(threading.current_thread(), 'request_log_tokens') |
| 409 | 411 | if hasattr(threading.current_thread(), 'request_context'): |
| 410 | - delattr(threading.current_thread(), 'request_context') | |
| 411 | 412 | \ No newline at end of file |
| 413 | + delattr(threading.current_thread(), 'request_context') | ... | ... |
docs/TODO-ES能力提升.md renamed to docs/issue-2026-03-21-ES能力提升.md
docs/TODO-意图判断-done.md renamed to docs/issue-2026-03-21-意图判断-done03-24.md
docs/issue-2026-03-26-ES文本搜索-补充多模态knn放入should-done-0327.md
0 → 100644
| ... | ... | @@ -0,0 +1,72 @@ |
| 1 | +目前knn跟query里面是并列的层级,如下: | |
| 2 | +{ | |
| 3 | + "size": 400, | |
| 4 | + "from": 0, | |
| 5 | + "query": { | |
| 6 | + "bool": { | |
| 7 | + "must": [... | |
| 8 | + ], | |
| 9 | + } | |
| 10 | + }, | |
| 11 | + "knn": { | |
| 12 | + "field": "title_embedding", | |
| 13 | + "query_vector": [...], | |
| 14 | + "k": 120, | |
| 15 | + "num_candidates": 400, | |
| 16 | + "boost": 2, | |
| 17 | + "_name": "knn_query" | |
| 18 | + }, | |
| 19 | +其中query的结构是这样的: | |
| 20 | +"query": { | |
| 21 | + "bool": { | |
| 22 | + "should": [ | |
| 23 | + { | |
| 24 | + "bool": { | |
| 25 | + "_name": "base_query", | |
| 26 | +\# 原始query | |
| 27 | + } | |
| 28 | + }, | |
| 29 | + { | |
| 30 | + "bool": { | |
| 31 | + "_name": "base_query_trans_zh", | |
| 32 | +\# 翻译query。有可能是base_query_trans_en,也有可能两者都有 | |
| 33 | + "boost": 0.75 | |
| 34 | + } | |
| 35 | + } | |
| 36 | + ], | |
| 37 | + "minimum_should_match": 1 | |
| 38 | + } | |
| 39 | + }, | |
| 40 | +我想把knn放到should里面,和base_query、base_query_trans_zh并列。 | |
| 41 | +另外,现在过滤是在knn里面单独加了一遍: | |
| 42 | + "knn": { | |
| 43 | + "field": "title_embedding", | |
| 44 | + "query_vector": [...], | |
| 45 | + "k": 120, | |
| 46 | + "num_candidates": 400, | |
| 47 | + "boost": 2, | |
| 48 | + "_name": "knn_query", | |
| 49 | + "filter": { | |
| 50 | + "range": { | |
| 51 | + "min_price": { | |
| 52 | + "gte": 100, | |
| 53 | + "lt": 200 | |
| 54 | + } | |
| 55 | + } | |
| 56 | + } | |
| 57 | + } | |
| 58 | +现在不需要了。因为knn在query的内层了。共用过滤。 | |
| 59 | + | |
| 60 | +另外: | |
| 61 | +我需要再增加一个knn。 | |
| 62 | +需要参考文本embedding获得的逻辑, | |
| 63 | +通过 | |
| 64 | +curl -X POST "http://localhost:6008/embed/clip_text?normalize=true&priority=1" \ | |
| 65 | + -H "Content-Type: application/json" \ | |
| 66 | + -d '["纯棉短袖", "street tee"]' | |
| 67 | +(用 POST /embed/clip_text 生成多模态文本向量。和文本embedding获取方法类似。注意思考代码如何精简,不要冗余。) | |
| 68 | +得到文本的多模态embedding。 | |
| 69 | +然后在这里补充一个多模态embedding,寻找图片相似的结果,对应的商品图片字段为image_embedding.vector。 | |
| 70 | +重排融合:之前有knn的配置bias和exponential。现在,文本和图片的embedding相似需要融合,融合方式是dis_max,因此需要配置: | |
| 71 | +1)各自的权重和tie_breaker | |
| 72 | +2)整个向量方面的权重(bias和exponential) | |
| 0 | 73 | \ No newline at end of file | ... | ... |
docs/TODO-keywords限定-done.txt renamed to docs/issue-2026-03-27-keywords限定-done-0327.txt
docs/TODO.md renamed to docs/issue.md
docs/TODO.txt renamed to docs/issue.txt
providers/rerank.py
| ... | ... | @@ -57,7 +57,7 @@ class HttpRerankProvider: |
| 57 | 57 | return None, None |
| 58 | 58 | |
| 59 | 59 | |
| 60 | -def create_rerank_provider() -> HttpRerankProvider: | |
| 60 | +def create_rerank_provider(service_profile: Optional[str] = None) -> HttpRerankProvider: | |
| 61 | 61 | """Create rerank provider from services config.""" |
| 62 | 62 | cfg = get_rerank_config() |
| 63 | 63 | provider = (cfg.provider or "http").strip().lower() |
| ... | ... | @@ -65,5 +65,5 @@ def create_rerank_provider() -> HttpRerankProvider: |
| 65 | 65 | if provider != "http": |
| 66 | 66 | raise ValueError(f"Unsupported rerank provider: {provider}") |
| 67 | 67 | |
| 68 | - url = get_rerank_service_url() | |
| 68 | + url = get_rerank_service_url(profile=service_profile) | |
| 69 | 69 | return HttpRerankProvider(service_url=url) | ... | ... |
scripts/experiments/english_query_bucketing_demo.py
0 → 100644
| ... | ... | @@ -0,0 +1,554 @@ |
| 1 | +#!/usr/bin/env python3 | |
| 2 | +""" | |
| 3 | +Offline experiment: English query bucketing (intersection / boost / drop). | |
| 4 | + | |
| 5 | +Scheme A: spaCy noun_chunks + head + lemma + rule buckets | |
| 6 | +Scheme B: spaCy NP candidates + KeyBERT rerank → intersection vs boost | |
| 7 | +Scheme C: YAKE + spaCy noun/POS filter | |
| 8 | + | |
| 9 | +Run (after deps): python scripts/experiments/english_query_bucketing_demo.py | |
| 10 | +Optional: pip install -r scripts/experiments/requirements_query_bucketing_experiments.txt | |
| 11 | +""" | |
| 12 | + | |
| 13 | +from __future__ import annotations | |
| 14 | + | |
| 15 | +import argparse | |
| 16 | +import json | |
| 17 | +import re | |
| 18 | +import sys | |
| 19 | +from dataclasses import dataclass, field | |
| 20 | +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple | |
| 21 | + | |
| 22 | + | |
| 23 | +# --- shared ----------------------------------------------------------------- | |
| 24 | + | |
| 25 | +_POSSESSIVE_RE = re.compile(r"(['’]s)\b", re.IGNORECASE) | |
| 26 | + | |
| 27 | + | |
| 28 | +def normalize_query(s: str) -> str: | |
| 29 | + s = (s or "").strip() | |
| 30 | + s = _POSSESSIVE_RE.sub("", s) | |
| 31 | + return s | |
| 32 | + | |
| 33 | + | |
| 34 | +@dataclass | |
| 35 | +class BucketResult: | |
| 36 | + intersection_terms: List[str] = field(default_factory=list) | |
| 37 | + boost_terms: List[str] = field(default_factory=list) | |
| 38 | + drop_terms: List[str] = field(default_factory=list) | |
| 39 | + | |
| 40 | + def to_dict(self) -> Dict[str, Any]: | |
| 41 | + return { | |
| 42 | + "intersection_terms": self.intersection_terms, | |
| 43 | + "boost_terms": self.boost_terms, | |
| 44 | + "drop_terms": self.drop_terms, | |
| 45 | + } | |
| 46 | + | |
| 47 | + | |
| 48 | +def _dedupe_preserve(seq: Sequence[str]) -> List[str]: | |
| 49 | + seen: Set[str] = set() | |
| 50 | + out: List[str] = [] | |
| 51 | + for x in seq: | |
| 52 | + k = x.strip().lower() | |
| 53 | + if not k or k in seen: | |
| 54 | + continue | |
| 55 | + seen.add(k) | |
| 56 | + out.append(x.strip()) | |
| 57 | + return out | |
| 58 | + | |
| 59 | + | |
| 60 | +# --- Scheme A: spaCy + rules ------------------------------------------------- | |
| 61 | + | |
| 62 | +WEAK_BOOST_ADJS = frozenset( | |
| 63 | + { | |
| 64 | + "best", | |
| 65 | + "good", | |
| 66 | + "great", | |
| 67 | + "new", | |
| 68 | + "free", | |
| 69 | + "cheap", | |
| 70 | + "top", | |
| 71 | + "fine", | |
| 72 | + "real", | |
| 73 | + } | |
| 74 | +) | |
| 75 | + | |
| 76 | +FUNCTIONAL_DEP = frozenset( | |
| 77 | + { | |
| 78 | + "det", | |
| 79 | + "aux", | |
| 80 | + "auxpass", | |
| 81 | + "prep", | |
| 82 | + "mark", | |
| 83 | + "expl", | |
| 84 | + "cc", | |
| 85 | + "punct", | |
| 86 | + "case", | |
| 87 | + } | |
| 88 | +) | |
| 89 | + | |
| 90 | +# Second pobj under list-like INTJ roots often encodes audience/size (boost, not must-match). | |
| 91 | +_DEMOGRAPHIC_NOUNS = frozenset( | |
| 92 | + { | |
| 93 | + "women", | |
| 94 | + "woman", | |
| 95 | + "men", | |
| 96 | + "man", | |
| 97 | + "kids", | |
| 98 | + "kid", | |
| 99 | + "boys", | |
| 100 | + "boy", | |
| 101 | + "girls", | |
| 102 | + "girl", | |
| 103 | + "baby", | |
| 104 | + "babies", | |
| 105 | + "toddler", | |
| 106 | + "adult", | |
| 107 | + "adults", | |
| 108 | + } | |
| 109 | +) | |
| 110 | + | |
| 111 | + | |
| 112 | +def _lemma_lower(t) -> str: | |
| 113 | + return ((t.lemma_ or t.text) or "").lower().strip() | |
| 114 | + | |
| 115 | + | |
| 116 | +def _surface_lower(t) -> str: | |
| 117 | + """Lowercased surface form (keeps plural 'headphones' vs lemma 'headphone').""" | |
| 118 | + return (t.text or "").lower().strip() | |
| 119 | + | |
| 120 | + | |
| 121 | +_PRICE_PREP_LEMMAS = frozenset({"under", "over", "below", "above", "within", "between", "near"}) | |
| 122 | + | |
| 123 | + | |
| 124 | +def bucket_scheme_a_spacy(query: str, nlp) -> BucketResult: | |
| 125 | + """ | |
| 126 | + Dependency-first bucketing: noun_chunks alone mis-parse verbal queries like | |
| 127 | + "noise cancelling headphones" (ROOT verb). Prefer dobj / ROOT product nouns, | |
| 128 | + purpose PP (for …), and brand INTJ/PROPN. | |
| 129 | + """ | |
| 130 | + import spacy # noqa: F401 | |
| 131 | + | |
| 132 | + # Do not strip possessives ('s) before spaCy: it changes the parse tree | |
| 133 | + # (e.g. "women's running shoes size 8" vs "women running shoes size 8"). | |
| 134 | + text = (query or "").strip() | |
| 135 | + doc = nlp(text) | |
| 136 | + intersection: Set[str] = set() | |
| 137 | + boost: Set[str] = set() | |
| 138 | + drop: Set[str] = set() | |
| 139 | + | |
| 140 | + stops = nlp.Defaults.stop_words | WEAK_BOOST_ADJS | |
| 141 | + | |
| 142 | + def mark_drop(t) -> None: | |
| 143 | + if not t.is_space and not t.is_punct: | |
| 144 | + drop.add(t.text.lower()) | |
| 145 | + | |
| 146 | + # --- Drops: function words / question words --- | |
| 147 | + for token in doc: | |
| 148 | + if token.is_space or token.is_punct: | |
| 149 | + continue | |
| 150 | + lem = _lemma_lower(token) | |
| 151 | + if token.pos_ in ("DET", "PRON", "AUX", "ADP", "PART", "SCONJ", "CCONJ"): | |
| 152 | + mark_drop(token) | |
| 153 | + continue | |
| 154 | + if token.dep_ in FUNCTIONAL_DEP: | |
| 155 | + mark_drop(token) | |
| 156 | + continue | |
| 157 | + if token.pos_ == "ADV" and lem in {"where", "how", "when", "why", "what", "which"}: | |
| 158 | + mark_drop(token) | |
| 159 | + continue | |
| 160 | + if token.text.lower() in ("'s", "’s"): | |
| 161 | + mark_drop(token) | |
| 162 | + continue | |
| 163 | + if lem in stops and token.pos_ != "PROPN": | |
| 164 | + mark_drop(token) | |
| 165 | + | |
| 166 | + pobj_heads_to_demote: Set[int] = set() | |
| 167 | + | |
| 168 | + # Purpose / context: "for airplane travel" → boost phrase; demote bare head from intersection | |
| 169 | + for token in doc: | |
| 170 | + if token.dep_ == "prep" and token.text.lower() == "for": | |
| 171 | + for c in token.children: | |
| 172 | + if c.dep_ == "pobj" and c.pos_ in ("NOUN", "PROPN"): | |
| 173 | + span = doc[c.left_edge.i : c.right_edge.i + 1] | |
| 174 | + phrase = span.text.strip().lower() | |
| 175 | + if phrase: | |
| 176 | + boost.add(phrase) | |
| 177 | + pobj_heads_to_demote.add(c.i) | |
| 178 | + | |
| 179 | + # Price / range: "under 500 dollars" → boost only | |
| 180 | + for token in doc: | |
| 181 | + if token.dep_ != "prep" or _lemma_lower(token) not in _PRICE_PREP_LEMMAS: | |
| 182 | + continue | |
| 183 | + for c in token.children: | |
| 184 | + if c.dep_ == "pobj" and c.pos_ in ("NOUN", "PROPN"): | |
| 185 | + span = doc[c.left_edge.i : c.right_edge.i + 1] | |
| 186 | + phrase = span.text.strip().lower() | |
| 187 | + if phrase: | |
| 188 | + boost.add(phrase) | |
| 189 | + pobj_heads_to_demote.add(c.i) | |
| 190 | + | |
| 191 | + # Direct object product nouns (handles "noise cancelling … headphones") | |
| 192 | + for token in doc: | |
| 193 | + if token.dep_ == "dobj" and token.pos_ in ("NOUN", "PROPN"): | |
| 194 | + if token.i in pobj_heads_to_demote: | |
| 195 | + continue | |
| 196 | + intersection.add(_surface_lower(token)) | |
| 197 | + | |
| 198 | + # Copular questions / definitions: "what is the best smartphone …" | |
| 199 | + for token in doc: | |
| 200 | + if token.dep_ != "nsubj" or token.pos_ not in ("NOUN", "PROPN"): | |
| 201 | + continue | |
| 202 | + h = token.head | |
| 203 | + if h.pos_ == "AUX" and h.dep_ == "ROOT": | |
| 204 | + intersection.add(_surface_lower(token)) | |
| 205 | + | |
| 206 | + # Verbal ROOT: modifiers left of dobj → boost phrase (e.g. "noise cancelling") | |
| 207 | + roots = [t for t in doc if t.dep_ == "ROOT"] | |
| 208 | + if roots and roots[0].pos_ == "VERB": | |
| 209 | + root_v = roots[0] | |
| 210 | + for t in doc: | |
| 211 | + if t.dep_ != "dobj" or t.pos_ not in ("NOUN", "PROPN"): | |
| 212 | + continue | |
| 213 | + if t.i in pobj_heads_to_demote: | |
| 214 | + continue | |
| 215 | + parts: List[str] = [] | |
| 216 | + for x in doc[: t.i]: | |
| 217 | + if x.is_punct or x.is_space: | |
| 218 | + continue | |
| 219 | + if x.pos_ in ("DET", "ADP", "PRON"): | |
| 220 | + continue | |
| 221 | + xl = _lemma_lower(x) | |
| 222 | + if xl in stops: | |
| 223 | + continue | |
| 224 | + parts.append(x.text.lower()) | |
| 225 | + if len(parts) >= 1: | |
| 226 | + boost.add(" ".join(parts)) | |
| 227 | + | |
| 228 | + # Brand / query lead: INTJ/PROPN ROOT (e.g. Nike …) | |
| 229 | + for token in doc: | |
| 230 | + if token.dep_ == "ROOT" and token.pos_ in ("INTJ", "PROPN"): | |
| 231 | + intersection.add(_surface_lower(token)) | |
| 232 | + if token.pos_ == "PROPN": | |
| 233 | + intersection.add(_surface_lower(token)) | |
| 234 | + | |
| 235 | + _DIMENSION_ROOTS = frozenset({"size", "width", "length", "height", "weight"}) | |
| 236 | + | |
| 237 | + # "women's running shoes size 8" → shoes ∩, "size 8" boost (not size alone) | |
| 238 | + for token in doc: | |
| 239 | + if token.dep_ != "ROOT" or token.pos_ != "NOUN": | |
| 240 | + continue | |
| 241 | + if _lemma_lower(token) not in _DIMENSION_ROOTS: | |
| 242 | + continue | |
| 243 | + for c in token.children: | |
| 244 | + if c.dep_ == "nsubj" and c.pos_ in ("NOUN", "PROPN"): | |
| 245 | + intersection.add(_surface_lower(c)) | |
| 246 | + for ch in c.children: | |
| 247 | + if ch.dep_ == "compound" and ch.pos_ in ("NOUN", "VERB", "ADJ"): | |
| 248 | + boost.add(_surface_lower(ch)) | |
| 249 | + # Only the dimension head + numbers (not full subtree: left_edge/right_edge is huge) | |
| 250 | + dim_parts = [token.text.lower()] | |
| 251 | + for ch in token.children: | |
| 252 | + if ch.dep_ == "nummod": | |
| 253 | + dim_parts.append(ch.text.lower()) | |
| 254 | + boost.add(" ".join(dim_parts)) | |
| 255 | + | |
| 256 | + # ROOT noun product (e.g. "plastic toy car") | |
| 257 | + for token in doc: | |
| 258 | + if token.dep_ == "ROOT" and token.pos_ in ("NOUN", "PROPN"): | |
| 259 | + if _lemma_lower(token) in _DIMENSION_ROOTS and any( | |
| 260 | + c.dep_ == "nsubj" and c.pos_ in ("NOUN", "PROPN") for c in token.children | |
| 261 | + ): | |
| 262 | + continue | |
| 263 | + intersection.add(_surface_lower(token)) | |
| 264 | + for c in token.children: | |
| 265 | + if c.dep_ == "compound" and c.pos_ == "NOUN": | |
| 266 | + boost.add(c.text.lower()) | |
| 267 | + if token.i - token.left_edge.i >= 1: | |
| 268 | + comps = [x.text.lower() for x in doc[token.left_edge.i : token.i] if x.dep_ == "compound"] | |
| 269 | + if len(comps) >= 2: | |
| 270 | + boost.add(" ".join(comps)) | |
| 271 | + | |
| 272 | + # List-like INTJ head with multiple pobj: first pobj = product head, rest often demographic | |
| 273 | + for token in doc: | |
| 274 | + if token.dep_ != "ROOT" or token.pos_ not in ("INTJ", "VERB", "NOUN"): | |
| 275 | + continue | |
| 276 | + pobjs = sorted( | |
| 277 | + [c for c in token.children if c.dep_ == "pobj" and c.pos_ in ("NOUN", "PROPN")], | |
| 278 | + key=lambda x: x.i, | |
| 279 | + ) | |
| 280 | + if len(pobjs) >= 2 and token.pos_ == "INTJ": | |
| 281 | + intersection.add(_surface_lower(pobjs[0])) | |
| 282 | + for extra in pobjs[1:]: | |
| 283 | + if _lemma_lower(extra) in _DEMOGRAPHIC_NOUNS: | |
| 284 | + boost.add(_surface_lower(extra)) | |
| 285 | + else: | |
| 286 | + intersection.add(_surface_lower(extra)) | |
| 287 | + elif len(pobjs) == 1 and token.pos_ == "INTJ": | |
| 288 | + intersection.add(_surface_lower(pobjs[0])) | |
| 289 | + | |
| 290 | + # amod under pobj (running → shoes) | |
| 291 | + for token in doc: | |
| 292 | + if token.dep_ == "amod" and token.head.pos_ in ("NOUN", "PROPN"): | |
| 293 | + if token.pos_ == "VERB": | |
| 294 | + boost.add(_surface_lower(token)) | |
| 295 | + elif token.pos_ == "ADJ": | |
| 296 | + boost.add(_lemma_lower(token)) | |
| 297 | + | |
| 298 | + # Genitive possessor (women's shoes → women boost) | |
| 299 | + for token in doc: | |
| 300 | + if token.dep_ == "poss" and token.head.pos_ in ("NOUN", "PROPN"): | |
| 301 | + boost.add(_surface_lower(token)) | |
| 302 | + | |
| 303 | + # noun_chunks fallback when no dobj/ROOT intersection yet | |
| 304 | + if not intersection: | |
| 305 | + for chunk in doc.noun_chunks: | |
| 306 | + head = chunk.root | |
| 307 | + if head.pos_ not in ("NOUN", "PROPN"): | |
| 308 | + continue | |
| 309 | + # Price / range: "under 500 dollars" → boost, not a product head | |
| 310 | + if head.dep_ == "pobj" and head.head.dep_ == "prep": | |
| 311 | + prep = head.head | |
| 312 | + if _lemma_lower(prep) in _PRICE_PREP_LEMMAS: | |
| 313 | + boost.add(chunk.text.strip().lower()) | |
| 314 | + continue | |
| 315 | + hl = _surface_lower(head) | |
| 316 | + if hl: | |
| 317 | + intersection.add(hl) | |
| 318 | + for t in chunk: | |
| 319 | + if t == head or t.pos_ != "PROPN": | |
| 320 | + continue | |
| 321 | + intersection.add(_surface_lower(t)) | |
| 322 | + for t in chunk: | |
| 323 | + if t == head: | |
| 324 | + continue | |
| 325 | + if t.pos_ == "ADJ" or (t.pos_ == "NOUN" and t.dep_ == "compound"): | |
| 326 | + boost.add(_lemma_lower(t)) | |
| 327 | + | |
| 328 | + # Remove demoted pobj heads from intersection (purpose / price clause) | |
| 329 | + for i in pobj_heads_to_demote: | |
| 330 | + t = doc[i] | |
| 331 | + intersection.discard(_lemma_lower(t)) | |
| 332 | + intersection.discard(_surface_lower(t)) | |
| 333 | + | |
| 334 | + boost -= intersection | |
| 335 | + boost = {b for b in boost if b.lower() not in stops and b.strip()} | |
| 336 | + | |
| 337 | + return BucketResult( | |
| 338 | + intersection_terms=_dedupe_preserve(sorted(intersection)), | |
| 339 | + boost_terms=_dedupe_preserve(sorted(boost)), | |
| 340 | + drop_terms=_dedupe_preserve(sorted(drop)), | |
| 341 | + ) | |
| 342 | + | |
| 343 | + | |
| 344 | +# --- Scheme B: spaCy candidates + KeyBERT ----------------------------------- | |
| 345 | + | |
| 346 | +def _spacy_np_candidates(doc) -> List[str]: | |
| 347 | + phrases: List[str] = [] | |
| 348 | + for chunk in doc.noun_chunks: | |
| 349 | + t = chunk.text.strip() | |
| 350 | + if len(t) < 2: | |
| 351 | + continue | |
| 352 | + root = chunk.root | |
| 353 | + if root.pos_ not in ("NOUN", "PROPN"): | |
| 354 | + continue | |
| 355 | + phrases.append(t) | |
| 356 | + return phrases | |
| 357 | + | |
| 358 | + | |
| 359 | +def bucket_scheme_b_keybert(query: str, nlp, kw_model) -> BucketResult: | |
| 360 | + text = (query or "").strip() | |
| 361 | + doc = nlp(text) | |
| 362 | + candidates = _spacy_np_candidates(doc) | |
| 363 | + if not candidates: | |
| 364 | + candidates = [text] | |
| 365 | + | |
| 366 | + # KeyBERT API: candidate_keywords=... (sentence-transformers backend) | |
| 367 | + try: | |
| 368 | + keywords = kw_model.extract_keywords( | |
| 369 | + text, | |
| 370 | + candidates=candidates, | |
| 371 | + top_n=min(8, max(4, len(candidates) + 2)), | |
| 372 | + ) | |
| 373 | + except TypeError: | |
| 374 | + keywords = kw_model.extract_keywords( | |
| 375 | + text, | |
| 376 | + candidate_keywords=candidates, | |
| 377 | + top_n=min(8, max(4, len(candidates) + 2)), | |
| 378 | + ) | |
| 379 | + ranked = [k[0].lower().strip() for k in (keywords or []) if k and k[0].strip()] | |
| 380 | + | |
| 381 | + intersection: List[str] = [] | |
| 382 | + boost: List[str] = [] | |
| 383 | + if ranked: | |
| 384 | + intersection.append(ranked[0]) | |
| 385 | + if len(ranked) > 1: | |
| 386 | + boost.extend(ranked[1:]) | |
| 387 | + # Add remaining spaCy heads not in lists | |
| 388 | + heads: List[str] = [] | |
| 389 | + for ch in doc.noun_chunks: | |
| 390 | + h = ch.root | |
| 391 | + if h.pos_ in ("NOUN", "PROPN"): | |
| 392 | + heads.append(_surface_lower(h)) | |
| 393 | + for h in heads: | |
| 394 | + if h and h not in intersection and h not in boost: | |
| 395 | + boost.append(h) | |
| 396 | + if not intersection and heads: | |
| 397 | + intersection.append(heads[0]) | |
| 398 | + boost = [x for x in boost if x != heads[0]] | |
| 399 | + | |
| 400 | + drop_tokens: Set[str] = set() | |
| 401 | + stops = nlp.Defaults.stop_words | WEAK_BOOST_ADJS | |
| 402 | + for token in doc: | |
| 403 | + if token.is_punct: | |
| 404 | + continue | |
| 405 | + lem = (token.lemma_ or token.text).lower() | |
| 406 | + if token.pos_ in ("DET", "ADP", "PART", "PRON", "AUX") or lem in stops: | |
| 407 | + drop_tokens.add(token.text.lower()) | |
| 408 | + | |
| 409 | + return BucketResult( | |
| 410 | + intersection_terms=_dedupe_preserve(intersection), | |
| 411 | + boost_terms=_dedupe_preserve(boost), | |
| 412 | + drop_terms=sorted(drop_tokens), | |
| 413 | + ) | |
| 414 | + | |
| 415 | + | |
| 416 | +# --- Scheme C: YAKE + noun filter -------------------------------------------- | |
| 417 | + | |
| 418 | +def bucket_scheme_c_yake(query: str, nlp, yake_extractor) -> BucketResult: | |
| 419 | + text = (query or "").strip() | |
| 420 | + doc = nlp(text) | |
| 421 | + | |
| 422 | + kws = yake_extractor.extract_keywords(text) # List[Tuple[str, float]] newest yake API may differ | |
| 423 | + | |
| 424 | + scored: List[Tuple[str, float]] = [] | |
| 425 | + if kws and isinstance(kws[0], (list, tuple)) and len(kws[0]) >= 2: | |
| 426 | + scored = [(str(a).strip(), float(b)) for a, b in kws] | |
| 427 | + else: | |
| 428 | + # older yake returns list of tuples (kw, score) | |
| 429 | + scored = [(str(x[0]).strip(), float(x[1])) for x in kws] | |
| 430 | + | |
| 431 | + boost: List[str] = [] | |
| 432 | + intersection: List[str] = [] | |
| 433 | + for phrase, _score in sorted(scored, key=lambda x: x[1]): # lower score = more important in YAKE | |
| 434 | + phrase = phrase.lower().strip() | |
| 435 | + if not phrase or len(phrase) < 2: | |
| 436 | + continue | |
| 437 | + sub = nlp(phrase) | |
| 438 | + keep = False | |
| 439 | + head_noun = False | |
| 440 | + for t in sub: | |
| 441 | + if t.is_punct or t.is_space: | |
| 442 | + continue | |
| 443 | + if t.pos_ in ("NOUN", "PROPN"): | |
| 444 | + keep = True | |
| 445 | + if t.dep_ == "ROOT" or t == sub[-1]: | |
| 446 | + head_noun = True | |
| 447 | + if not keep: | |
| 448 | + continue | |
| 449 | + # top 1–2 important → intersection (very small) | |
| 450 | + if len(intersection) < 2 and head_noun and len(phrase.split()) <= 2: | |
| 451 | + intersection.append(phrase) | |
| 452 | + else: | |
| 453 | + boost.append(phrase) | |
| 454 | + | |
| 455 | + drop: Set[str] = set() | |
| 456 | + stops = nlp.Defaults.stop_words | WEAK_BOOST_ADJS | |
| 457 | + for token in doc: | |
| 458 | + if token.is_punct: | |
| 459 | + continue | |
| 460 | + lem = (token.lemma_ or token.text).lower() | |
| 461 | + if token.pos_ in ("DET", "ADP", "PART", "PRON", "AUX") or lem in stops: | |
| 462 | + drop.add(token.text.lower()) | |
| 463 | + | |
| 464 | + return BucketResult( | |
| 465 | + intersection_terms=_dedupe_preserve(intersection), | |
| 466 | + boost_terms=_dedupe_preserve(boost), | |
| 467 | + drop_terms=sorted(drop), | |
| 468 | + ) | |
| 469 | + | |
| 470 | + | |
| 471 | +# --- CLI --------------------------------------------------------------------- | |
| 472 | + | |
| 473 | +DEFAULT_QUERIES = [ | |
| 474 | + "best noise cancelling headphones for airplane travel", | |
| 475 | + "nike running shoes women", | |
| 476 | + "plastic toy car", | |
| 477 | + "what is the best smartphone under 500 dollars", | |
| 478 | + "women's running shoes size 8", | |
| 479 | +] | |
| 480 | + | |
| 481 | + | |
| 482 | +def _load_spacy(): | |
| 483 | + import spacy | |
| 484 | + | |
| 485 | + try: | |
| 486 | + return spacy.load("en_core_web_sm") | |
| 487 | + except OSError: | |
| 488 | + print( | |
| 489 | + "Missing model: run: python -m spacy download en_core_web_sm", | |
| 490 | + file=sys.stderr, | |
| 491 | + ) | |
| 492 | + raise | |
| 493 | + | |
| 494 | + | |
| 495 | +def _load_keybert(): | |
| 496 | + from keybert import KeyBERT | |
| 497 | + | |
| 498 | + # small & fast for demo; swap for larger if needed | |
| 499 | + return KeyBERT(model="paraphrase-MiniLM-L6-v2") | |
| 500 | + | |
| 501 | + | |
| 502 | +def _load_yake(): | |
| 503 | + import yake | |
| 504 | + | |
| 505 | + return yake.KeywordExtractor( | |
| 506 | + lan="en", | |
| 507 | + n=3, | |
| 508 | + dedupLim=0.9, | |
| 509 | + top=20, | |
| 510 | + features=None, | |
| 511 | + ) | |
| 512 | + | |
| 513 | + | |
| 514 | +def main() -> None: | |
| 515 | + parser = argparse.ArgumentParser(description="English query bucketing experiments") | |
| 516 | + parser.add_argument( | |
| 517 | + "--queries", | |
| 518 | + nargs="*", | |
| 519 | + default=DEFAULT_QUERIES, | |
| 520 | + help="Queries to run (default: built-in examples)", | |
| 521 | + ) | |
| 522 | + parser.add_argument( | |
| 523 | + "--scheme", | |
| 524 | + choices=("a", "b", "c", "all"), | |
| 525 | + default="all", | |
| 526 | + ) | |
| 527 | + args = parser.parse_args() | |
| 528 | + | |
| 529 | + nlp = _load_spacy() | |
| 530 | + kb = None | |
| 531 | + yk = None | |
| 532 | + if args.scheme in ("b", "all"): | |
| 533 | + kb = _load_keybert() | |
| 534 | + if args.scheme in ("c", "all"): | |
| 535 | + yk = _load_yake() | |
| 536 | + | |
| 537 | + for q in args.queries: | |
| 538 | + print("=" * 72) | |
| 539 | + print("QUERY:", q) | |
| 540 | + print("-" * 72) | |
| 541 | + if args.scheme in ("a", "all"): | |
| 542 | + ra = bucket_scheme_a_spacy(q, nlp) | |
| 543 | + print("A spaCy+rules:", json.dumps(ra.to_dict(), ensure_ascii=False)) | |
| 544 | + if args.scheme in ("b", "all") and kb is not None: | |
| 545 | + rb = bucket_scheme_b_keybert(q, nlp, kb) | |
| 546 | + print("B spaCy+KeyBERT:", json.dumps(rb.to_dict(), ensure_ascii=False)) | |
| 547 | + if args.scheme in ("c", "all") and yk is not None: | |
| 548 | + rc = bucket_scheme_c_yake(q, nlp, yk) | |
| 549 | + print("C YAKE+noun filter:", json.dumps(rc.to_dict(), ensure_ascii=False)) | |
| 550 | + print() | |
| 551 | + | |
| 552 | + | |
| 553 | +if __name__ == "__main__": | |
| 554 | + main() | ... | ... |
scripts/experiments/requirements_query_bucketing_experiments.txt
0 → 100644
| ... | ... | @@ -0,0 +1,246 @@ |
| 1 | +#!/usr/bin/env python3 | |
| 2 | +""" | |
| 3 | +临时脚本:从 ES 遍历指定租户的 image_url,批量调用图片 embedding 服务。 | |
| 4 | +5 进程并发,每请求最多 8 条 URL。日志打印到标准输出。 | |
| 5 | + | |
| 6 | +用法: | |
| 7 | + source activate.sh # 会加载 .env,提供 ES_HOST / ES_USERNAME / ES_PASSWORD | |
| 8 | + python scripts/temp_embed_tenant_image_urls.py | |
| 9 | + | |
| 10 | +未 source 时脚本也会尝试加载项目根目录 .env。 | |
| 11 | +""" | |
| 12 | + | |
| 13 | +from __future__ import annotations | |
| 14 | + | |
| 15 | +import json | |
| 16 | +import multiprocessing as mp | |
| 17 | +import os | |
| 18 | +import sys | |
| 19 | +import time | |
| 20 | +from dataclasses import dataclass | |
| 21 | +from pathlib import Path | |
| 22 | +from typing import Any, Dict, List, Optional, Tuple | |
| 23 | +from urllib.parse import urlencode | |
| 24 | + | |
| 25 | +import requests | |
| 26 | +from elasticsearch import Elasticsearch | |
| 27 | +from elasticsearch.helpers import scan | |
| 28 | + | |
| 29 | +# 未 source activate.sh 时仍可从项目根 .env 加载(与 ES_HOST / ES_USERNAME / ES_PASSWORD 一致) | |
| 30 | +try: | |
| 31 | + from dotenv import load_dotenv | |
| 32 | + | |
| 33 | + _ROOT = Path(__file__).resolve().parents[1] | |
| 34 | + load_dotenv(_ROOT / ".env") | |
| 35 | +except ImportError: | |
| 36 | + pass | |
| 37 | + | |
| 38 | +# --------------------------------------------------------------------------- | |
| 39 | +# 配置(可按需修改;默认与 .env 中 ES_* 一致,见 config/loader.py) | |
| 40 | +# --------------------------------------------------------------------------- | |
| 41 | + | |
| 42 | +# Elasticsearch(默认读环境变量:ES_HOST、ES_USERNAME、ES_PASSWORD) | |
| 43 | +ES_HOST: str = os.getenv("ES_HOST", "http://localhost:9200") | |
| 44 | +ES_USERNAME: Optional[str] = os.getenv("ES_USERNAME") or None | |
| 45 | +ES_PASSWORD: Optional[str] = os.getenv("ES_PASSWORD") or None | |
| 46 | +ES_INDEX: str = "search_products_tenant_163" | |
| 47 | + | |
| 48 | +# 租户(keyword 字段,字符串) | |
| 49 | +TENANT_ID: str = "163" | |
| 50 | + | |
| 51 | +# 图片 embedding 服务(与文档 7.1.2 一致) | |
| 52 | +EMBED_BASE_URL: str = "http://localhost:6008" | |
| 53 | +EMBED_PATH: str = "/embed/image" | |
| 54 | +EMBED_QUERY: Dict[str, Any] = { | |
| 55 | + "normalize": "true", | |
| 56 | + "priority": "1", # 与对接文档 curl 一致;批量离线可改为 "0" | |
| 57 | +} | |
| 58 | + | |
| 59 | +# 并发与批量 | |
| 60 | +WORKER_PROCESSES: int = 5 | |
| 61 | +URLS_PER_REQUEST: int = 8 | |
| 62 | + | |
| 63 | +# HTTP | |
| 64 | +REQUEST_TIMEOUT_SEC: float = 120.0 | |
| 65 | + | |
| 66 | +# ES scan(elasticsearch-py 8+/ES 9:`scan(..., query=...)` 会展开为 `client.search(**kwargs)`, | |
| 67 | +# 必须传与 Search API 一致的参数名,例如顶层 `query` = DSL 的 query 子句,不要用裸 `match_all`。) | |
| 68 | +SCROLL_CHUNK_SIZE: int = 500 | |
| 69 | + | |
| 70 | +# --------------------------------------------------------------------------- | |
| 71 | + | |
| 72 | + | |
| 73 | +@dataclass | |
| 74 | +class BatchResult: | |
| 75 | + batch_index: int | |
| 76 | + url_count: int | |
| 77 | + ok: bool | |
| 78 | + status_code: Optional[int] | |
| 79 | + elapsed_sec: float | |
| 80 | + error: Optional[str] = None | |
| 81 | + | |
| 82 | + | |
| 83 | +def _build_embed_url() -> str: | |
| 84 | + q = urlencode(EMBED_QUERY) | |
| 85 | + return f"{EMBED_BASE_URL.rstrip('/')}{EMBED_PATH}?{q}" | |
| 86 | + | |
| 87 | + | |
| 88 | +def _process_batch(payload: Tuple[int, List[str]]) -> BatchResult: | |
| 89 | + batch_index, urls = payload | |
| 90 | + if not urls: | |
| 91 | + return BatchResult(batch_index, 0, True, None, 0.0, None) | |
| 92 | + | |
| 93 | + url = _build_embed_url() | |
| 94 | + t0 = time.perf_counter() | |
| 95 | + try: | |
| 96 | + resp = requests.post( | |
| 97 | + url, | |
| 98 | + headers={"Content-Type": "application/json"}, | |
| 99 | + data=json.dumps(urls), | |
| 100 | + timeout=REQUEST_TIMEOUT_SEC, | |
| 101 | + ) | |
| 102 | + elapsed = time.perf_counter() - t0 | |
| 103 | + ok = resp.status_code == 200 | |
| 104 | + err: Optional[str] = None | |
| 105 | + if ok: | |
| 106 | + try: | |
| 107 | + body = resp.json() | |
| 108 | + if not isinstance(body, list) or len(body) != len(urls): | |
| 109 | + ok = False | |
| 110 | + err = f"response length mismatch or not list: got {type(body).__name__}" | |
| 111 | + except Exception as e: | |
| 112 | + ok = False | |
| 113 | + err = f"json decode: {e}" | |
| 114 | + else: | |
| 115 | + err = resp.text[:500] if resp.text else f"HTTP {resp.status_code}" | |
| 116 | + | |
| 117 | + worker = mp.current_process().name | |
| 118 | + status = resp.status_code if resp else None | |
| 119 | + ms = elapsed * 1000.0 | |
| 120 | + if ok: | |
| 121 | + print( | |
| 122 | + f"[embed] worker={worker} batch={batch_index} urls={len(urls)} " | |
| 123 | + f"http={status} elapsed_ms={ms:.2f} ok", | |
| 124 | + flush=True, | |
| 125 | + ) | |
| 126 | + else: | |
| 127 | + print( | |
| 128 | + f"[embed] worker={worker} batch={batch_index} urls={len(urls)} " | |
| 129 | + f"http={status} elapsed_ms={ms:.2f} FAIL err={err}", | |
| 130 | + flush=True, | |
| 131 | + ) | |
| 132 | + return BatchResult(batch_index, len(urls), ok, status, elapsed, err) | |
| 133 | + except Exception as e: | |
| 134 | + elapsed = time.perf_counter() - t0 | |
| 135 | + worker = mp.current_process().name | |
| 136 | + print( | |
| 137 | + f"[embed] worker={worker} batch={batch_index} urls={len(urls)} " | |
| 138 | + f"http=None elapsed_ms={elapsed * 1000.0:.2f} FAIL err={e}", | |
| 139 | + flush=True, | |
| 140 | + ) | |
| 141 | + return BatchResult(batch_index, len(urls), False, None, elapsed, str(e)) | |
| 142 | + | |
| 143 | + | |
| 144 | +def _iter_image_urls(es: Elasticsearch) -> List[str]: | |
| 145 | + # 对应 search body: { "query": { "term": { "tenant_id": "..." } } } | |
| 146 | + search_kw: Dict[str, Any] = { | |
| 147 | + "query": {"term": {"tenant_id": TENANT_ID}}, | |
| 148 | + "source_includes": ["image_url"], | |
| 149 | + } | |
| 150 | + urls: List[str] = [] | |
| 151 | + for hit in scan( | |
| 152 | + es, | |
| 153 | + query=search_kw, | |
| 154 | + index=ES_INDEX, | |
| 155 | + size=SCROLL_CHUNK_SIZE, | |
| 156 | + ): | |
| 157 | + src = hit.get("_source") or {} | |
| 158 | + u = src.get("image_url") | |
| 159 | + if u is None: | |
| 160 | + continue | |
| 161 | + s = str(u).strip() | |
| 162 | + if not s: | |
| 163 | + continue | |
| 164 | + urls.append(s) | |
| 165 | + return urls | |
| 166 | + | |
| 167 | + | |
| 168 | +def main() -> int: | |
| 169 | + t_wall0 = time.perf_counter() | |
| 170 | + | |
| 171 | + auth = None | |
| 172 | + if ES_USERNAME and ES_PASSWORD: | |
| 173 | + auth = (ES_USERNAME, ES_PASSWORD) | |
| 174 | + | |
| 175 | + es = Elasticsearch([ES_HOST], basic_auth=auth) | |
| 176 | + if not es.ping(): | |
| 177 | + print("ERROR: Elasticsearch ping failed", file=sys.stderr) | |
| 178 | + return 1 | |
| 179 | + | |
| 180 | + print( | |
| 181 | + f"[main] ES={ES_HOST} basic_auth={'yes' if auth else 'no'} " | |
| 182 | + f"index={ES_INDEX} tenant_id={TENANT_ID} " | |
| 183 | + f"workers={WORKER_PROCESSES} urls_per_req={URLS_PER_REQUEST}", | |
| 184 | + flush=True, | |
| 185 | + ) | |
| 186 | + print(f"[main] embed_url={_build_embed_url()}", flush=True) | |
| 187 | + | |
| 188 | + t_fetch0 = time.perf_counter() | |
| 189 | + all_urls = _iter_image_urls(es) | |
| 190 | + fetch_elapsed = time.perf_counter() - t_fetch0 | |
| 191 | + print( | |
| 192 | + f"[main] collected image_url count={len(all_urls)} es_scan_elapsed_sec={fetch_elapsed:.3f}", | |
| 193 | + flush=True, | |
| 194 | + ) | |
| 195 | + | |
| 196 | + batches: List[List[str]] = [] | |
| 197 | + for i in range(0, len(all_urls), URLS_PER_REQUEST): | |
| 198 | + batches.append(all_urls[i : i + URLS_PER_REQUEST]) | |
| 199 | + | |
| 200 | + if not batches: | |
| 201 | + print("[main] no URLs to process; done.", flush=True) | |
| 202 | + return 0 | |
| 203 | + | |
| 204 | + tasks = [(idx, batch) for idx, batch in enumerate(batches)] | |
| 205 | + print(f"[main] batches={len(tasks)} (parallel processes={WORKER_PROCESSES})", flush=True) | |
| 206 | + | |
| 207 | + t_run0 = time.perf_counter() | |
| 208 | + total_urls = 0 | |
| 209 | + success_urls = 0 | |
| 210 | + failed_urls = 0 | |
| 211 | + ok_batches = 0 | |
| 212 | + fail_batches = 0 | |
| 213 | + sum_req_sec = 0.0 | |
| 214 | + | |
| 215 | + with mp.Pool(processes=WORKER_PROCESSES) as pool: | |
| 216 | + for res in pool.imap_unordered(_process_batch, tasks, chunksize=1): | |
| 217 | + total_urls += res.url_count | |
| 218 | + sum_req_sec += res.elapsed_sec | |
| 219 | + if res.ok: | |
| 220 | + ok_batches += 1 | |
| 221 | + success_urls += res.url_count | |
| 222 | + else: | |
| 223 | + fail_batches += 1 | |
| 224 | + failed_urls += res.url_count | |
| 225 | + | |
| 226 | + wall_total = time.perf_counter() - t_wall0 | |
| 227 | + run_elapsed = time.perf_counter() - t_run0 | |
| 228 | + | |
| 229 | + print("---------- summary ----------", flush=True) | |
| 230 | + print(f"tenant_id: {TENANT_ID}", flush=True) | |
| 231 | + print(f"total documents w/ url: {len(all_urls)}", flush=True) | |
| 232 | + print(f"total batches: {len(batches)}", flush=True) | |
| 233 | + print(f"batches succeeded: {ok_batches}", flush=True) | |
| 234 | + print(f"batches failed: {fail_batches}", flush=True) | |
| 235 | + print(f"urls (success path): {success_urls}", flush=True) | |
| 236 | + print(f"urls (failed path): {failed_urls}", flush=True) | |
| 237 | + print(f"ES scan elapsed (s): {fetch_elapsed:.3f}", flush=True) | |
| 238 | + print(f"embed phase wall (s): {run_elapsed:.3f}", flush=True) | |
| 239 | + print(f"sum request time (s): {sum_req_sec:.3f} (sequential sum, for reference)", flush=True) | |
| 240 | + print(f"total wall time (s): {wall_total:.3f}", flush=True) | |
| 241 | + print("-----------------------------", flush=True) | |
| 242 | + return 0 if fail_batches == 0 else 2 | |
| 243 | + | |
| 244 | + | |
| 245 | +if __name__ == "__main__": | |
| 246 | + raise SystemExit(main()) | ... | ... |
search/rerank_client.py
| ... | ... | @@ -10,7 +10,7 @@ |
| 10 | 10 | from typing import Dict, Any, List, Optional, Tuple |
| 11 | 11 | import logging |
| 12 | 12 | |
| 13 | -from config.schema import RerankFusionConfig | |
| 13 | +from config.schema import CoarseRankFusionConfig, RerankFusionConfig | |
| 14 | 14 | from providers import create_rerank_provider |
| 15 | 15 | |
| 16 | 16 | logger = logging.getLogger(__name__) |
| ... | ... | @@ -120,6 +120,7 @@ def call_rerank_service( |
| 120 | 120 | docs: List[str], |
| 121 | 121 | timeout_sec: float = DEFAULT_TIMEOUT_SEC, |
| 122 | 122 | top_n: Optional[int] = None, |
| 123 | + service_profile: Optional[str] = None, | |
| 123 | 124 | ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]: |
| 124 | 125 | """ |
| 125 | 126 | 调用重排服务 POST /rerank,返回分数列表与 meta。 |
| ... | ... | @@ -128,7 +129,7 @@ def call_rerank_service( |
| 128 | 129 | if not docs: |
| 129 | 130 | return [], {} |
| 130 | 131 | try: |
| 131 | - client = create_rerank_provider() | |
| 132 | + client = create_rerank_provider(service_profile=service_profile) | |
| 132 | 133 | return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n) |
| 133 | 134 | except Exception as e: |
| 134 | 135 | logger.warning("Rerank request failed: %s", e, exc_info=True) |
| ... | ... | @@ -240,24 +241,105 @@ def _collect_text_score_components(matched_queries: Any, fallback_es_score: floa |
| 240 | 241 | |
| 241 | 242 | def _multiply_fusion_factors( |
| 242 | 243 | rerank_score: float, |
| 244 | + fine_score: Optional[float], | |
| 243 | 245 | text_score: float, |
| 244 | 246 | knn_score: float, |
| 245 | 247 | fusion: RerankFusionConfig, |
| 246 | -) -> Tuple[float, float, float, float]: | |
| 247 | - """(rerank_factor, text_factor, knn_factor, fused_without_style_boost).""" | |
| 248 | +) -> Tuple[float, float, float, float, float]: | |
| 249 | + """(rerank_factor, fine_factor, text_factor, knn_factor, fused_without_style_boost).""" | |
| 248 | 250 | r = (max(rerank_score, 0.0) + fusion.rerank_bias) ** fusion.rerank_exponent |
| 251 | + if fine_score is None: | |
| 252 | + f = 1.0 | |
| 253 | + else: | |
| 254 | + f = (max(fine_score, 0.0) + fusion.fine_bias) ** fusion.fine_exponent | |
| 249 | 255 | t = (max(text_score, 0.0) + fusion.text_bias) ** fusion.text_exponent |
| 250 | 256 | k = (max(knn_score, 0.0) + fusion.knn_bias) ** fusion.knn_exponent |
| 251 | - return r, t, k, r * t * k | |
| 257 | + return r, f, t, k, r * f * t * k | |
| 258 | + | |
| 259 | + | |
| 260 | +def _multiply_coarse_fusion_factors( | |
| 261 | + text_score: float, | |
| 262 | + knn_score: float, | |
| 263 | + fusion: CoarseRankFusionConfig, | |
| 264 | +) -> Tuple[float, float, float]: | |
| 265 | + text_factor = (max(text_score, 0.0) + fusion.text_bias) ** fusion.text_exponent | |
| 266 | + knn_factor = (max(knn_score, 0.0) + fusion.knn_bias) ** fusion.knn_exponent | |
| 267 | + return text_factor, knn_factor, text_factor * knn_factor | |
| 252 | 268 | |
| 253 | 269 | |
| 254 | 270 | def _has_selected_sku(hit: Dict[str, Any]) -> bool: |
| 255 | 271 | return bool(str(hit.get("_style_rerank_suffix") or "").strip()) |
| 256 | 272 | |
| 257 | 273 | |
| 274 | +def coarse_resort_hits( | |
| 275 | + es_hits: List[Dict[str, Any]], | |
| 276 | + fusion: Optional[CoarseRankFusionConfig] = None, | |
| 277 | + debug: bool = False, | |
| 278 | +) -> List[Dict[str, Any]]: | |
| 279 | + """Coarse rank with text/knn fusion only.""" | |
| 280 | + if not es_hits: | |
| 281 | + return [] | |
| 282 | + | |
| 283 | + f = fusion or CoarseRankFusionConfig() | |
| 284 | + coarse_debug: List[Dict[str, Any]] = [] if debug else [] | |
| 285 | + for hit in es_hits: | |
| 286 | + es_score = _to_score(hit.get("_score")) | |
| 287 | + matched_queries = hit.get("matched_queries") | |
| 288 | + knn_components = _collect_knn_score_components(matched_queries, f) | |
| 289 | + text_components = _collect_text_score_components(matched_queries, es_score) | |
| 290 | + text_score = text_components["text_score"] | |
| 291 | + knn_score = knn_components["knn_score"] | |
| 292 | + text_factor, knn_factor, coarse_score = _multiply_coarse_fusion_factors( | |
| 293 | + text_score=text_score, | |
| 294 | + knn_score=knn_score, | |
| 295 | + fusion=f, | |
| 296 | + ) | |
| 297 | + | |
| 298 | + hit["_text_score"] = text_score | |
| 299 | + hit["_knn_score"] = knn_score | |
| 300 | + hit["_text_knn_score"] = knn_components["text_knn_score"] | |
| 301 | + hit["_image_knn_score"] = knn_components["image_knn_score"] | |
| 302 | + hit["_coarse_score"] = coarse_score | |
| 303 | + | |
| 304 | + if debug: | |
| 305 | + coarse_debug.append( | |
| 306 | + { | |
| 307 | + "doc_id": hit.get("_id"), | |
| 308 | + "es_score": es_score, | |
| 309 | + "text_score": text_score, | |
| 310 | + "text_source_score": text_components["source_score"], | |
| 311 | + "text_translation_score": text_components["translation_score"], | |
| 312 | + "text_weighted_source_score": text_components["weighted_source_score"], | |
| 313 | + "text_weighted_translation_score": text_components["weighted_translation_score"], | |
| 314 | + "text_primary_score": text_components["primary_text_score"], | |
| 315 | + "text_support_score": text_components["support_text_score"], | |
| 316 | + "text_score_fallback_to_es": ( | |
| 317 | + text_score == es_score | |
| 318 | + and text_components["source_score"] <= 0.0 | |
| 319 | + and text_components["translation_score"] <= 0.0 | |
| 320 | + ), | |
| 321 | + "text_knn_score": knn_components["text_knn_score"], | |
| 322 | + "image_knn_score": knn_components["image_knn_score"], | |
| 323 | + "weighted_text_knn_score": knn_components["weighted_text_knn_score"], | |
| 324 | + "weighted_image_knn_score": knn_components["weighted_image_knn_score"], | |
| 325 | + "knn_primary_score": knn_components["primary_knn_score"], | |
| 326 | + "knn_support_score": knn_components["support_knn_score"], | |
| 327 | + "knn_score": knn_score, | |
| 328 | + "coarse_text_factor": text_factor, | |
| 329 | + "coarse_knn_factor": knn_factor, | |
| 330 | + "coarse_score": coarse_score, | |
| 331 | + "matched_queries": matched_queries, | |
| 332 | + } | |
| 333 | + ) | |
| 334 | + | |
| 335 | + es_hits.sort(key=lambda h: h.get("_coarse_score", h.get("_score", 0.0)), reverse=True) | |
| 336 | + return coarse_debug | |
| 337 | + | |
| 338 | + | |
| 258 | 339 | def fuse_scores_and_resort( |
| 259 | 340 | es_hits: List[Dict[str, Any]], |
| 260 | 341 | rerank_scores: List[float], |
| 342 | + fine_scores: Optional[List[float]] = None, | |
| 261 | 343 | weight_es: float = DEFAULT_WEIGHT_ES, |
| 262 | 344 | weight_ai: float = DEFAULT_WEIGHT_AI, |
| 263 | 345 | fusion: Optional[RerankFusionConfig] = None, |
| ... | ... | @@ -290,6 +372,8 @@ def fuse_scores_and_resort( |
| 290 | 372 | n = len(es_hits) |
| 291 | 373 | if n == 0 or len(rerank_scores) != n: |
| 292 | 374 | return [] |
| 375 | + if fine_scores is not None and len(fine_scores) != n: | |
| 376 | + fine_scores = None | |
| 293 | 377 | |
| 294 | 378 | f = fusion or RerankFusionConfig() |
| 295 | 379 | fused_debug: List[Dict[str, Any]] = [] if debug else [] |
| ... | ... | @@ -297,13 +381,14 @@ def fuse_scores_and_resort( |
| 297 | 381 | for idx, hit in enumerate(es_hits): |
| 298 | 382 | es_score = _to_score(hit.get("_score")) |
| 299 | 383 | rerank_score = _to_score(rerank_scores[idx]) |
| 384 | + fine_score = _to_score(fine_scores[idx]) if fine_scores is not None else _to_score(hit.get("_fine_score")) | |
| 300 | 385 | matched_queries = hit.get("matched_queries") |
| 301 | 386 | knn_components = _collect_knn_score_components(matched_queries, f) |
| 302 | 387 | knn_score = knn_components["knn_score"] |
| 303 | 388 | text_components = _collect_text_score_components(matched_queries, es_score) |
| 304 | 389 | text_score = text_components["text_score"] |
| 305 | - rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors( | |
| 306 | - rerank_score, text_score, knn_score, f | |
| 390 | + rerank_factor, fine_factor, text_factor, knn_factor, fused = _multiply_fusion_factors( | |
| 391 | + rerank_score, fine_score if fine_scores is not None or "_fine_score" in hit else None, text_score, knn_score, f | |
| 307 | 392 | ) |
| 308 | 393 | sku_selected = _has_selected_sku(hit) |
| 309 | 394 | style_boost = style_intent_selected_sku_boost if sku_selected else 1.0 |
| ... | ... | @@ -311,6 +396,7 @@ def fuse_scores_and_resort( |
| 311 | 396 | |
| 312 | 397 | hit["_original_score"] = hit.get("_score") |
| 313 | 398 | hit["_rerank_score"] = rerank_score |
| 399 | + hit["_fine_score"] = fine_score | |
| 314 | 400 | hit["_text_score"] = text_score |
| 315 | 401 | hit["_knn_score"] = knn_score |
| 316 | 402 | hit["_text_knn_score"] = knn_components["text_knn_score"] |
| ... | ... | @@ -330,6 +416,7 @@ def fuse_scores_and_resort( |
| 330 | 416 | "doc_id": hit.get("_id"), |
| 331 | 417 | "es_score": es_score, |
| 332 | 418 | "rerank_score": rerank_score, |
| 419 | + "fine_score": fine_score, | |
| 333 | 420 | "text_score": text_score, |
| 334 | 421 | "text_source_score": text_components["source_score"], |
| 335 | 422 | "text_translation_score": text_components["translation_score"], |
| ... | ... | @@ -350,6 +437,7 @@ def fuse_scores_and_resort( |
| 350 | 437 | "knn_support_score": knn_components["support_knn_score"], |
| 351 | 438 | "knn_score": knn_score, |
| 352 | 439 | "rerank_factor": rerank_factor, |
| 440 | + "fine_factor": fine_factor, | |
| 353 | 441 | "text_factor": text_factor, |
| 354 | 442 | "knn_factor": knn_factor, |
| 355 | 443 | "style_intent_selected_sku": sku_selected, |
| ... | ... | @@ -381,6 +469,8 @@ def run_rerank( |
| 381 | 469 | debug: bool = False, |
| 382 | 470 | fusion: Optional[RerankFusionConfig] = None, |
| 383 | 471 | style_intent_selected_sku_boost: float = 1.2, |
| 472 | + fine_scores: Optional[List[float]] = None, | |
| 473 | + service_profile: Optional[str] = None, | |
| 384 | 474 | ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]: |
| 385 | 475 | """ |
| 386 | 476 | 完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。 |
| ... | ... | @@ -404,6 +494,7 @@ def run_rerank( |
| 404 | 494 | docs, |
| 405 | 495 | timeout_sec=timeout_sec, |
| 406 | 496 | top_n=top_n, |
| 497 | + service_profile=service_profile, | |
| 407 | 498 | ) |
| 408 | 499 | |
| 409 | 500 | if scores is None or len(scores) != len(hits): |
| ... | ... | @@ -412,6 +503,7 @@ def run_rerank( |
| 412 | 503 | fused_debug = fuse_scores_and_resort( |
| 413 | 504 | hits, |
| 414 | 505 | scores, |
| 506 | + fine_scores=fine_scores, | |
| 415 | 507 | weight_es=weight_es, |
| 416 | 508 | weight_ai=weight_ai, |
| 417 | 509 | fusion=fusion, |
| ... | ... | @@ -427,3 +519,53 @@ def run_rerank( |
| 427 | 519 | es_response["hits"]["max_score"] = top |
| 428 | 520 | |
| 429 | 521 | return es_response, meta, fused_debug |
| 522 | + | |
| 523 | + | |
| 524 | +def run_lightweight_rerank( | |
| 525 | + query: str, | |
| 526 | + es_hits: List[Dict[str, Any]], | |
| 527 | + language: str = "zh", | |
| 528 | + timeout_sec: float = DEFAULT_TIMEOUT_SEC, | |
| 529 | + rerank_query_template: str = "{query}", | |
| 530 | + rerank_doc_template: str = "{title}", | |
| 531 | + top_n: Optional[int] = None, | |
| 532 | + debug: bool = False, | |
| 533 | + service_profile: Optional[str] = "fine", | |
| 534 | +) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]], List[Dict[str, Any]]]: | |
| 535 | + """Call lightweight reranker and attach scores to hits without final fusion.""" | |
| 536 | + if not es_hits: | |
| 537 | + return [], {}, [] | |
| 538 | + | |
| 539 | + query_text = str(rerank_query_template).format_map({"query": query}) | |
| 540 | + rerank_debug_rows: Optional[List[Dict[str, Any]]] = [] if debug else None | |
| 541 | + docs = build_docs_from_hits( | |
| 542 | + es_hits, | |
| 543 | + language=language, | |
| 544 | + doc_template=rerank_doc_template, | |
| 545 | + debug_rows=rerank_debug_rows, | |
| 546 | + ) | |
| 547 | + scores, meta = call_rerank_service( | |
| 548 | + query_text, | |
| 549 | + docs, | |
| 550 | + timeout_sec=timeout_sec, | |
| 551 | + top_n=top_n, | |
| 552 | + service_profile=service_profile, | |
| 553 | + ) | |
| 554 | + if scores is None or len(scores) != len(es_hits): | |
| 555 | + return None, None, [] | |
| 556 | + | |
| 557 | + debug_rows: List[Dict[str, Any]] = [] if debug else [] | |
| 558 | + for idx, hit in enumerate(es_hits): | |
| 559 | + fine_score = _to_score(scores[idx]) | |
| 560 | + hit["_fine_score"] = fine_score | |
| 561 | + if debug: | |
| 562 | + row: Dict[str, Any] = { | |
| 563 | + "doc_id": hit.get("_id"), | |
| 564 | + "fine_score": fine_score, | |
| 565 | + } | |
| 566 | + if rerank_debug_rows is not None and idx < len(rerank_debug_rows): | |
| 567 | + row["rerank_input"] = rerank_debug_rows[idx] | |
| 568 | + debug_rows.append(row) | |
| 569 | + | |
| 570 | + es_hits.sort(key=lambda h: h.get("_fine_score", 0.0), reverse=True) | |
| 571 | + return scores, meta, debug_rows | ... | ... |
search/searcher.py
| ... | ... | @@ -251,6 +251,30 @@ class Searcher: |
| 251 | 251 | return hits_by_id, int(resp.get("took", 0) or 0) |
| 252 | 252 | |
| 253 | 253 | @staticmethod |
| 254 | + def _restore_hits_in_doc_order( | |
| 255 | + doc_ids: List[str], | |
| 256 | + hits_by_id: Dict[str, Dict[str, Any]], | |
| 257 | + ) -> List[Dict[str, Any]]: | |
| 258 | + ordered_hits: List[Dict[str, Any]] = [] | |
| 259 | + for doc_id in doc_ids: | |
| 260 | + hit = hits_by_id.get(str(doc_id)) | |
| 261 | + if hit is not None: | |
| 262 | + ordered_hits.append(hit) | |
| 263 | + return ordered_hits | |
| 264 | + | |
| 265 | + @staticmethod | |
| 266 | + def _merge_source_specs(*source_specs: Any) -> Optional[Dict[str, Any]]: | |
| 267 | + includes: set[str] = set() | |
| 268 | + for source_spec in source_specs: | |
| 269 | + if not isinstance(source_spec, dict): | |
| 270 | + continue | |
| 271 | + for field_name in source_spec.get("includes") or []: | |
| 272 | + includes.add(str(field_name)) | |
| 273 | + if not includes: | |
| 274 | + return None | |
| 275 | + return {"includes": sorted(includes)} | |
| 276 | + | |
| 277 | + @staticmethod | |
| 254 | 278 | def _has_style_intent(parsed_query: Optional[ParsedQuery]) -> bool: |
| 255 | 279 | profile = getattr(parsed_query, "style_intent_profile", None) |
| 256 | 280 | return bool(getattr(profile, "is_active", False)) |
| ... | ... | @@ -327,20 +351,30 @@ class Searcher: |
| 327 | 351 | index_langs = tenant_cfg.get("index_languages") or [] |
| 328 | 352 | enable_translation = len(index_langs) > 0 |
| 329 | 353 | enable_embedding = self.config.query_config.enable_text_embedding |
| 354 | + coarse_cfg = self.config.coarse_rank | |
| 355 | + fine_cfg = self.config.fine_rank | |
| 330 | 356 | rc = self.config.rerank |
| 331 | 357 | effective_query_template = rerank_query_template or rc.rerank_query_template |
| 332 | 358 | effective_doc_template = rerank_doc_template or rc.rerank_doc_template |
| 359 | + fine_query_template = fine_cfg.rerank_query_template or effective_query_template | |
| 360 | + fine_doc_template = fine_cfg.rerank_doc_template or effective_doc_template | |
| 333 | 361 | # 重排开关优先级:请求参数显式传值 > 服务端配置(默认开启) |
| 334 | 362 | rerank_enabled_by_config = bool(rc.enabled) |
| 335 | 363 | do_rerank = rerank_enabled_by_config if enable_rerank is None else bool(enable_rerank) |
| 336 | 364 | rerank_window = rc.rerank_window |
| 365 | + coarse_input_window = max(rerank_window, int(coarse_cfg.input_window)) | |
| 366 | + coarse_output_window = max(rerank_window, int(coarse_cfg.output_window)) | |
| 367 | + fine_input_window = max(rerank_window, int(fine_cfg.input_window)) | |
| 368 | + fine_output_window = max(rerank_window, int(fine_cfg.output_window)) | |
| 337 | 369 | # 若开启重排且请求范围在窗口内:从 ES 取前 rerank_window 条、重排后再按 from/size 分页;否则不重排,按原 from/size 查 ES |
| 338 | 370 | in_rerank_window = do_rerank and (from_ + size) <= rerank_window |
| 339 | 371 | es_fetch_from = 0 if in_rerank_window else from_ |
| 340 | - es_fetch_size = rerank_window if in_rerank_window else size | |
| 372 | + es_fetch_size = coarse_input_window if in_rerank_window else size | |
| 341 | 373 | |
| 342 | 374 | es_score_normalization_factor: Optional[float] = None |
| 343 | 375 | initial_ranks_by_doc: Dict[str, int] = {} |
| 376 | + coarse_debug_info: Optional[Dict[str, Any]] = None | |
| 377 | + fine_debug_info: Optional[Dict[str, Any]] = None | |
| 344 | 378 | rerank_debug_info: Optional[Dict[str, Any]] = None |
| 345 | 379 | |
| 346 | 380 | # Start timing |
| ... | ... | @@ -367,12 +401,19 @@ class Searcher: |
| 367 | 401 | 'enable_rerank_request': enable_rerank, |
| 368 | 402 | 'rerank_query_template': effective_query_template, |
| 369 | 403 | 'rerank_doc_template': effective_doc_template, |
| 404 | + 'fine_query_template': fine_query_template, | |
| 405 | + 'fine_doc_template': fine_doc_template, | |
| 370 | 406 | 'filters': filters, |
| 371 | 407 | 'range_filters': range_filters, |
| 372 | 408 | 'facets': facets, |
| 373 | 409 | 'enable_translation': enable_translation, |
| 374 | 410 | 'enable_embedding': enable_embedding, |
| 375 | 411 | 'enable_rerank': do_rerank, |
| 412 | + 'coarse_input_window': coarse_input_window, | |
| 413 | + 'coarse_output_window': coarse_output_window, | |
| 414 | + 'fine_input_window': fine_input_window, | |
| 415 | + 'fine_output_window': fine_output_window, | |
| 416 | + 'rerank_window': rerank_window, | |
| 376 | 417 | 'min_score': min_score, |
| 377 | 418 | 'sort_by': sort_by, |
| 378 | 419 | 'sort_order': sort_order |
| ... | ... | @@ -470,16 +511,12 @@ class Searcher: |
| 470 | 511 | # Keep requested response _source semantics for the final response fill. |
| 471 | 512 | response_source_spec = es_query.get("_source") |
| 472 | 513 | |
| 473 | - # In rerank window, first pass only fetches minimal fields required by rerank template. | |
| 514 | + # In multi-stage rank window, first pass only needs score signals for coarse rank. | |
| 474 | 515 | es_query_for_fetch = es_query |
| 475 | 516 | rerank_prefetch_source = None |
| 476 | 517 | if in_rerank_window: |
| 477 | - rerank_prefetch_source = self._resolve_rerank_source_filter( | |
| 478 | - effective_doc_template, | |
| 479 | - parsed_query=parsed_query, | |
| 480 | - ) | |
| 481 | 518 | es_query_for_fetch = dict(es_query) |
| 482 | - es_query_for_fetch["_source"] = rerank_prefetch_source | |
| 519 | + es_query_for_fetch["_source"] = False | |
| 483 | 520 | |
| 484 | 521 | # Extract size and from from body for ES client parameters |
| 485 | 522 | body_for_es = {k: v for k, v in es_query_for_fetch.items() if k not in ['size', 'from']} |
| ... | ... | @@ -587,26 +624,131 @@ class Searcher: |
| 587 | 624 | context.end_stage(RequestContextStage.ELASTICSEARCH_SEARCH_PRIMARY) |
| 588 | 625 | |
| 589 | 626 | style_intent_decisions: Dict[str, SkuSelectionDecision] = {} |
| 590 | - if self._has_style_intent(parsed_query) and in_rerank_window: | |
| 591 | - style_intent_decisions = self._apply_style_intent_to_hits( | |
| 592 | - es_response.get("hits", {}).get("hits") or [], | |
| 593 | - parsed_query, | |
| 594 | - context=context, | |
| 595 | - ) | |
| 596 | - if style_intent_decisions: | |
| 627 | + if do_rerank and in_rerank_window: | |
| 628 | + from dataclasses import asdict | |
| 629 | + from config.services_config import get_rerank_service_url | |
| 630 | + from .rerank_client import coarse_resort_hits, run_lightweight_rerank, run_rerank | |
| 631 | + | |
| 632 | + rerank_query = parsed_query.text_for_rerank() if parsed_query else query | |
| 633 | + hits = es_response.get("hits", {}).get("hits") or [] | |
| 634 | + | |
| 635 | + context.start_stage(RequestContextStage.COARSE_RANKING) | |
| 636 | + try: | |
| 637 | + coarse_debug = coarse_resort_hits( | |
| 638 | + hits, | |
| 639 | + fusion=coarse_cfg.fusion, | |
| 640 | + debug=debug, | |
| 641 | + ) | |
| 642 | + hits = hits[:coarse_output_window] | |
| 643 | + es_response.setdefault("hits", {})["hits"] = hits | |
| 644 | + if debug: | |
| 645 | + coarse_debug_info = { | |
| 646 | + "docs_in": es_fetch_size, | |
| 647 | + "docs_out": len(hits), | |
| 648 | + "fusion": asdict(coarse_cfg.fusion), | |
| 649 | + } | |
| 650 | + context.store_intermediate_result("coarse_rank_scores", coarse_debug) | |
| 597 | 651 | context.logger.info( |
| 598 | - "款式意图 SKU 预筛选完成 | hits=%s", | |
| 599 | - len(style_intent_decisions), | |
| 652 | + "粗排完成 | docs_in=%s | docs_out=%s", | |
| 653 | + es_fetch_size, | |
| 654 | + len(hits), | |
| 600 | 655 | extra={'reqid': context.reqid, 'uid': context.uid} |
| 601 | 656 | ) |
| 657 | + finally: | |
| 658 | + context.end_stage(RequestContextStage.COARSE_RANKING) | |
| 659 | + | |
| 660 | + ranking_source_spec = self._merge_source_specs( | |
| 661 | + self._resolve_rerank_source_filter( | |
| 662 | + fine_doc_template, | |
| 663 | + parsed_query=parsed_query, | |
| 664 | + ), | |
| 665 | + self._resolve_rerank_source_filter( | |
| 666 | + effective_doc_template, | |
| 667 | + parsed_query=parsed_query, | |
| 668 | + ), | |
| 669 | + ) | |
| 670 | + candidate_ids = [str(h.get("_id")) for h in hits if h.get("_id") is not None] | |
| 671 | + if candidate_ids: | |
| 672 | + details_by_id, fill_took = self._fetch_hits_by_ids( | |
| 673 | + index_name=index_name, | |
| 674 | + doc_ids=candidate_ids, | |
| 675 | + source_spec=ranking_source_spec, | |
| 676 | + ) | |
| 677 | + for hit in hits: | |
| 678 | + hid = hit.get("_id") | |
| 679 | + if hid is None: | |
| 680 | + continue | |
| 681 | + detail_hit = details_by_id.get(str(hid)) | |
| 682 | + if detail_hit is not None and "_source" in detail_hit: | |
| 683 | + hit["_source"] = detail_hit.get("_source") or {} | |
| 684 | + if fill_took: | |
| 685 | + es_response["took"] = int((es_response.get("took", 0) or 0) + fill_took) | |
| 686 | + | |
| 687 | + if self._has_style_intent(parsed_query): | |
| 688 | + style_intent_decisions = self._apply_style_intent_to_hits( | |
| 689 | + es_response.get("hits", {}).get("hits") or [], | |
| 690 | + parsed_query, | |
| 691 | + context=context, | |
| 692 | + ) | |
| 693 | + if style_intent_decisions: | |
| 694 | + context.logger.info( | |
| 695 | + "款式意图 SKU 预筛选完成 | hits=%s", | |
| 696 | + len(style_intent_decisions), | |
| 697 | + extra={'reqid': context.reqid, 'uid': context.uid} | |
| 698 | + ) | |
| 699 | + | |
| 700 | + fine_scores: Optional[List[float]] = None | |
| 701 | + hits = es_response.get("hits", {}).get("hits") or [] | |
| 702 | + if fine_cfg.enabled and hits: | |
| 703 | + context.start_stage(RequestContextStage.FINE_RANKING) | |
| 704 | + try: | |
| 705 | + fine_scores, fine_meta, fine_debug_rows = run_lightweight_rerank( | |
| 706 | + query=rerank_query, | |
| 707 | + es_hits=hits[:fine_input_window], | |
| 708 | + language=language, | |
| 709 | + timeout_sec=fine_cfg.timeout_sec, | |
| 710 | + rerank_query_template=fine_query_template, | |
| 711 | + rerank_doc_template=fine_doc_template, | |
| 712 | + top_n=fine_output_window, | |
| 713 | + debug=debug, | |
| 714 | + service_profile=fine_cfg.service_profile, | |
| 715 | + ) | |
| 716 | + if fine_scores is not None: | |
| 717 | + hits = hits[:fine_output_window] | |
| 718 | + es_response["hits"]["hits"] = hits | |
| 719 | + if debug: | |
| 720 | + fine_debug_info = { | |
| 721 | + "service_url": get_rerank_service_url(profile=fine_cfg.service_profile), | |
| 722 | + "query_template": fine_query_template, | |
| 723 | + "doc_template": fine_doc_template, | |
| 724 | + "query_text": str(fine_query_template).format_map({"query": rerank_query}), | |
| 725 | + "docs": len(hits), | |
| 726 | + "top_n": fine_output_window, | |
| 727 | + "meta": fine_meta, | |
| 728 | + } | |
| 729 | + context.store_intermediate_result("fine_rank_scores", fine_debug_rows) | |
| 730 | + context.logger.info( | |
| 731 | + "精排完成 | docs=%s | top_n=%s | meta=%s", | |
| 732 | + len(hits), | |
| 733 | + fine_output_window, | |
| 734 | + fine_meta, | |
| 735 | + extra={'reqid': context.reqid, 'uid': context.uid} | |
| 736 | + ) | |
| 737 | + except Exception as e: | |
| 738 | + context.add_warning(f"Fine rerank failed: {e}") | |
| 739 | + context.logger.warning( | |
| 740 | + f"调用精排服务失败 | error: {e}", | |
| 741 | + extra={'reqid': context.reqid, 'uid': context.uid}, | |
| 742 | + exc_info=True, | |
| 743 | + ) | |
| 744 | + finally: | |
| 745 | + context.end_stage(RequestContextStage.FINE_RANKING) | |
| 602 | 746 | |
| 603 | - # Optional Step 4.5: AI reranking(仅当请求范围在重排窗口内时执行) | |
| 604 | - if do_rerank and in_rerank_window: | |
| 605 | 747 | context.start_stage(RequestContextStage.RERANKING) |
| 606 | 748 | try: |
| 607 | - from .rerank_client import run_rerank | |
| 608 | - | |
| 609 | - rerank_query = parsed_query.text_for_rerank() if parsed_query else query | |
| 749 | + final_hits = es_response.get("hits", {}).get("hits") or [] | |
| 750 | + final_input = final_hits[:rerank_window] | |
| 751 | + es_response["hits"]["hits"] = final_input | |
| 610 | 752 | es_response, rerank_meta, fused_debug = run_rerank( |
| 611 | 753 | query=rerank_query, |
| 612 | 754 | es_response=es_response, |
| ... | ... | @@ -619,15 +761,15 @@ class Searcher: |
| 619 | 761 | top_n=(from_ + size), |
| 620 | 762 | debug=debug, |
| 621 | 763 | fusion=rc.fusion, |
| 764 | + fine_scores=fine_scores[:len(final_input)] if fine_scores is not None else None, | |
| 765 | + service_profile=rc.service_profile, | |
| 622 | 766 | style_intent_selected_sku_boost=self.config.query_config.style_intent_selected_sku_boost, |
| 623 | 767 | ) |
| 624 | 768 | |
| 625 | 769 | if rerank_meta is not None: |
| 626 | 770 | if debug: |
| 627 | - from dataclasses import asdict | |
| 628 | - from config.services_config import get_rerank_service_url | |
| 629 | 771 | rerank_debug_info = { |
| 630 | - "service_url": get_rerank_service_url(), | |
| 772 | + "service_url": get_rerank_service_url(profile=rc.service_profile), | |
| 631 | 773 | "query_template": effective_query_template, |
| 632 | 774 | "doc_template": effective_doc_template, |
| 633 | 775 | "query_text": str(effective_query_template).format_map({"query": rerank_query}), |
| ... | ... | @@ -652,15 +794,17 @@ class Searcher: |
| 652 | 794 | finally: |
| 653 | 795 | context.end_stage(RequestContextStage.RERANKING) |
| 654 | 796 | |
| 655 | - # 当本次请求在重排窗口内时:已从 ES 取了 rerank_window 条并可能已重排,需按请求的 from/size 做分页切片 | |
| 797 | + # 当本次请求在重排窗口内时:已按多阶段排序产出前 rerank_window 条,需按请求的 from/size 做分页切片 | |
| 656 | 798 | if in_rerank_window: |
| 657 | 799 | hits = es_response.get("hits", {}).get("hits") or [] |
| 658 | 800 | sliced = hits[from_ : from_ + size] |
| 659 | 801 | es_response.setdefault("hits", {})["hits"] = sliced |
| 660 | 802 | if sliced: |
| 661 | - # 对于启用重排的结果,优先使用 _fused_score 计算 max_score;否则退回原始 _score | |
| 662 | 803 | slice_max = max( |
| 663 | - (h.get("_fused_score", h.get("_score", 0.0)) for h in sliced), | |
| 804 | + ( | |
| 805 | + h.get("_fused_score", h.get("_fine_score", h.get("_coarse_score", h.get("_score", 0.0)))) | |
| 806 | + for h in sliced | |
| 807 | + ), | |
| 664 | 808 | default=0.0, |
| 665 | 809 | ) |
| 666 | 810 | try: |
| ... | ... | @@ -670,7 +814,6 @@ class Searcher: |
| 670 | 814 | else: |
| 671 | 815 | es_response["hits"]["max_score"] = 0.0 |
| 672 | 816 | |
| 673 | - # Page fill: fetch detailed fields only for final page hits. | |
| 674 | 817 | if sliced: |
| 675 | 818 | if response_source_spec is False: |
| 676 | 819 | for hit in sliced: |
| ... | ... | @@ -754,6 +897,16 @@ class Searcher: |
| 754 | 897 | if doc_id is None: |
| 755 | 898 | continue |
| 756 | 899 | rerank_debug_by_doc[str(doc_id)] = item |
| 900 | + fine_debug_raw = context.get_intermediate_result('fine_rank_scores', None) | |
| 901 | + fine_debug_by_doc: Dict[str, Dict[str, Any]] = {} | |
| 902 | + if isinstance(fine_debug_raw, list): | |
| 903 | + for item in fine_debug_raw: | |
| 904 | + if not isinstance(item, dict): | |
| 905 | + continue | |
| 906 | + doc_id = item.get("doc_id") | |
| 907 | + if doc_id is None: | |
| 908 | + continue | |
| 909 | + fine_debug_by_doc[str(doc_id)] = item | |
| 757 | 910 | |
| 758 | 911 | if self._has_style_intent(parsed_query): |
| 759 | 912 | if style_intent_decisions: |
| ... | ... | @@ -784,6 +937,9 @@ class Searcher: |
| 784 | 937 | rerank_debug = None |
| 785 | 938 | if doc_id is not None: |
| 786 | 939 | rerank_debug = rerank_debug_by_doc.get(str(doc_id)) |
| 940 | + fine_debug = None | |
| 941 | + if doc_id is not None: | |
| 942 | + fine_debug = fine_debug_by_doc.get(str(doc_id)) | |
| 787 | 943 | style_intent_debug = None |
| 788 | 944 | if doc_id is not None and style_intent_decisions: |
| 789 | 945 | decision = style_intent_decisions.get(str(doc_id)) |
| ... | ... | @@ -823,6 +979,7 @@ class Searcher: |
| 823 | 979 | debug_entry["doc_id"] = rerank_debug.get("doc_id") |
| 824 | 980 | # 与 rerank_client 中字段保持一致,便于前端直接使用 |
| 825 | 981 | debug_entry["rerank_score"] = rerank_debug.get("rerank_score") |
| 982 | + debug_entry["fine_score"] = rerank_debug.get("fine_score") | |
| 826 | 983 | debug_entry["text_score"] = rerank_debug.get("text_score") |
| 827 | 984 | debug_entry["text_source_score"] = rerank_debug.get("text_source_score") |
| 828 | 985 | debug_entry["text_translation_score"] = rerank_debug.get("text_translation_score") |
| ... | ... | @@ -833,11 +990,16 @@ class Searcher: |
| 833 | 990 | debug_entry["text_score_fallback_to_es"] = rerank_debug.get("text_score_fallback_to_es") |
| 834 | 991 | debug_entry["knn_score"] = rerank_debug.get("knn_score") |
| 835 | 992 | debug_entry["rerank_factor"] = rerank_debug.get("rerank_factor") |
| 993 | + debug_entry["fine_factor"] = rerank_debug.get("fine_factor") | |
| 836 | 994 | debug_entry["text_factor"] = rerank_debug.get("text_factor") |
| 837 | 995 | debug_entry["knn_factor"] = rerank_debug.get("knn_factor") |
| 838 | 996 | debug_entry["fused_score"] = rerank_debug.get("fused_score") |
| 839 | 997 | debug_entry["rerank_input"] = rerank_debug.get("rerank_input") |
| 840 | 998 | debug_entry["matched_queries"] = rerank_debug.get("matched_queries") |
| 999 | + elif fine_debug: | |
| 1000 | + debug_entry["doc_id"] = fine_debug.get("doc_id") | |
| 1001 | + debug_entry["fine_score"] = fine_debug.get("fine_score") | |
| 1002 | + debug_entry["rerank_input"] = fine_debug.get("rerank_input") | |
| 841 | 1003 | |
| 842 | 1004 | if style_intent_debug: |
| 843 | 1005 | debug_entry["style_intent_sku"] = style_intent_debug |
| ... | ... | @@ -908,6 +1070,8 @@ class Searcher: |
| 908 | 1070 | "shards": es_response.get('_shards', {}), |
| 909 | 1071 | "es_score_normalization_factor": es_score_normalization_factor, |
| 910 | 1072 | }, |
| 1073 | + "coarse_rank": coarse_debug_info, | |
| 1074 | + "fine_rank": fine_debug_info, | |
| 911 | 1075 | "rerank": rerank_debug_info, |
| 912 | 1076 | "feature_flags": context.metadata.get('feature_flags', {}), |
| 913 | 1077 | "stage_timings": { | ... | ... |
| ... | ... | @@ -0,0 +1,43 @@ |
| 1 | +白色oversized T-shirt | |
| 2 | +falda negra oficina | |
| 3 | +red fitted tee | |
| 4 | +黒いミディ丈スカート | |
| 5 | +黑色中长半身裙 | |
| 6 | +فستان أسود متوسط الطول | |
| 7 | +чёрное летнее платье | |
| 8 | +修身牛仔裤 | |
| 9 | +date night dress | |
| 10 | +vacation outfit dress | |
| 11 | +minimalist top | |
| 12 | +streetwear t-shirt | |
| 13 | +office casual blouse | |
| 14 | +街头风T恤 | |
| 15 | +宽松T恤 | |
| 16 | +复古印花T恤 | |
| 17 | +Y2K上衣 | |
| 18 | +情侣T恤 | |
| 19 | +美式复古T恤 | |
| 20 | +重磅棉T恤 | |
| 21 | +修身打底衫 | |
| 22 | +辣妹风短袖 | |
| 23 | +纯欲上衣 | |
| 24 | +正肩白T恤 | |
| 25 | +波西米亚花朵衬衫 | |
| 26 | +泡泡袖短袖 | |
| 27 | +扎染字母T恤 | |
| 28 | +T-shirt Dress | |
| 29 | +Crop Top | |
| 30 | +Lace Undershirt | |
| 31 | +Leopard Print Ripped T-shirt | |
| 32 | +Breton Stripe T-shirt | |
| 33 | +V-Neck Cotton T-shirt | |
| 34 | +Sweet & Cool Bow T-shirt | |
| 35 | +Vacation Style T-shirt | |
| 36 | +Commuter Casual Top | |
| 37 | +Minimalist Solid T-shirt | |
| 38 | +Band T-shirt | |
| 39 | +Athletic Gym T-shirt | |
| 40 | +Plus Size Loose T-shirt | |
| 41 | +Korean Style Slim T-shirt | |
| 42 | +Basic Layering Top | |
| 43 | + | ... | ... |
tests/test_search_rerank_window.py
| ... | ... | @@ -311,11 +311,18 @@ def test_searcher_reranks_top_window_by_default(monkeypatch): |
| 311 | 311 | |
| 312 | 312 | called: Dict[str, Any] = {"count": 0, "docs": 0} |
| 313 | 313 | |
| 314 | + def _fake_run_lightweight_rerank(**kwargs): | |
| 315 | + hits = kwargs["es_hits"] | |
| 316 | + for idx, hit in enumerate(hits): | |
| 317 | + hit["_fine_score"] = float(len(hits) - idx) | |
| 318 | + return [hit["_fine_score"] for hit in hits], {"stage": "fine"}, [] | |
| 319 | + | |
| 314 | 320 | def _fake_run_rerank(**kwargs): |
| 315 | 321 | called["count"] += 1 |
| 316 | 322 | called["docs"] = len(kwargs["es_response"]["hits"]["hits"]) |
| 317 | 323 | return kwargs["es_response"], None, [] |
| 318 | 324 | |
| 325 | + monkeypatch.setattr("search.rerank_client.run_lightweight_rerank", _fake_run_lightweight_rerank) | |
| 319 | 326 | monkeypatch.setattr("search.rerank_client.run_rerank", _fake_run_rerank) |
| 320 | 327 | |
| 321 | 328 | result = searcher.search( |
| ... | ... | @@ -328,17 +335,20 @@ def test_searcher_reranks_top_window_by_default(monkeypatch): |
| 328 | 335 | ) |
| 329 | 336 | |
| 330 | 337 | assert called["count"] == 1 |
| 331 | - # 应当对配置的 rerank_window 条文档做重排预取 | |
| 332 | - window = searcher.config.rerank.rerank_window | |
| 333 | - assert called["docs"] == window | |
| 338 | + assert called["docs"] == searcher.config.rerank.rerank_window | |
| 334 | 339 | assert es_client.calls[0]["from_"] == 0 |
| 335 | - assert es_client.calls[0]["size"] == window | |
| 340 | + assert es_client.calls[0]["size"] == searcher.config.coarse_rank.input_window | |
| 336 | 341 | assert es_client.calls[0]["include_named_queries_score"] is True |
| 337 | - assert es_client.calls[0]["body"]["_source"] == {"includes": ["title"]} | |
| 338 | - assert len(es_client.calls) == 2 | |
| 339 | - assert es_client.calls[1]["size"] == 10 | |
| 342 | + assert es_client.calls[0]["body"]["_source"] is False | |
| 343 | + assert len(es_client.calls) == 3 | |
| 344 | + assert es_client.calls[1]["size"] == max( | |
| 345 | + searcher.config.coarse_rank.output_window, | |
| 346 | + searcher.config.rerank.rerank_window, | |
| 347 | + ) | |
| 340 | 348 | assert es_client.calls[1]["from_"] == 0 |
| 341 | - assert es_client.calls[1]["body"]["query"]["ids"]["values"] == [str(i) for i in range(20, 30)] | |
| 349 | + assert es_client.calls[2]["size"] == 10 | |
| 350 | + assert es_client.calls[2]["from_"] == 0 | |
| 351 | + assert es_client.calls[2]["body"]["query"]["ids"]["values"] == [str(i) for i in range(20, 30)] | |
| 342 | 352 | assert len(result.results) == 10 |
| 343 | 353 | assert result.results[0].spu_id == "20" |
| 344 | 354 | assert result.results[0].brief == "brief-20" |
| ... | ... | @@ -353,6 +363,10 @@ def test_searcher_rerank_prefetch_source_follows_doc_template(monkeypatch): |
| 353 | 363 | "search.searcher.get_tenant_config_loader", |
| 354 | 364 | lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), |
| 355 | 365 | ) |
| 366 | + monkeypatch.setattr( | |
| 367 | + "search.rerank_client.run_lightweight_rerank", | |
| 368 | + lambda **kwargs: ([1.0] * len(kwargs["es_hits"]), {"stage": "fine"}, []), | |
| 369 | + ) | |
| 356 | 370 | monkeypatch.setattr("search.rerank_client.run_rerank", lambda **kwargs: (kwargs["es_response"], None, [])) |
| 357 | 371 | |
| 358 | 372 | searcher.search( |
| ... | ... | @@ -365,7 +379,8 @@ def test_searcher_rerank_prefetch_source_follows_doc_template(monkeypatch): |
| 365 | 379 | rerank_doc_template="{title} {vendor} {brief}", |
| 366 | 380 | ) |
| 367 | 381 | |
| 368 | - assert es_client.calls[0]["body"]["_source"] == {"includes": ["brief", "title", "vendor"]} | |
| 382 | + assert es_client.calls[0]["body"]["_source"] is False | |
| 383 | + assert es_client.calls[1]["body"]["_source"] == {"includes": ["brief", "title", "vendor"]} | |
| 369 | 384 | |
| 370 | 385 | |
| 371 | 386 | def test_searcher_rerank_prefetch_source_includes_sku_fields_when_style_intent_active(monkeypatch): |
| ... | ... | @@ -378,6 +393,10 @@ def test_searcher_rerank_prefetch_source_includes_sku_fields_when_style_intent_a |
| 378 | 393 | lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), |
| 379 | 394 | ) |
| 380 | 395 | monkeypatch.setattr( |
| 396 | + "search.rerank_client.run_lightweight_rerank", | |
| 397 | + lambda **kwargs: ([1.0] * len(kwargs["es_hits"]), {"stage": "fine"}, []), | |
| 398 | + ) | |
| 399 | + monkeypatch.setattr( | |
| 381 | 400 | "search.rerank_client.run_rerank", |
| 382 | 401 | lambda **kwargs: (kwargs["es_response"], None, []), |
| 383 | 402 | ) |
| ... | ... | @@ -414,7 +433,8 @@ def test_searcher_rerank_prefetch_source_includes_sku_fields_when_style_intent_a |
| 414 | 433 | enable_rerank=None, |
| 415 | 434 | ) |
| 416 | 435 | |
| 417 | - assert es_client.calls[0]["body"]["_source"] == { | |
| 436 | + assert es_client.calls[0]["body"]["_source"] is False | |
| 437 | + assert es_client.calls[1]["body"]["_source"] == { | |
| 418 | 438 | "includes": ["option1_name", "option2_name", "option3_name", "skus", "title"] |
| 419 | 439 | } |
| 420 | 440 | ... | ... |