Compare View
Commits (9)
-
@config/dictionaries/style_intent_color.csv @config/dictionaries/style_intent_size.csv @query/style_intent.py @search/sku_intent_selector.py 1. 两个csv词典,分为三列, - 英文关键词 - 中文关键词 - 标准属性名称词 三列都可以允许逗号分割。补充的第三列使用在商品属性中,使用的是标准的英文名称 2. 判断意图的时候,中文词用中文翻译名去匹配,如果不存在中文翻译名,则用原始 query,英文词同理 3. SKU 选择的时候,用每一个 SKU 的属性名去匹配。 匹配规则要大幅度简化,并做性能优化: 1)文本匹配规则只需要看规范化后的属性值是否包含了词典配置的第三列"标准属性名称词",如果包含了,则认为匹配成功。 找到第一个匹配成功的即可。如果都没有成功,后面也不再需要用向量匹配。 暂时废弃向量匹配、双向匹配等复杂逻辑。
-
这两个配置、四种情况: backend: qwen3_vllm | qwen3_vllm_score instruction_format: compact | standard 调用 python scripts/benchmark_reranker_random_titles.py 100,200,400,600,800,1000 --repeat 5 产出性能测试报告 平均延迟(ms,客户端 POST /rerank 墙钟,--seed 99) backend instruction_format n=100 n=200 n=400 n=600 n=800 n=1000 qwen3_vllm compact 213.5 418.0 861.4 1263.4 1744.3 2162.2 qwen3_vllm standard 254.9 475.4 909.7 1353.2 1912.5 2406.7 qwen3_vllm_score compact 239.2 480.2 966.2 1433.5 1937.2 2428.4 qwen3_vllm_score standard 299.6 591.8 1178.9 1773.7 2341.6 2931.7 归纳: 在本机 T4、当前 vLLM 与上述 YAML(max_model_len=160、infer_batch_size=100 等)下,两种后端都是 compact 快于 standard;整体最快为 qwen3_vllm + compact(n=1000 ≈ 2.16 s),最慢为 qwen3_vllm_score + standard(≈ 2.93 s)。其他 GPU / vLLM 版本下排序可能变化。
Showing
45 changed files
Show diff stats
config/config.yaml
| ... | ... | @@ -114,10 +114,11 @@ query_config: |
| 114 | 114 | # 查询解析阶段:翻译与 query 向量并发执行,共用同一等待预算(毫秒)。 |
| 115 | 115 | # 检测语言已在租户 index_languages 内:较短;不在索引语言内:较长(翻译对召回更关键)。 |
| 116 | 116 | translation_embedding_wait_budget_ms_source_in_index: 500 # 80 |
| 117 | - translation_embedding_wait_budget_ms_source_not_in_index: 500 #200 | |
| 117 | + translation_embedding_wait_budget_ms_source_not_in_index: 700 #200 | |
| 118 | 118 | |
| 119 | 119 | style_intent: |
| 120 | 120 | enabled: true |
| 121 | + selected_sku_boost: 1.2 | |
| 121 | 122 | color_dictionary_path: "config/dictionaries/style_intent_color.csv" |
| 122 | 123 | size_dictionary_path: "config/dictionaries/style_intent_size.csv" |
| 123 | 124 | dimension_aliases: |
| ... | ... | @@ -230,7 +231,7 @@ rerank: |
| 230 | 231 | text_bias: 0.1 |
| 231 | 232 | text_exponent: 0.35 |
| 232 | 233 | knn_bias: 0.6 |
| 233 | - knn_exponent: 0.2 | |
| 234 | + knn_exponent: 0.0 | |
| 234 | 235 | |
| 235 | 236 | # 可扩展服务/provider 注册表(单一配置源) |
| 236 | 237 | services: |
| ... | ... | @@ -380,7 +381,7 @@ services: |
| 380 | 381 | max_docs: 1000 |
| 381 | 382 | normalize: true |
| 382 | 383 | # 服务内后端(reranker 进程启动时读取) |
| 383 | - backend: "bge" # bge | qwen3_vllm | qwen3_transformers | dashscope_rerank | |
| 384 | + backend: "qwen3_vllm" # bge | qwen3_vllm | qwen3_vllm_score | qwen3_transformers | qwen3_transformers_packed | qwen3_gguf | qwen3_gguf_06b | dashscope_rerank | |
| 384 | 385 | backends: |
| 385 | 386 | bge: |
| 386 | 387 | model_name: "BAAI/bge-reranker-v2-m3" |
| ... | ... | @@ -401,6 +402,9 @@ services: |
| 401 | 402 | enforce_eager: false |
| 402 | 403 | infer_batch_size: 100 |
| 403 | 404 | sort_by_doc_length: true |
| 405 | + # 与 reranker/backends/qwen3_vllm.py 一致:standard=_format_instruction__standard(固定 yes/no system);compact=_format_instruction(instruction 作 system 且 user 内重复 Instruct) | |
| 406 | + # instruction_format: compact | |
| 407 | + instruction_format: compact | |
| 404 | 408 | # instruction: "Given a query, score the product for relevance" |
| 405 | 409 | # "rank products by given query" 比 “Given a query, score the product for relevance” 更好点 |
| 406 | 410 | # instruction: "rank products by given query, category match first" |
| ... | ... | @@ -410,6 +414,32 @@ services: |
| 410 | 414 | # instruction: "Relevance ranking: category & style match first" |
| 411 | 415 | # instruction: "Score product relevance by query with category & style match prioritized" |
| 412 | 416 | instruction: "Rank products by query with category & style match prioritized" |
| 417 | + # vLLM LLM.score()(跨编码打分)。独立高性能环境 .venv-reranker-score(vllm 0.18 固定版):./scripts/setup_reranker_venv.sh qwen3_vllm_score | |
| 418 | + # 与 qwen3_vllm 可共用同一 model_name / HF 缓存;venv 分离以便升级 vLLM 而不影响 generate 后端。 | |
| 419 | + qwen3_vllm_score: | |
| 420 | + model_name: "Qwen/Qwen3-Reranker-0.6B" | |
| 421 | + # 官方 Hub 原版需 true;若改用已转换的 seq-cls 权重(如 tomaarsen/...-seq-cls)则设为 false | |
| 422 | + use_original_qwen3_hf_overrides: true | |
| 423 | + # vLLM 0.18:算力 < 8(如 T4)默认自动用 TRITON_ATTN;Ampere+ 可省略或设 auto。也可设环境变量 RERANK_VLLM_ATTENTION_BACKEND | |
| 424 | + # vllm_attention_backend: "auto" | |
| 425 | + # 可选:与 vLLM 对齐;一般保持 auto | |
| 426 | + # vllm_runner: "auto" | |
| 427 | + # vllm_convert: "auto" | |
| 428 | + # 可选:在 use_original_qwen3_hf_overrides 为 true 时与内置 overrides 合并 | |
| 429 | + # hf_overrides: {} | |
| 430 | + engine: "vllm" | |
| 431 | + max_model_len: 160 | |
| 432 | + tensor_parallel_size: 1 | |
| 433 | + gpu_memory_utilization: 0.20 | |
| 434 | + dtype: "float16" | |
| 435 | + enable_prefix_caching: true | |
| 436 | + enforce_eager: false | |
| 437 | + infer_batch_size: 100 | |
| 438 | + sort_by_doc_length: true | |
| 439 | + # 与 qwen3_vllm 同名项语义一致;默认 standard 与 vLLM 官方 Qwen3 reranker 前缀一致 | |
| 440 | + # instruction_format: compact | |
| 441 | + instruction_format: standard | |
| 442 | + instruction: "Rank products by query with category & style match prioritized" | |
| 413 | 443 | qwen3_transformers: |
| 414 | 444 | model_name: "Qwen/Qwen3-Reranker-0.6B" |
| 415 | 445 | instruction: "rank products by given query" |
| ... | ... | @@ -419,6 +449,68 @@ services: |
| 419 | 449 | use_fp16: true |
| 420 | 450 | # sdpa:默认无需 flash-attn;若已安装 flash_attn 可改为 flash_attention_2 |
| 421 | 451 | attn_implementation: "sdpa" |
| 452 | + # Packed Transformers backend: shared query prefix + custom position_ids/attention_mask. | |
| 453 | + # For 1 query + many short docs (for example 400 product titles), this usually reduces | |
| 454 | + # repeated prefix work and padding waste compared with pairwise batching. | |
| 455 | + qwen3_transformers_packed: | |
| 456 | + model_name: "Qwen/Qwen3-Reranker-0.6B" | |
| 457 | + instruction: "Rank products by query with category & style match prioritized" | |
| 458 | + max_model_len: 4096 | |
| 459 | + max_doc_len: 160 | |
| 460 | + max_docs_per_pack: 0 | |
| 461 | + use_fp16: true | |
| 462 | + sort_by_doc_length: true | |
| 463 | + # Packed mode relies on a custom 4D attention mask. "eager" is the safest default. | |
| 464 | + # If your torch/transformers stack validates it, you can benchmark "sdpa". | |
| 465 | + attn_implementation: "eager" | |
| 466 | + qwen3_gguf: | |
| 467 | + repo_id: "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF" | |
| 468 | + filename: "*Q8_0.gguf" | |
| 469 | + cache_dir: "./model_cache" | |
| 470 | + local_dir: "./models/reranker/qwen3-reranker-4b-gguf" | |
| 471 | + instruction: "Rank products by query with category & style match prioritized" | |
| 472 | + # T4 16GB / 性能优先配置:全量层 offload,实测比保守配置明显更快 | |
| 473 | + n_ctx: 512 | |
| 474 | + n_batch: 512 | |
| 475 | + n_ubatch: 512 | |
| 476 | + n_gpu_layers: 999 | |
| 477 | + main_gpu: 0 | |
| 478 | + n_threads: 2 | |
| 479 | + n_threads_batch: 4 | |
| 480 | + flash_attn: true | |
| 481 | + offload_kqv: true | |
| 482 | + use_mmap: true | |
| 483 | + use_mlock: false | |
| 484 | + infer_batch_size: 8 | |
| 485 | + sort_by_doc_length: true | |
| 486 | + length_sort_mode: "char" | |
| 487 | + enable_warmup: true | |
| 488 | + verbose: false | |
| 489 | + qwen3_gguf_06b: | |
| 490 | + repo_id: "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF" | |
| 491 | + filename: "qwen3-reranker-0.6b-q8_0.gguf" | |
| 492 | + cache_dir: "./model_cache" | |
| 493 | + local_dir: "./models/reranker/qwen3-reranker-0.6b-q8_0-gguf" | |
| 494 | + instruction: "Rank products by query with category & style match prioritized" | |
| 495 | + # 0.6B GGUF / online rerank baseline: | |
| 496 | + # 实测 400 titles 单请求约 265s,因此它更适合作为低显存功能后备,不适合在线低延迟主路由。 | |
| 497 | + n_ctx: 256 | |
| 498 | + n_batch: 256 | |
| 499 | + n_ubatch: 256 | |
| 500 | + n_gpu_layers: 999 | |
| 501 | + main_gpu: 0 | |
| 502 | + n_threads: 2 | |
| 503 | + n_threads_batch: 4 | |
| 504 | + flash_attn: true | |
| 505 | + offload_kqv: true | |
| 506 | + use_mmap: true | |
| 507 | + use_mlock: false | |
| 508 | + infer_batch_size: 32 | |
| 509 | + sort_by_doc_length: true | |
| 510 | + length_sort_mode: "char" | |
| 511 | + reuse_query_state: false | |
| 512 | + enable_warmup: true | |
| 513 | + verbose: false | |
| 422 | 514 | dashscope_rerank: |
| 423 | 515 | model_name: "qwen3-rerank" |
| 424 | 516 | # 按地域选择 endpoint: | ... | ... |
config/dictionaries/product_title_exclusion.tsv
config/dictionaries/style_intent_color.csv
| 1 | -black,black,blk,黑,黑色 | |
| 2 | -white,white,wht,白,白色 | |
| 3 | -red,red,reddish,红,红色 | |
| 4 | -blue,blue,blu,蓝,蓝色 | |
| 5 | -green,green,grn,绿,绿色 | |
| 6 | -yellow,yellow,ylw,黄,黄色 | |
| 7 | -pink,pink,粉,粉色 | |
| 8 | -purple,purple,violet,紫,紫色 | |
| 9 | -gray,gray,grey,灰,灰色 | |
| 10 | -brown,brown,棕,棕色,咖啡色 | |
| 11 | -beige,beige,khaki,米色,卡其色 | |
| 12 | -navy,navy,navy blue,藏青,藏蓝,深蓝 | |
| 13 | -silver,silver,银,银色 | |
| 14 | -gold,gold,金,金色 | |
| 15 | -orange,orange,橙,橙色 | |
| 1 | +"black,blk","黑,黑色","black" | |
| 2 | +"white,wht","白,白色","white" | |
| 3 | +"red,reddish","红,红色","red" | |
| 4 | +"blue,blu","蓝,蓝色","blue" | |
| 5 | +"green,grn","绿,绿色","green" | |
| 6 | +"yellow,ylw","黄,黄色","yellow" | |
| 7 | +"pink","粉,粉色","pink" | |
| 8 | +"purple,violet","紫,紫色","purple" | |
| 9 | +"gray,grey","灰,灰色","gray,grey" | |
| 10 | +"brown","棕,棕色,咖啡色","brown" | |
| 11 | +"beige,khaki","米色,卡其色","beige,khaki" | |
| 12 | +"navy,navy blue","藏青,藏蓝,深蓝","navy" | |
| 13 | +"silver","银,银色","silver" | |
| 14 | +"gold","金,金色","gold" | |
| 15 | +"orange","橙,橙色","orange" | ... | ... |
config/dictionaries/style_intent_size.csv
| 1 | -xs,xs,extra small,x-small,加小码 | |
| 2 | -s,s,small,小码,小号 | |
| 3 | -m,m,medium,中码,中号 | |
| 4 | -l,l,large,大码,大号 | |
| 5 | -xl,xl,x-large,extra large,加大码 | |
| 6 | -xxl,xxl,2xl,xx-large,双加大码 | |
| 7 | -xxxl,xxxl,3xl,xxx-large,三加大码 | |
| 8 | -one size,one size,onesize,free size,均码 | |
| 1 | +"xs,extra small,x-small","加小码","xs,extra small,x-small" | |
| 2 | +"s,small","小码,小号","s,small" | |
| 3 | +"m,medium","中码,中号","m,medium" | |
| 4 | +"l,large","大码,大号","l,large" | |
| 5 | +"xl,x-large,extra large","加大码","xl,x-large,extra large" | |
| 6 | +"xxl,2xl,xx-large","双加大码","xxl,2xl,xx-large" | |
| 7 | +"xxxl,3xl,xxx-large","三加大码","xxxl,3xl,xxx-large" | ... | ... |
config/loader.py
| ... | ... | @@ -10,6 +10,7 @@ from __future__ import annotations |
| 10 | 10 | import hashlib |
| 11 | 11 | import json |
| 12 | 12 | import os |
| 13 | +import csv | |
| 13 | 14 | from copy import deepcopy |
| 14 | 15 | from dataclasses import asdict |
| 15 | 16 | from functools import lru_cache |
| ... | ... | @@ -96,20 +97,33 @@ def _read_rewrite_dictionary(path: Path) -> Dict[str, str]: |
| 96 | 97 | return rewrite_dict |
| 97 | 98 | |
| 98 | 99 | |
| 99 | -def _read_synonym_csv_dictionary(path: Path) -> List[List[str]]: | |
| 100 | - rows: List[List[str]] = [] | |
| 100 | +def _read_synonym_csv_dictionary(path: Path) -> List[Dict[str, List[str]]]: | |
| 101 | + rows: List[Dict[str, List[str]]] = [] | |
| 101 | 102 | if not path.exists(): |
| 102 | 103 | return rows |
| 103 | 104 | |
| 105 | + def _split_terms(cell: str) -> List[str]: | |
| 106 | + return [item.strip() for item in str(cell or "").split(",") if item.strip()] | |
| 107 | + | |
| 104 | 108 | with open(path, "r", encoding="utf-8") as handle: |
| 105 | - for raw_line in handle: | |
| 106 | - line = raw_line.strip() | |
| 107 | - if not line or line.startswith("#"): | |
| 109 | + reader = csv.reader(handle) | |
| 110 | + for parts in reader: | |
| 111 | + if not parts: | |
| 112 | + continue | |
| 113 | + if parts[0].strip().startswith("#"): | |
| 108 | 114 | continue |
| 109 | - parts = [segment.strip() for segment in line.split(",")] | |
| 110 | - normalized = [segment for segment in parts if segment] | |
| 111 | - if normalized: | |
| 112 | - rows.append(normalized) | |
| 115 | + | |
| 116 | + normalized = [segment.strip() for segment in parts] | |
| 117 | + if len(normalized) < 3: | |
| 118 | + continue | |
| 119 | + | |
| 120 | + row = { | |
| 121 | + "en_terms": _split_terms(normalized[0]), | |
| 122 | + "zh_terms": _split_terms(normalized[1]), | |
| 123 | + "attribute_terms": _split_terms(normalized[2]), | |
| 124 | + } | |
| 125 | + if any(row.values()): | |
| 126 | + rows.append(row) | |
| 113 | 127 | return rows |
| 114 | 128 | |
| 115 | 129 | |
| ... | ... | @@ -425,6 +439,9 @@ class AppConfigLoader: |
| 425 | 439 | query_cfg.get("translation_embedding_wait_budget_ms_source_not_in_index", 200) |
| 426 | 440 | ), |
| 427 | 441 | style_intent_enabled=bool(style_intent_cfg.get("enabled", True)), |
| 442 | + style_intent_selected_sku_boost=float( | |
| 443 | + style_intent_cfg.get("selected_sku_boost", 1.2) | |
| 444 | + ), | |
| 428 | 445 | style_intent_terms=style_intent_terms, |
| 429 | 446 | style_intent_dimension_aliases=style_dimension_aliases, |
| 430 | 447 | product_title_exclusion_enabled=bool(product_title_exclusion_cfg.get("enabled", True)), | ... | ... |
config/schema.py
| ... | ... | @@ -65,7 +65,8 @@ class QueryConfig: |
| 65 | 65 | translation_embedding_wait_budget_ms_source_in_index: int = 80 |
| 66 | 66 | translation_embedding_wait_budget_ms_source_not_in_index: int = 200 |
| 67 | 67 | style_intent_enabled: bool = True |
| 68 | - style_intent_terms: Dict[str, List[List[str]]] = field(default_factory=dict) | |
| 68 | + style_intent_selected_sku_boost: float = 1.2 | |
| 69 | + style_intent_terms: Dict[str, List[Dict[str, List[str]]]] = field(default_factory=dict) | |
| 69 | 70 | style_intent_dimension_aliases: Dict[str, List[str]] = field(default_factory=dict) |
| 70 | 71 | product_title_exclusion_enabled: bool = True |
| 71 | 72 | product_title_exclusion_rules: List[Dict[str, List[str]]] = field(default_factory=list) | ... | ... |
config/services_config.py
| ... | ... | @@ -7,6 +7,7 @@ contains no independent parsing or precedence logic. |
| 7 | 7 | |
| 8 | 8 | from __future__ import annotations |
| 9 | 9 | |
| 10 | +import os | |
| 10 | 11 | from typing import Any, Dict, Tuple |
| 11 | 12 | |
| 12 | 13 | from config.loader import get_app_config |
| ... | ... | @@ -61,6 +62,12 @@ def get_embedding_image_backend_config() -> Tuple[str, Dict[str, Any]]: |
| 61 | 62 | |
| 62 | 63 | def get_rerank_backend_config() -> Tuple[str, Dict[str, Any]]: |
| 63 | 64 | cfg = get_app_config().services.rerank |
| 65 | + backend = str(os.getenv("RERANK_BACKEND") or cfg.backend).strip() | |
| 66 | + if backend != cfg.backend: | |
| 67 | + backend_cfg = cfg.backends.get(backend) | |
| 68 | + if backend_cfg is None: | |
| 69 | + raise ValueError(f"Unknown rerank backend override from RERANK_BACKEND: {backend!r}") | |
| 70 | + return backend, dict(backend_cfg) | |
| 64 | 71 | return cfg.backend, cfg.get_backend_config() |
| 65 | 72 | |
| 66 | 73 | ... | ... |
perf_reports/reranker_vllm_instruction/2026-03-25/RESULTS.md
0 → 100644
| ... | ... | @@ -0,0 +1,61 @@ |
| 1 | +# Reranker benchmark: `qwen3_vllm` vs `qwen3_vllm_score` × `instruction_format` | |
| 2 | + | |
| 3 | +**Date:** 2026-03-25 | |
| 4 | +**Host:** single GPU (Tesla T4, ~16 GiB), CUDA 12.8 (see `nvidia-smi` during run). | |
| 5 | + | |
| 6 | +## Configuration (from `config/config.yaml`) | |
| 7 | + | |
| 8 | +Shared across both backends for this run: | |
| 9 | + | |
| 10 | +| Key | Value | | |
| 11 | +|-----|-------| | |
| 12 | +| `model_name` | `Qwen/Qwen3-Reranker-0.6B` | | |
| 13 | +| `max_model_len` | 160 | | |
| 14 | +| `infer_batch_size` | 100 | | |
| 15 | +| `sort_by_doc_length` | true | | |
| 16 | +| `enable_prefix_caching` | true | | |
| 17 | +| `enforce_eager` | false | | |
| 18 | +| `dtype` | float16 | | |
| 19 | +| `tensor_parallel_size` | 1 | | |
| 20 | +| `gpu_memory_utilization` | 0.20 | | |
| 21 | +| `instruction` | `Rank products by query with category & style match prioritized` | | |
| 22 | + | |
| 23 | +`qwen3_vllm` uses vLLM **generate + logprobs** (`.venv-reranker`). | |
| 24 | +`qwen3_vllm_score` uses vLLM **`LLM.score()`** (`.venv-reranker-score`, pinned vLLM stack per `reranker/README.md`). | |
| 25 | + | |
| 26 | +## Methodology | |
| 27 | + | |
| 28 | +- Script: `python scripts/benchmark_reranker_random_titles.py 100,200,400,600,800,1000 --repeat 5` with **`--seed 99`** (see note below), **`--quiet-runs`**, **`--timeout 360`**. | |
| 29 | +- Titles: default file `/home/ubuntu/rerank_test/titles.1.8w` (one title per line). | |
| 30 | +- Query: default `健身女生T恤短袖`. | |
| 31 | +- Each scenario: **3 warm-up** requests at `n=400` (not timed), then **5 timed** runs per `n`. | |
| 32 | +- Metric: **client wall time** for `POST /rerank` (localhost), milliseconds. | |
| 33 | +- After each `services.rerank.backend` / `instruction_format` change: `./restart.sh reranker`, then **`GET /health`** until `backend` and `instruction_format` matched the intended scenario (extended `reranker/server.py` to expose `instruction_format` when the backend defines `_instruction_format`). | |
| 34 | + | |
| 35 | +**Note on RNG seed:** With `--seed 42`, some runs occasionally lost one sample at `n=600` (non-200 or transport error). All figures below use **`--seed 99`** so every cell has **5/5** successful runs and comparable sampled titles. | |
| 36 | + | |
| 37 | +## Raw artifacts | |
| 38 | + | |
| 39 | +JSON aggregates (means, stdev, raw `values_ms`): same directory, `qwen3_vllm_{compact,standard}.json`, `qwen3_vllm_score_{compact,standard}.json`. | |
| 40 | + | |
| 41 | +## Results — mean latency (ms) | |
| 42 | + | |
| 43 | +| backend | instruction_format | n=100 | n=200 | n=400 | n=600 | n=800 | n=1000 | | |
| 44 | +|---------|-------------------|------:|------:|------:|------:|------:|-------:| | |
| 45 | +| `qwen3_vllm` | `compact` | 213.5 | 418.0 | 861.4 | 1263.4 | 1744.3 | 2162.2 | | |
| 46 | +| `qwen3_vllm` | `standard` | 254.9 | 475.4 | 909.7 | 1353.2 | 1912.5 | 2406.7 | | |
| 47 | +| `qwen3_vllm_score` | `compact` | 239.2 | 480.2 | 966.2 | 1433.5 | 1937.2 | 2428.4 | | |
| 48 | +| `qwen3_vllm_score` | `standard` | 299.6 | 591.8 | 1178.9 | 1773.7 | 2341.6 | 2931.7 | | |
| 49 | + | |
| 50 | +## Short interpretation | |
| 51 | + | |
| 52 | +1. **`compact` vs `standard`:** For both backends, **`compact` is faster** on this setup (shorter / different chat template vs fixed yes/no system prompt + user block — see `reranker/backends/qwen3_vllm.py` / `qwen3_vllm_score.py`). | |
| 53 | +2. **`qwen3_vllm` vs `qwen3_vllm_score`:** At **`n=1000`**, **`qwen3_vllm` + `compact`** is the fastest row (~2162 ms mean); **`qwen3_vllm_score` + `standard`** is the slowest (~2932 ms). Ordering can change on other GPUs / vLLM versions / batching. | |
| 54 | +3. **Repo default** after tests: `services.rerank.backend: qwen3_vllm_score`, `instruction_format: compact` on **both** `qwen3_vllm` and `qwen3_vllm_score` blocks (patch script keeps them aligned). | |
| 55 | + | |
| 56 | +## Tooling added / changed | |
| 57 | + | |
| 58 | +- `reranker/server.py`: `/health` includes `instruction_format` when the active backend sets `_instruction_format`. | |
| 59 | +- `scripts/benchmark_reranker_random_titles.py`: `--tag`, `--json-summary-out`, `--quiet-runs`. | |
| 60 | +- `scripts/patch_rerank_vllm_benchmark_config.py`: surgical YAML patch (preserves newlines). | |
| 61 | +- `scripts/run_reranker_vllm_instruction_benchmark.sh`: full matrix driver (continues if a benchmark exits non-zero; uses `--timeout 360`). | ... | ... |
query/style_intent.py
| ... | ... | @@ -11,38 +11,79 @@ from .tokenization import TokenizedText, normalize_query_text, tokenize_text |
| 11 | 11 | |
| 12 | 12 | |
| 13 | 13 | @dataclass(frozen=True) |
| 14 | +class StyleIntentTermDefinition: | |
| 15 | + canonical_value: str | |
| 16 | + en_terms: Tuple[str, ...] | |
| 17 | + zh_terms: Tuple[str, ...] | |
| 18 | + attribute_terms: Tuple[str, ...] | |
| 19 | + | |
| 20 | + | |
| 21 | +@dataclass(frozen=True) | |
| 14 | 22 | class StyleIntentDefinition: |
| 15 | 23 | intent_type: str |
| 16 | - term_groups: Tuple[Tuple[str, ...], ...] | |
| 24 | + terms: Tuple[StyleIntentTermDefinition, ...] | |
| 17 | 25 | dimension_aliases: Tuple[str, ...] |
| 18 | - synonym_to_canonical: Dict[str, str] | |
| 26 | + en_synonym_to_term: Dict[str, StyleIntentTermDefinition] | |
| 27 | + zh_synonym_to_term: Dict[str, StyleIntentTermDefinition] | |
| 19 | 28 | max_term_ngram: int = 3 |
| 20 | 29 | |
| 21 | 30 | @classmethod |
| 22 | 31 | def from_rows( |
| 23 | 32 | cls, |
| 24 | 33 | intent_type: str, |
| 25 | - rows: Sequence[Sequence[str]], | |
| 34 | + rows: Sequence[Dict[str, List[str]]], | |
| 26 | 35 | dimension_aliases: Sequence[str], |
| 27 | 36 | ) -> "StyleIntentDefinition": |
| 28 | - term_groups: List[Tuple[str, ...]] = [] | |
| 29 | - synonym_to_canonical: Dict[str, str] = {} | |
| 37 | + terms: List[StyleIntentTermDefinition] = [] | |
| 38 | + en_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {} | |
| 39 | + zh_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {} | |
| 30 | 40 | max_ngram = 1 |
| 31 | 41 | |
| 32 | 42 | for row in rows: |
| 33 | - normalized_terms: List[str] = [] | |
| 34 | - for raw_term in row: | |
| 35 | - term = normalize_query_text(raw_term) | |
| 36 | - if not term or term in normalized_terms: | |
| 37 | - continue | |
| 38 | - normalized_terms.append(term) | |
| 39 | - if not normalized_terms: | |
| 43 | + normalized_en = tuple( | |
| 44 | + dict.fromkeys( | |
| 45 | + term | |
| 46 | + for term in (normalize_query_text(raw) for raw in row.get("en_terms", [])) | |
| 47 | + if term | |
| 48 | + ) | |
| 49 | + ) | |
| 50 | + normalized_zh = tuple( | |
| 51 | + dict.fromkeys( | |
| 52 | + term | |
| 53 | + for term in (normalize_query_text(raw) for raw in row.get("zh_terms", [])) | |
| 54 | + if term | |
| 55 | + ) | |
| 56 | + ) | |
| 57 | + normalized_attribute = tuple( | |
| 58 | + dict.fromkeys( | |
| 59 | + term | |
| 60 | + for term in (normalize_query_text(raw) for raw in row.get("attribute_terms", [])) | |
| 61 | + if term | |
| 62 | + ) | |
| 63 | + ) | |
| 64 | + if not normalized_en and not normalized_zh and not normalized_attribute: | |
| 40 | 65 | continue |
| 41 | 66 | |
| 42 | - canonical = normalized_terms[0] | |
| 43 | - term_groups.append(tuple(normalized_terms)) | |
| 44 | - for term in normalized_terms: | |
| 45 | - synonym_to_canonical[term] = canonical | |
| 67 | + canonical = ( | |
| 68 | + normalized_attribute[0] | |
| 69 | + if normalized_attribute | |
| 70 | + else normalized_en[0] | |
| 71 | + if normalized_en | |
| 72 | + else normalized_zh[0] | |
| 73 | + ) | |
| 74 | + term_definition = StyleIntentTermDefinition( | |
| 75 | + canonical_value=canonical, | |
| 76 | + en_terms=normalized_en, | |
| 77 | + zh_terms=normalized_zh, | |
| 78 | + attribute_terms=normalized_attribute, | |
| 79 | + ) | |
| 80 | + terms.append(term_definition) | |
| 81 | + | |
| 82 | + for term in normalized_en: | |
| 83 | + en_synonym_to_term[term] = term_definition | |
| 84 | + max_ngram = max(max_ngram, len(term.split())) | |
| 85 | + for term in normalized_zh: | |
| 86 | + zh_synonym_to_term[term] = term_definition | |
| 46 | 87 | max_ngram = max(max_ngram, len(term.split())) |
| 47 | 88 | |
| 48 | 89 | aliases = tuple( |
| ... | ... | @@ -58,28 +99,31 @@ class StyleIntentDefinition: |
| 58 | 99 | |
| 59 | 100 | return cls( |
| 60 | 101 | intent_type=intent_type, |
| 61 | - term_groups=tuple(term_groups), | |
| 102 | + terms=tuple(terms), | |
| 62 | 103 | dimension_aliases=aliases, |
| 63 | - synonym_to_canonical=synonym_to_canonical, | |
| 104 | + en_synonym_to_term=en_synonym_to_term, | |
| 105 | + zh_synonym_to_term=zh_synonym_to_term, | |
| 64 | 106 | max_term_ngram=max_ngram, |
| 65 | 107 | ) |
| 66 | 108 | |
| 67 | - def match_candidates(self, candidates: Iterable[str]) -> Set[str]: | |
| 68 | - matched: Set[str] = set() | |
| 109 | + def match_candidates(self, candidates: Iterable[str], *, language: str) -> Set[StyleIntentTermDefinition]: | |
| 110 | + mapping = self.zh_synonym_to_term if language == "zh" else self.en_synonym_to_term | |
| 111 | + matched: Set[StyleIntentTermDefinition] = set() | |
| 69 | 112 | for candidate in candidates: |
| 70 | - canonical = self.synonym_to_canonical.get(normalize_query_text(candidate)) | |
| 71 | - if canonical: | |
| 72 | - matched.add(canonical) | |
| 113 | + term_definition = mapping.get(normalize_query_text(candidate)) | |
| 114 | + if term_definition: | |
| 115 | + matched.add(term_definition) | |
| 73 | 116 | return matched |
| 74 | 117 | |
| 75 | 118 | def match_text( |
| 76 | 119 | self, |
| 77 | 120 | text: str, |
| 78 | 121 | *, |
| 122 | + language: str, | |
| 79 | 123 | tokenizer: Optional[Callable[[str], Any]] = None, |
| 80 | - ) -> Set[str]: | |
| 124 | + ) -> Set[StyleIntentTermDefinition]: | |
| 81 | 125 | bundle = tokenize_text(text, tokenizer=tokenizer, max_ngram=self.max_term_ngram) |
| 82 | - return self.match_candidates(bundle.candidates) | |
| 126 | + return self.match_candidates(bundle.candidates, language=language) | |
| 83 | 127 | |
| 84 | 128 | |
| 85 | 129 | @dataclass(frozen=True) |
| ... | ... | @@ -88,6 +132,7 @@ class DetectedStyleIntent: |
| 88 | 132 | canonical_value: str |
| 89 | 133 | matched_term: str |
| 90 | 134 | matched_query_text: str |
| 135 | + attribute_terms: Tuple[str, ...] | |
| 91 | 136 | dimension_aliases: Tuple[str, ...] |
| 92 | 137 | |
| 93 | 138 | def to_dict(self) -> Dict[str, Any]: |
| ... | ... | @@ -96,6 +141,7 @@ class DetectedStyleIntent: |
| 96 | 141 | "canonical_value": self.canonical_value, |
| 97 | 142 | "matched_term": self.matched_term, |
| 98 | 143 | "matched_query_text": self.matched_query_text, |
| 144 | + "attribute_terms": list(self.attribute_terms), | |
| 99 | 145 | "dimension_aliases": list(self.dimension_aliases), |
| 100 | 146 | } |
| 101 | 147 | |
| ... | ... | @@ -159,7 +205,7 @@ class StyleIntentRegistry: |
| 159 | 205 | rows=rows or [], |
| 160 | 206 | dimension_aliases=dimension_aliases.get(intent_type, []), |
| 161 | 207 | ) |
| 162 | - if definition.synonym_to_canonical: | |
| 208 | + if definition.terms: | |
| 163 | 209 | definitions[definition.intent_type] = definition |
| 164 | 210 | |
| 165 | 211 | return cls( |
| ... | ... | @@ -191,15 +237,10 @@ class StyleIntentDetector: |
| 191 | 237 | seen = set() |
| 192 | 238 | variants: List[TokenizedText] = [] |
| 193 | 239 | texts = [ |
| 194 | - getattr(parsed_query, "original_query", None), | |
| 195 | - getattr(parsed_query, "query_normalized", None), | |
| 196 | - getattr(parsed_query, "rewritten_query", None), | |
| 240 | + self._get_language_query_text(parsed_query, "zh"), | |
| 241 | + self._get_language_query_text(parsed_query, "en"), | |
| 197 | 242 | ] |
| 198 | 243 | |
| 199 | - translations = getattr(parsed_query, "translations", {}) or {} | |
| 200 | - if isinstance(translations, dict): | |
| 201 | - texts.extend(translations.values()) | |
| 202 | - | |
| 203 | 244 | for raw_text in texts: |
| 204 | 245 | text = str(raw_text or "").strip() |
| 205 | 246 | if not text: |
| ... | ... | @@ -221,35 +262,66 @@ class StyleIntentDetector: |
| 221 | 262 | |
| 222 | 263 | return tuple(variants) |
| 223 | 264 | |
| 265 | + @staticmethod | |
| 266 | + def _get_language_query_text(parsed_query: Any, language: str) -> str: | |
| 267 | + translations = getattr(parsed_query, "translations", {}) or {} | |
| 268 | + if isinstance(translations, dict): | |
| 269 | + translated = translations.get(language) | |
| 270 | + if translated: | |
| 271 | + return str(translated) | |
| 272 | + return str(getattr(parsed_query, "original_query", "") or "") | |
| 273 | + | |
| 274 | + def _tokenize_language_query(self, parsed_query: Any, language: str) -> Optional[TokenizedText]: | |
| 275 | + text = self._get_language_query_text(parsed_query, language).strip() | |
| 276 | + if not text: | |
| 277 | + return None | |
| 278 | + return tokenize_text( | |
| 279 | + text, | |
| 280 | + tokenizer=self.tokenizer, | |
| 281 | + max_ngram=max( | |
| 282 | + (definition.max_term_ngram for definition in self.registry.definitions.values()), | |
| 283 | + default=3, | |
| 284 | + ), | |
| 285 | + ) | |
| 286 | + | |
| 224 | 287 | def detect(self, parsed_query: Any) -> StyleIntentProfile: |
| 225 | 288 | if not self.registry.enabled or not self.registry.definitions: |
| 226 | 289 | return StyleIntentProfile() |
| 227 | 290 | |
| 228 | 291 | query_variants = self._build_query_variants(parsed_query) |
| 292 | + zh_variant = self._tokenize_language_query(parsed_query, "zh") | |
| 293 | + en_variant = self._tokenize_language_query(parsed_query, "en") | |
| 229 | 294 | detected: List[DetectedStyleIntent] = [] |
| 230 | 295 | seen_pairs = set() |
| 231 | 296 | |
| 232 | - for variant in query_variants: | |
| 233 | - for intent_type, definition in self.registry.definitions.items(): | |
| 234 | - matched_canonicals = definition.match_candidates(variant.candidates) | |
| 235 | - if not matched_canonicals: | |
| 297 | + for intent_type, definition in self.registry.definitions.items(): | |
| 298 | + for language, variant, mapping in ( | |
| 299 | + ("zh", zh_variant, definition.zh_synonym_to_term), | |
| 300 | + ("en", en_variant, definition.en_synonym_to_term), | |
| 301 | + ): | |
| 302 | + if variant is None or not mapping: | |
| 303 | + continue | |
| 304 | + | |
| 305 | + matched_terms = definition.match_candidates(variant.candidates, language=language) | |
| 306 | + if not matched_terms: | |
| 236 | 307 | continue |
| 237 | 308 | |
| 238 | 309 | for candidate in variant.candidates: |
| 239 | 310 | normalized_candidate = normalize_query_text(candidate) |
| 240 | - canonical = definition.synonym_to_canonical.get(normalized_candidate) | |
| 241 | - if not canonical or canonical not in matched_canonicals: | |
| 311 | + term_definition = mapping.get(normalized_candidate) | |
| 312 | + if term_definition is None or term_definition not in matched_terms: | |
| 242 | 313 | continue |
| 243 | - pair = (intent_type, canonical) | |
| 314 | + pair = (intent_type, term_definition.canonical_value) | |
| 244 | 315 | if pair in seen_pairs: |
| 245 | 316 | continue |
| 246 | 317 | seen_pairs.add(pair) |
| 247 | 318 | detected.append( |
| 248 | 319 | DetectedStyleIntent( |
| 249 | 320 | intent_type=intent_type, |
| 250 | - canonical_value=canonical, | |
| 321 | + canonical_value=term_definition.canonical_value, | |
| 251 | 322 | matched_term=normalized_candidate, |
| 252 | 323 | matched_query_text=variant.text, |
| 324 | + attribute_terms=term_definition.attribute_terms, | |
| 253 | 325 | dimension_aliases=definition.dimension_aliases, |
| 254 | 326 | ) |
| 255 | 327 | ) | ... | ... |
requirements_reranker_qwen3_transformers_packed.txt
0 → 100644
| ... | ... | @@ -0,0 +1,9 @@ |
| 1 | +# Isolated dependencies for qwen3_transformers_packed reranker backend. | |
| 2 | +# | |
| 3 | +# Keep this stack aligned with the validated CUDA runtime on our hosts. | |
| 4 | +# On this machine, torch 2.11.0 + cu130 fails CUDA init, while torch 2.10.0 + cu128 works. | |
| 5 | +# We also cap transformers <5 to stay on the same family as the working vLLM score env. | |
| 6 | + | |
| 7 | +-r requirements_reranker_qwen3_transformers.txt | |
| 8 | +torch==2.10.0 | |
| 9 | +transformers>=4.51.0,<5 | ... | ... |
| ... | ... | @@ -0,0 +1,14 @@ |
| 1 | +# Dedicated high-performance venv for qwen3_vllm_score: .venv-reranker-score | |
| 2 | +# | |
| 3 | +# Create / refresh: | |
| 4 | +# ./scripts/setup_reranker_venv.sh qwen3_vllm_score | |
| 5 | +# | |
| 6 | +# vLLM 0.17+ replaces LLM(task="score") with runner/convert auto + LLM.score(). | |
| 7 | +# Pin vLLM for reproducible perf baselines; bump after validating CUDA/driver on your hosts. | |
| 8 | +# If pip cannot find a wheel for your CUDA version, edit the vllm line or install from: | |
| 9 | +# https://docs.vllm.ai/en/latest/getting_started/installation.html | |
| 10 | + | |
| 11 | +-r requirements_reranker_base.txt | |
| 12 | +vllm==0.18.0 | |
| 13 | +# Match vLLM 0.18 stack; cap <5 to avoid pip prefetching incompatible transformers 5.x. | |
| 14 | +transformers>=4.51.0,<5 | ... | ... |
requirements_reranker_service.txt
| 1 | -# Isolated dependencies for reranker service (.venv-reranker) | |
| 1 | +# Legacy alias: qwen3_vllm reranker service env (.venv-reranker). | |
| 2 | 2 | # |
| 3 | -# Default backend is qwen3_vllm (Qwen3-Reranker-0.6B). | |
| 3 | +# Prefer backend-specific requirements files: | |
| 4 | +# - requirements_reranker_qwen3_vllm.txt | |
| 5 | +# - requirements_reranker_qwen3_vllm_score.txt | |
| 6 | +# - requirements_reranker_qwen3_gguf.txt | |
| 7 | +# - requirements_reranker_qwen3_transformers.txt | |
| 8 | +# - requirements_reranker_bge.txt | |
| 9 | +# - requirements_reranker_dashscope.txt | |
| 4 | 10 | |
| 5 | -fastapi>=0.100.0 | |
| 6 | -uvicorn[standard]>=0.23.0 | |
| 7 | -pydantic>=2.0.0 | |
| 8 | -numpy>=1.24.0 | |
| 9 | -pyyaml>=6.0 | |
| 10 | -transformers>=4.30.0 | |
| 11 | -vllm>=0.8.5 | |
| 11 | +-r requirements_reranker_qwen3_vllm.txt | ... | ... |
reranker/DEPLOYMENT_AND_TUNING.md
| 1 | -# Reranker 部署与性能调优手册(Qwen3-vLLM) | |
| 1 | +# Reranker 部署与性能调优手册(Qwen3-vLLM / Qwen3-GGUF) | |
| 2 | 2 | |
| 3 | 3 | 本文档沉淀当前项目在电商搜索重排场景下的可复用实践,覆盖: |
| 4 | 4 | |
| 5 | 5 | - 环境准备与安装部署 |
| 6 | -- `qwen3_vllm` 配置项与优化思路 | |
| 6 | +- `qwen3_vllm` / `qwen3_gguf` / `qwen3_gguf_06b` 配置项与优化思路 | |
| 7 | 7 | - 1000-doc 场景压测流程 |
| 8 | 8 | - 关键结论与推荐默认参数 |
| 9 | 9 | - 常见故障排查 |
| 10 | 10 | |
| 11 | 11 | 适用范围: |
| 12 | 12 | |
| 13 | -- 重排后端:`services.rerank.backend: qwen3_vllm` | |
| 14 | -- 模型:`Qwen/Qwen3-Reranker-0.6B` | |
| 13 | +- 重排后端:`services.rerank.backend: qwen3_vllm` / `qwen3_gguf` / `qwen3_gguf_06b` | |
| 14 | +- 模型:`Qwen/Qwen3-Reranker-0.6B` / `DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF` / `ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF` | |
| 15 | 15 | - 场景:query 较短(通常 < 100 tokens),doc 为商品标题或标题+简短描述,单请求 docs 约 1000 条 |
| 16 | 16 | |
| 17 | 17 | ## 1. 环境基线 |
| 18 | 18 | |
| 19 | -当前验证环境(2026-03-11): | |
| 19 | +当前验证环境(2026-03-25): | |
| 20 | 20 | |
| 21 | 21 | - GPU:`Tesla T4 16GB` |
| 22 | 22 | - Driver / CUDA:`570.158.01 / 12.8` |
| 23 | 23 | - Python:`3.12.3` |
| 24 | -- 关键依赖:`vllm==0.17.0`、`torch==2.10.0+cu128`、`transformers==4.57.6`、`fastapi==0.135.1`、`uvicorn==0.41.0` | |
| 24 | +- 关键依赖:`vllm==0.17.0`、`torch==2.10.0+cu128`、`transformers==4.57.6`、`llama-cpp-python>=0.3.16`、`fastapi==0.135.1`、`uvicorn==0.41.0` | |
| 25 | 25 | |
| 26 | 26 | ## 2. 环境准备与安装 |
| 27 | 27 | |
| 28 | 28 | ### 2.1 准备 reranker 独立虚拟环境 |
| 29 | 29 | |
| 30 | 30 | ```bash |
| 31 | -./scripts/setup_reranker_venv.sh | |
| 31 | +./scripts/setup_reranker_venv.sh qwen3_vllm | |
| 32 | +``` | |
| 33 | + | |
| 34 | +若使用 GGUF 并需要 CUDA: | |
| 35 | + | |
| 36 | +```bash | |
| 37 | +./scripts/setup_reranker_venv.sh qwen3_gguf | |
| 38 | +PATH=/usr/local/cuda/bin:$PATH \ | |
| 39 | +CUDACXX=/usr/local/cuda/bin/nvcc \ | |
| 40 | +CMAKE_ARGS="-DGGML_CUDA=on" \ | |
| 41 | +FORCE_CMAKE=1 \ | |
| 42 | +./.venv-reranker-gguf/bin/pip install --no-cache-dir --force-reinstall --no-build-isolation llama-cpp-python==0.3.18 | |
| 32 | 43 | ``` |
| 33 | 44 | |
| 34 | 45 | ### 2.2 基础检查 |
| ... | ... | @@ -37,6 +48,7 @@ |
| 37 | 48 | nvidia-smi |
| 38 | 49 | ./.venv-reranker/bin/python -c "import torch; print(torch.cuda.is_available())" |
| 39 | 50 | ./.venv-reranker/bin/python -c "import vllm, transformers; print(vllm.__version__, transformers.__version__)" |
| 51 | +./.venv-reranker-gguf/bin/python -c "import llama_cpp; print(llama_cpp.__version__)" | |
| 40 | 52 | ``` |
| 41 | 53 | |
| 42 | 54 | ## 3. 部署与运行 |
| ... | ... | @@ -64,6 +76,29 @@ services: |
| 64 | 76 | length_sort_mode: "char" # char | token |
| 65 | 77 | ``` |
| 66 | 78 | |
| 79 | +GGUF / T4 剩余显存约 `4.8~6GB` 时,推荐基线: | |
| 80 | + | |
| 81 | +```yaml | |
| 82 | +services: | |
| 83 | + rerank: | |
| 84 | + backend: "qwen3_gguf" | |
| 85 | + backends: | |
| 86 | + qwen3_gguf: | |
| 87 | + repo_id: "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF" | |
| 88 | + filename: "*Q8_0.gguf" | |
| 89 | + local_dir: "./models/reranker/qwen3-reranker-4b-gguf" | |
| 90 | + cache_dir: "./model_cache" | |
| 91 | + n_ctx: 384 | |
| 92 | + n_batch: 384 | |
| 93 | + n_ubatch: 128 | |
| 94 | + n_gpu_layers: 24 | |
| 95 | + flash_attn: true | |
| 96 | + offload_kqv: true | |
| 97 | + infer_batch_size: 8 | |
| 98 | + sort_by_doc_length: true | |
| 99 | + length_sort_mode: "char" | |
| 100 | +``` | |
| 101 | + | |
| 67 | 102 | ### 3.2 启停命令 |
| 68 | 103 | |
| 69 | 104 | 推荐统一使用: |
| ... | ... | @@ -105,6 +140,13 @@ curl -sS http://127.0.0.1:6007/health |
| 105 | 140 | - `service_ctl.sh` 对 reranker 使用独立启动路径 |
| 106 | 141 | - 增加“稳定健康检查”(连续健康探测)避免“刚 healthy 即退出”的假阳性 |
| 107 | 142 | |
| 143 | +### 4.4 GGUF / T4 小显存优化原则 | |
| 144 | + | |
| 145 | +- `Q8_0` 权重约 `4.28GB`,但还要给 KV cache、CUDA 工作区和运行时碎片预留空间,不能按“模型大小 < 剩余显存”直接判断可行。 | |
| 146 | +- 当前业务是短 query + 商品标题,优先压缩 `n_ctx`;`384` 通常比默认长上下文更划算。 | |
| 147 | +- T4 小显存下先扫 `n_gpu_layers`,再尝试提高 `n_ctx`;`infer_batch_size` 在当前 GGUF 接入里主要是服务侧 work chunk,不是 llama.cpp 的真实算子 batch。 | |
| 148 | +- `flash_attn: true`、`offload_kqv: true` 默认保持开启;若 OOM,优先降低 `n_gpu_layers`。 | |
| 149 | + | |
| 108 | 150 | ## 5. 性能调优流程(标准流程) |
| 109 | 151 | |
| 110 | 152 | ### 5.1 使用一键压测脚本 |
| ... | ... | @@ -125,6 +167,13 @@ curl -sS http://127.0.0.1:6007/health |
| 125 | 167 | - `infer_batch_size`: `24 32 48 64` |
| 126 | 168 | - 并发组:`c=1`(看单请求延迟)、`c=4`(看并发吞吐与尾延迟) |
| 127 | 169 | |
| 170 | +GGUF 建议扫描: | |
| 171 | + | |
| 172 | +- `n_gpu_layers`: `20 24 28` | |
| 173 | +- `n_ctx`: `320 384 448` | |
| 174 | +- `infer_batch_size`: `4 8 12`(次要,仅影响服务侧 work chunk) | |
| 175 | +- 扫描顺序:先固定 `n_ctx=384`,找能稳定启动的最大 `n_gpu_layers`;再在显存允许时尝试 `n_ctx=448`;最后才微调 `infer_batch_size` | |
| 176 | + | |
| 128 | 177 | 可通过环境变量覆盖: |
| 129 | 178 | |
| 130 | 179 | - `BATCH_SIZES` |
| ... | ... | @@ -140,23 +189,28 @@ curl -sS http://127.0.0.1:6007/health |
| 140 | 189 | - `RERANK_VLLM_INFER_BATCH_SIZE` |
| 141 | 190 | - `RERANK_VLLM_SORT_BY_DOC_LENGTH` |
| 142 | 191 | |
| 143 | -## 6. 本轮关键结论(2026-03-11) | |
| 144 | - | |
| 145 | -基于报告: | |
| 146 | - | |
| 147 | -- `perf_reports/20260311/reranker_1000docs/report.md` | |
| 192 | +## 6. 本轮关键结论 | |
| 148 | 193 | |
| 149 | -结论: | |
| 194 | +vLLM(2026-03-11,见 `perf_reports/20260311/reranker_1000docs/report.md`): | |
| 150 | 195 | |
| 151 | 196 | - 对在线重排更重要的单请求延迟(`c=1`)指标,`infer_batch_size=64` 最优 |
| 152 | 197 | - `infer_batch_size=96` 在更高并发下吞吐略高,但会牺牲单请求延迟稳定性 |
| 153 | 198 | - 当前默认选择 `infer_batch_size=64` 作为平衡点 |
| 154 | 199 | |
| 200 | +GGUF(2026-03-25,本次接入): | |
| 201 | + | |
| 202 | +- `DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF` 的 `Q8_0` 体积约 `4.28GB`,结合当前机器实测剩余显存约 `4823 MiB`,默认不采用激进的全量 GPU offload。 | |
| 203 | +- 当前推荐默认值:`n_ctx=384`、`n_batch=384`、`n_ubatch=128`、`n_gpu_layers=24`、`infer_batch_size=8`。 | |
| 204 | +- 若现场剩余显存更接近 `6GB` 且碎片较少,可优先尝试 `n_gpu_layers=28`;若启动失败,回退到 `24` 或 `20`。 | |
| 205 | +- 由于当前工作区尚未缓存该 GGUF 权重,本次尚未完成真实吞吐压测;上线前需在部署机复跑一轮参数扫描并归档报告。 | |
| 206 | + | |
| 155 | 207 | ## 7. 生产建议 |
| 156 | 208 | |
| 157 | 209 | - 默认保持:`infer_batch_size: 64`、`sort_by_doc_length: true` |
| 158 | 210 | - 满足以下条件时可考虑提高到 `96`:业务以吞吐优先、可接受更高单请求延迟、已通过同机同数据压测验证收益 |
| 159 | 211 | - 每次改动后都必须复跑 `benchmark_reranker_1000docs.sh` 并归档结果 |
| 212 | +- GGUF 默认保持:`n_ctx: 384`、`n_gpu_layers: 24`、`infer_batch_size: 8`、`flash_attn: true`、`offload_kqv: true` | |
| 213 | +- GGUF 若 OOM:先降 `n_gpu_layers`,再降 `n_ctx`,最后再降 `infer_batch_size` | |
| 160 | 214 | |
| 161 | 215 | ## 8. 故障排查 |
| 162 | 216 | |
| ... | ... | @@ -194,6 +248,13 @@ lsof -i :6007 -P -n |
| 194 | 248 | - 降低 `infer_batch_size` |
| 195 | 249 | - 检查是否有其他进程占用同卡 |
| 196 | 250 | |
| 251 | +GGUF 优先调整: | |
| 252 | + | |
| 253 | +- 降低 `n_gpu_layers` | |
| 254 | +- 降低 `n_ctx` | |
| 255 | +- 降低 `infer_batch_size` | |
| 256 | +- 检查是否有其他进程占用同卡 | |
| 257 | + | |
| 197 | 258 | ## 9. 变更与验证清单 |
| 198 | 259 | |
| 199 | 260 | 每次 reranker 调优改动后,至少完成: | ... | ... |
| ... | ... | @@ -0,0 +1,154 @@ |
| 1 | +# Qwen3-Reranker-0.6B GGUF 安装与调优 | |
| 2 | + | |
| 3 | +本文档覆盖 `qwen3_gguf_06b` 后端,对应模型: | |
| 4 | + | |
| 5 | +- Hugging Face: `ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF` | |
| 6 | +- 文件: `qwen3-reranker-0.6b-q8_0.gguf` | |
| 7 | +- 本地目录: `./models/reranker/qwen3-reranker-0.6b-q8_0-gguf` | |
| 8 | + | |
| 9 | +## 结论先看 | |
| 10 | + | |
| 11 | +这个后端已经接入完成,也能正常使用 GPU offload,但不适合当前项目的在线主链路场景。 | |
| 12 | + | |
| 13 | +目标场景是: | |
| 14 | + | |
| 15 | +- 1 个 query | |
| 16 | +- 400 个商品标题 | |
| 17 | +- 追求最短响应时间 | |
| 18 | + | |
| 19 | +实测最优配置下: | |
| 20 | + | |
| 21 | +- GPU 显存占用约 `894 MiB` | |
| 22 | +- 400 titles 单请求延迟约 `265318 ms` | |
| 23 | + | |
| 24 | +因此它更适合作为: | |
| 25 | + | |
| 26 | +- 低显存 fallback | |
| 27 | +- 功能验证 | |
| 28 | +- 本地离线实验 | |
| 29 | + | |
| 30 | +不建议作为在线低延迟 reranker 主 backend。 | |
| 31 | + | |
| 32 | +## 独立环境 | |
| 33 | + | |
| 34 | +`qwen3_gguf_06b` 使用独立 venv: | |
| 35 | + | |
| 36 | +- backend: `qwen3_gguf_06b` | |
| 37 | +- venv: `.venv-reranker-gguf-06b` | |
| 38 | +- requirements: `requirements_reranker_qwen3_gguf_06b.txt` | |
| 39 | + | |
| 40 | +安装: | |
| 41 | + | |
| 42 | +```bash | |
| 43 | +./scripts/setup_reranker_venv.sh qwen3_gguf_06b | |
| 44 | +``` | |
| 45 | + | |
| 46 | +如果需要确认是 CUDA 版 `llama-cpp-python`: | |
| 47 | + | |
| 48 | +```bash | |
| 49 | +./.venv-reranker-gguf-06b/bin/python - <<'PY' | |
| 50 | +import llama_cpp | |
| 51 | +print(llama_cpp.llama_supports_gpu_offload()) | |
| 52 | +PY | |
| 53 | +``` | |
| 54 | + | |
| 55 | +预期输出: | |
| 56 | + | |
| 57 | +```python | |
| 58 | +True | |
| 59 | +``` | |
| 60 | + | |
| 61 | +## 模型下载 | |
| 62 | + | |
| 63 | +推荐预先下载到本地,避免首次服务启动时在线拉取: | |
| 64 | + | |
| 65 | +```bash | |
| 66 | +mkdir -p models/reranker/qwen3-reranker-0.6b-q8_0-gguf | |
| 67 | +curl -L --fail -C - \ | |
| 68 | + -o models/reranker/qwen3-reranker-0.6b-q8_0-gguf/qwen3-reranker-0.6b-q8_0.gguf \ | |
| 69 | + 'https://huggingface.co/ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/resolve/main/qwen3-reranker-0.6b-q8_0.gguf?download=true' | |
| 70 | +``` | |
| 71 | + | |
| 72 | +当前实测文件大小: | |
| 73 | + | |
| 74 | +- `639153184` bytes | |
| 75 | + | |
| 76 | +## 推荐配置 | |
| 77 | + | |
| 78 | +`config/config.yaml` 中建议保留: | |
| 79 | + | |
| 80 | +```yaml | |
| 81 | +qwen3_gguf_06b: | |
| 82 | + repo_id: "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF" | |
| 83 | + filename: "qwen3-reranker-0.6b-q8_0.gguf" | |
| 84 | + local_dir: "./models/reranker/qwen3-reranker-0.6b-q8_0-gguf" | |
| 85 | + cache_dir: "./model_cache" | |
| 86 | + instruction: "Rank products by query with category & style match prioritized" | |
| 87 | + n_ctx: 256 | |
| 88 | + n_batch: 256 | |
| 89 | + n_ubatch: 256 | |
| 90 | + n_gpu_layers: 999 | |
| 91 | + main_gpu: 0 | |
| 92 | + n_threads: 2 | |
| 93 | + n_threads_batch: 4 | |
| 94 | + flash_attn: true | |
| 95 | + offload_kqv: true | |
| 96 | + use_mmap: true | |
| 97 | + use_mlock: false | |
| 98 | + infer_batch_size: 32 | |
| 99 | + sort_by_doc_length: true | |
| 100 | + length_sort_mode: "char" | |
| 101 | + reuse_query_state: false | |
| 102 | + enable_warmup: true | |
| 103 | + verbose: false | |
| 104 | +``` | |
| 105 | + | |
| 106 | +## 调优结果 | |
| 107 | + | |
| 108 | +在当前机器上做了同机实测。标题文件来自 `/home/ubuntu/rerank_test/titles.1.8w`,查询为 `白色oversized T-shirt`。 | |
| 109 | + | |
| 110 | +80 titles: | |
| 111 | + | |
| 112 | +- `n_ctx=256, reuse_query_state=true` -> `60108 ms` | |
| 113 | +- `n_ctx=256, reuse_query_state=false` -> `53383~56893 ms` | |
| 114 | +- `n_ctx=320, reuse_query_state=true` -> `60961 ms` | |
| 115 | +- `n_ctx=384, reuse_query_state=true` -> `56578 ms` | |
| 116 | +- `n_ctx=384, reuse_query_state=false` -> `57272 ms` | |
| 117 | +- `n_ctx=512, reuse_query_state=false` -> `60542 ms` | |
| 118 | +- `n_ctx=256, reuse_query_state=false, n_threads=4, n_threads_batch=8` -> `61228 ms` | |
| 119 | + | |
| 120 | +400 titles: | |
| 121 | + | |
| 122 | +- `n_ctx=256, n_batch=256, n_ubatch=256, n_gpu_layers=999, reuse_query_state=false` | |
| 123 | + -> `265318 ms` | |
| 124 | + | |
| 125 | +## 经验沉淀 | |
| 126 | + | |
| 127 | +这次接入最重要的结论不是“哪个小参数更快”,而是: | |
| 128 | + | |
| 129 | +1. 这个 0.6B GGUF 权重虽然小,但当前后端实现仍是逐 doc 顺序打分。 | |
| 130 | +2. 对在线 400-title 请求来说,串行打分本身就是主瓶颈。 | |
| 131 | +3. `reuse_query_state` 在这个模型上没有带来收益,反而更慢。 | |
| 132 | +4. `n_ctx` 拉大到 `384/512` 也没有带来实质收益,反而更慢或持平。 | |
| 133 | +5. 这个 backend 的优势是低显存,不是低延迟。 | |
| 134 | + | |
| 135 | +如果目标是在线最短响应时间,优先级建议是: | |
| 136 | + | |
| 137 | +1. `qwen3_vllm` | |
| 138 | +2. 其他真正支持高吞吐批处理的后端 | |
| 139 | +3. `qwen3_gguf_06b` 仅作为低显存 fallback | |
| 140 | + | |
| 141 | +## 验证命令 | |
| 142 | + | |
| 143 | +本地直连 backend 调优: | |
| 144 | + | |
| 145 | +```bash | |
| 146 | +PYTHONPATH=/data/saas-search ./.venv-reranker-gguf/bin/python \ | |
| 147 | + scripts/benchmark_reranker_gguf_local.py --backend-name qwen3_gguf_06b --docs 400 | |
| 148 | +``` | |
| 149 | + | |
| 150 | +按服务方式启动: | |
| 151 | + | |
| 152 | +```bash | |
| 153 | +RERANK_BACKEND=qwen3_gguf_06b ./scripts/start_reranker.sh | |
| 154 | +``` | ... | ... |
| ... | ... | @@ -0,0 +1,280 @@ |
| 1 | +# Qwen3 GGUF 安装与调优手册 | |
| 2 | + | |
| 3 | +本文档只覆盖 `qwen3_gguf` 后端,目标机器为当前项目实测环境: | |
| 4 | + | |
| 5 | +- GPU: `Tesla T4 16GB` | |
| 6 | +- CUDA: `12.8` | |
| 7 | +- 模型: `DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF` | |
| 8 | +- 量化: `Q8_0` | |
| 9 | + | |
| 10 | +--- | |
| 11 | + | |
| 12 | +## 1. 结论先看 | |
| 13 | + | |
| 14 | +当前这套代码里,GGUF 后端的主要瓶颈不是“显存没吃满”,而是 **llama.cpp 按 doc 顺序逐条打分**。因此最有效的优化策略是: | |
| 15 | + | |
| 16 | +- 让模型层尽可能全部 offload 到 GPU | |
| 17 | +- 打开 `flash_attn` / `offload_kqv` | |
| 18 | +- 把 `n_ctx / n_batch / n_ubatch` 调到一个对短标题重排更合适的高效点 | |
| 19 | + | |
| 20 | +本轮在当前机器上的推荐配置是: | |
| 21 | + | |
| 22 | +```yaml | |
| 23 | +qwen3_gguf: | |
| 24 | + n_ctx: 512 | |
| 25 | + n_batch: 512 | |
| 26 | + n_ubatch: 512 | |
| 27 | + n_gpu_layers: 999 | |
| 28 | + n_threads: 2 | |
| 29 | + n_threads_batch: 4 | |
| 30 | + flash_attn: true | |
| 31 | + offload_kqv: true | |
| 32 | + infer_batch_size: 8 | |
| 33 | + sort_by_doc_length: true | |
| 34 | + length_sort_mode: "char" | |
| 35 | +``` | |
| 36 | + | |
| 37 | +说明: | |
| 38 | + | |
| 39 | +- `n_gpu_layers: 999` 在 llama.cpp 中等价于“尽可能全部层都 offload” | |
| 40 | +- 这台 T4 上,**即使全量 offload,当前模型也只占到约 `4.5 GiB` GPU 显存** | |
| 41 | +- 所以“允许 8G 显存”并不会自动带来更高速度;这个模型/后端在当前工作负载下已经接近“该用到的权重都上 GPU 了” | |
| 42 | + | |
| 43 | +--- | |
| 44 | + | |
| 45 | +## 2. 独立环境 | |
| 46 | + | |
| 47 | +`qwen3_gguf` 必须使用自己的独立 venv: | |
| 48 | + | |
| 49 | +- `qwen3_vllm` -> `.venv-reranker` | |
| 50 | +- `qwen3_gguf` -> `.venv-reranker-gguf` | |
| 51 | + | |
| 52 | +安装命令: | |
| 53 | + | |
| 54 | +```bash | |
| 55 | +./scripts/setup_reranker_venv.sh qwen3_gguf | |
| 56 | +``` | |
| 57 | + | |
| 58 | +脚本现在会自动做两件事: | |
| 59 | + | |
| 60 | +1. 安装 GGUF 后端所需 Python 依赖 | |
| 61 | +2. 在检测到 `/usr/local/cuda/bin/nvcc` 时,把 `llama-cpp-python` **重编译成 CUDA 版** | |
| 62 | + | |
| 63 | +--- | |
| 64 | + | |
| 65 | +## 3. GPU 版验证 | |
| 66 | + | |
| 67 | +必须验证不是 CPU-only 版: | |
| 68 | + | |
| 69 | +```bash | |
| 70 | +./.venv-reranker-gguf/bin/python - <<'PY' | |
| 71 | +import llama_cpp | |
| 72 | +print("supports_gpu_offload =", llama_cpp.llama_supports_gpu_offload()) | |
| 73 | +PY | |
| 74 | +``` | |
| 75 | + | |
| 76 | +正确结果应为: | |
| 77 | + | |
| 78 | +```text | |
| 79 | +supports_gpu_offload = True | |
| 80 | +``` | |
| 81 | + | |
| 82 | +还可以看动态库: | |
| 83 | + | |
| 84 | +```bash | |
| 85 | +ldd .venv-reranker-gguf/lib/python3.12/site-packages/llama_cpp/lib/libllama.so | rg 'cuda|cublas|ggml-cuda' | |
| 86 | +``` | |
| 87 | + | |
| 88 | +应能看到: | |
| 89 | + | |
| 90 | +- `libggml-cuda.so` | |
| 91 | +- `libcudart.so` | |
| 92 | +- `libcublas.so` | |
| 93 | + | |
| 94 | +--- | |
| 95 | + | |
| 96 | +## 4. 模型下载 | |
| 97 | + | |
| 98 | +当前使用本地文件优先策略,模型放在: | |
| 99 | + | |
| 100 | +```text | |
| 101 | +models/reranker/qwen3-reranker-4b-gguf/Qwen.Qwen3-Reranker-4B.Q8_0.gguf | |
| 102 | +``` | |
| 103 | + | |
| 104 | +若本地文件存在,后端会直接加载本地 GGUF,不再依赖启动时在线下载。 | |
| 105 | + | |
| 106 | +为了避免当前机器上 Hugging Face Xet 下载的 `416 Range Not Satisfiable` 问题,`start_reranker.sh` 已对 `qwen3_gguf` 默认设置: | |
| 107 | + | |
| 108 | +```bash | |
| 109 | +HF_HUB_DISABLE_XET=1 | |
| 110 | +``` | |
| 111 | + | |
| 112 | +--- | |
| 113 | + | |
| 114 | +## 5. 本地调优脚本 | |
| 115 | + | |
| 116 | +新增本地基准脚本: | |
| 117 | + | |
| 118 | +```bash | |
| 119 | +PYTHONPATH=/data/saas-search ./.venv-reranker-gguf/bin/python \ | |
| 120 | + scripts/benchmark_reranker_gguf_local.py --docs 64 --repeat 1 | |
| 121 | +``` | |
| 122 | + | |
| 123 | +它会直接实例化 GGUF backend,输出: | |
| 124 | + | |
| 125 | +- 模型加载耗时 | |
| 126 | +- 当前进程 GPU 显存占用 | |
| 127 | +- 单次 rerank 延迟 | |
| 128 | + | |
| 129 | +--- | |
| 130 | + | |
| 131 | +## 6. 本轮实测结果 | |
| 132 | + | |
| 133 | +测试条件: | |
| 134 | + | |
| 135 | +- Query: `白色oversized T-shirt` | |
| 136 | +- Docs: `64` 条商品标题 | |
| 137 | +- 本地脚本:`scripts/benchmark_reranker_gguf_local.py` | |
| 138 | +- 每组 1 次,重点比较相对趋势 | |
| 139 | + | |
| 140 | +结果: | |
| 141 | + | |
| 142 | +### 6.1 保守配置 | |
| 143 | + | |
| 144 | +```text | |
| 145 | +n_ctx=384 | |
| 146 | +n_batch=384 | |
| 147 | +n_ubatch=128 | |
| 148 | +n_gpu_layers=24 | |
| 149 | +``` | |
| 150 | + | |
| 151 | +- GPU 显存:`2984 MiB` | |
| 152 | +- 64 docs 延迟:`74347.91 ms` | |
| 153 | + | |
| 154 | +### 6.2 全量 offload | |
| 155 | + | |
| 156 | +```text | |
| 157 | +n_ctx=384 | |
| 158 | +n_batch=384 | |
| 159 | +n_ubatch=128 | |
| 160 | +n_gpu_layers=999 | |
| 161 | +``` | |
| 162 | + | |
| 163 | +- GPU 显存:`4338 MiB` | |
| 164 | +- 64 docs 延迟:`51401.77 ms` | |
| 165 | + | |
| 166 | +### 6.3 最优配置 | |
| 167 | + | |
| 168 | +```text | |
| 169 | +n_ctx=512 | |
| 170 | +n_batch=512 | |
| 171 | +n_ubatch=512 | |
| 172 | +n_gpu_layers=999 | |
| 173 | +``` | |
| 174 | + | |
| 175 | +- GPU 显存:`4564 MiB` | |
| 176 | +- 64 docs 延迟:`49116.10 ms` | |
| 177 | + | |
| 178 | +### 6.4 其它尝试 | |
| 179 | + | |
| 180 | +`n_threads=4 / n_threads_batch=8`: | |
| 181 | + | |
| 182 | +- GPU 显存:`4564 MiB` | |
| 183 | +- 64 docs 延迟:`49895.88 ms` | |
| 184 | +- 比推荐值略慢 | |
| 185 | + | |
| 186 | +`infer_batch_size=64`: | |
| 187 | + | |
| 188 | +- GPU 显存:`4564 MiB` | |
| 189 | +- 64 docs 延迟:`50723.36 ms` | |
| 190 | +- 也略慢 | |
| 191 | + | |
| 192 | +### 6.5 API 级验证 | |
| 193 | + | |
| 194 | +在把推荐配置写入 `config/config.yaml` 并重启服务后,使用: | |
| 195 | + | |
| 196 | +```bash | |
| 197 | +RERANK_BASE=http://127.0.0.1:6007 \ | |
| 198 | + ./.venv/bin/python scripts/benchmark_reranker_random_titles.py 64 --repeat 1 --query '白色oversized T-shirt' | |
| 199 | +``` | |
| 200 | + | |
| 201 | +得到: | |
| 202 | + | |
| 203 | +- `64 docs`:`50177.22 ms` | |
| 204 | + | |
| 205 | +再用: | |
| 206 | + | |
| 207 | +```bash | |
| 208 | +RERANK_BASE=http://127.0.0.1:6007 \ | |
| 209 | + ./.venv/bin/python scripts/benchmark_reranker_random_titles.py 153 --repeat 1 --query '白色oversized T-shirt' | |
| 210 | +``` | |
| 211 | + | |
| 212 | +得到: | |
| 213 | + | |
| 214 | +- `153 docs`:`115328.60 ms` | |
| 215 | + | |
| 216 | +对比旧日志中的保守配置: | |
| 217 | + | |
| 218 | +- 旧配置 `153 docs`:`153435.37 ms` | |
| 219 | +- 新配置 `153 docs`:`115328.60 ms` | |
| 220 | + | |
| 221 | +改善幅度约: | |
| 222 | + | |
| 223 | +- `24.8%` | |
| 224 | + | |
| 225 | +--- | |
| 226 | + | |
| 227 | +## 7. 为什么没有吃到 8G | |
| 228 | + | |
| 229 | +结论很重要: | |
| 230 | + | |
| 231 | +- 当前最优配置已经是“尽可能全量层 offload” | |
| 232 | +- 该 `Q8_0` 模型在这套 llama.cpp / T4 / 短文本重排场景下,**实测只需要约 `4.5 GiB` GPU 显存** | |
| 233 | +- 继续为了“吃满 8G”去增大 `n_ctx`,不会明显提升吞吐,反而可能带来额外开销 | |
| 234 | + | |
| 235 | +所以本轮不是“显存太保守”,而是: | |
| 236 | + | |
| 237 | +- 可 offload 的权重已经基本 offload 完了 | |
| 238 | +- 真正拖慢响应的是 **逐 doc 顺序推理** 这一后端实现路径 | |
| 239 | + | |
| 240 | +--- | |
| 241 | + | |
| 242 | +## 8. 生产建议 | |
| 243 | + | |
| 244 | +### 8.1 当前建议 | |
| 245 | + | |
| 246 | +保留以下参数: | |
| 247 | + | |
| 248 | +```yaml | |
| 249 | +n_ctx: 512 | |
| 250 | +n_batch: 512 | |
| 251 | +n_ubatch: 512 | |
| 252 | +n_gpu_layers: 999 | |
| 253 | +n_threads: 2 | |
| 254 | +n_threads_batch: 4 | |
| 255 | +flash_attn: true | |
| 256 | +offload_kqv: true | |
| 257 | +``` | |
| 258 | + | |
| 259 | +### 8.2 如果还嫌慢 | |
| 260 | + | |
| 261 | +优先级建议: | |
| 262 | + | |
| 263 | +1. 缩小 `rerank_window` | |
| 264 | +2. 减少传入 doc 数 | |
| 265 | +3. 若业务允许,切换到更适合高吞吐的后端 | |
| 266 | + | |
| 267 | +原因: | |
| 268 | + | |
| 269 | +- 当前 GGUF 后端是本地单进程、逐 doc 打分 | |
| 270 | +- 对长列表重排,它天然不如 vLLM / 云端 rerank API 擅长吞吐 | |
| 271 | + | |
| 272 | +--- | |
| 273 | + | |
| 274 | +## 9. 本轮落地文件 | |
| 275 | + | |
| 276 | +- `config/config.yaml` | |
| 277 | +- `scripts/setup_reranker_venv.sh` | |
| 278 | +- `scripts/start_reranker.sh` | |
| 279 | +- `scripts/benchmark_reranker_gguf_local.py` | |
| 280 | +- `reranker/GGUF_INSTALL_AND_TUNING.md` | ... | ... |
reranker/README.md
| 1 | 1 | # Reranker 模块 |
| 2 | 2 | |
| 3 | -**请求示例**见 `docs/QUICKSTART.md` §3.5。扩展规范见 `docs/DEVELOPER_GUIDE.md` §7。部署与调优实战见 `reranker/DEPLOYMENT_AND_TUNING.md`。 | |
| 3 | +**请求示例**见 `docs/QUICKSTART.md` §3.5。扩展规范见 `docs/DEVELOPER_GUIDE.md` §7。部署与调优实战见 `reranker/DEPLOYMENT_AND_TUNING.md`。`ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF` 的专项接入与调优结论见 `reranker/GGUF_0_6B_INSTALL_AND_TUNING.md`。 | |
| 4 | 4 | |
| 5 | 5 | --- |
| 6 | 6 | |
| 7 | -Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM、Qwen3-Transformers、DashScope 云重排)。调用方通过 HTTP 访问,不关心具体后端。 | |
| 7 | +Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM、Qwen3-Transformers、Qwen3-GGUF、DashScope 云重排)。调用方通过 HTTP 访问,不关心具体后端。 | |
| 8 | 8 | |
| 9 | 9 | **特性** |
| 10 | -- 多后端:`qwen3_vllm`(默认,Qwen3-Reranker-0.6B + vLLM)、`qwen3_transformers`(纯 Transformers,无需 vLLM)、`bge`(兼容保留) | |
| 10 | +- 多后端:`qwen3_vllm`、`qwen3_vllm_score`(同模型,vLLM ``LLM.score()`` + 独立 `.venv-reranker-score`)、`qwen3_transformers`、`qwen3_transformers_packed`(共享前缀 + packed attention mask)、`qwen3_gguf`(Qwen3-Reranker-4B GGUF + llama.cpp)、`qwen3_gguf_06b`(Qwen3-Reranker-0.6B Q8_0 GGUF + llama.cpp)、`bge`(兼容保留) | |
| 11 | 11 | - 云后端:`dashscope_rerank`(调用 DashScope `/compatible-api/v1/reranks`,支持按地域切换 endpoint) |
| 12 | 12 | - 统一配置:`config/config.yaml` → `services.rerank.backend` / `services.rerank.backends.<name>` |
| 13 | 13 | - 文档去重、分数与输入顺序一致、FP16/GPU 支持(视后端) |
| ... | ... | @@ -17,28 +17,51 @@ Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwe |
| 17 | 17 | - `reranker/backends/`:后端实现与工厂 |
| 18 | 18 | - `backends/__init__.py`:`get_rerank_backend(name, config)` |
| 19 | 19 | - `backends/bge.py`:BGE 后端 |
| 20 | - - `backends/qwen3_vllm.py`:Qwen3-Reranker-0.6B + vLLM 后端 | |
| 20 | + - `backends/qwen3_vllm.py`:Qwen3-Reranker-0.6B + vLLM(generate + logprobs) | |
| 21 | + - `backends/qwen3_vllm_score.py`:同上模型 + vLLM ``LLM.score()``(`requirements_reranker_qwen3_vllm_score.txt` / `.venv-reranker-score`) | |
| 21 | 22 | - `backends/qwen3_transformers.py`:Qwen3-Reranker-0.6B 纯 Transformers 后端(官方 Usage 方式) |
| 23 | + - `backends/qwen3_transformers_packed.py`:Qwen3-Reranker-0.6B + Transformers packed 推理(共享 query prefix,适合 `1 query + 400 docs`) | |
| 24 | + - `backends/qwen3_gguf.py`:Qwen3-Reranker GGUF + llama.cpp 后端(支持 `qwen3_gguf` / `qwen3_gguf_06b`) | |
| 22 | 25 | - `backends/dashscope_rerank.py`:DashScope 云重排后端(HTTP 调用) |
| 23 | 26 | - `reranker/bge_reranker.py`:BGE 核心推理(被 bge 后端封装) |
| 24 | 27 | - `reranker/config.py`:服务端口、MAX_DOCS、NORMALIZE 等(后端参数在 config.yaml) |
| 25 | 28 | |
| 26 | 29 | ## 依赖 |
| 27 | 30 | - 通用:`torch`、`transformers`、`fastapi`、`uvicorn`(隔离环境见 `requirements_reranker_service.txt`;全量 ML 环境另见 `requirements_ml.txt`) |
| 28 | -- **Qwen3-vLLM 后端**:`vllm>=0.8.5`、`transformers>=4.51.0`(仅当使用 `backend: qwen3_vllm` 时需 vLLM) | |
| 31 | +- **Qwen3-vLLM 后端**:`vllm>=0.8.5`、`transformers>=4.51.0`(`qwen3_vllm` → `.venv-reranker`) | |
| 32 | +- **Qwen3-vLLM-score 后端**:固定 `vllm==0.18.0`(`qwen3_vllm_score` → `.venv-reranker-score`,见 `requirements_reranker_qwen3_vllm_score.txt`) | |
| 29 | 33 | - **Qwen3-Transformers 后端**:`transformers>=4.51.0`、`torch`(无需 vLLM,适合 CPU 或小显存) |
| 34 | +- **Qwen3-Transformers-Packed 后端**:复用 Transformers 依赖(`qwen3_transformers_packed` → `.venv-reranker-transformers-packed`) | |
| 35 | +- **Qwen3-GGUF 后端**:`llama-cpp-python>=0.3.16` | |
| 36 | +- 现在按 backend 使用独立 venv: | |
| 37 | + - `qwen3_vllm` -> `.venv-reranker` | |
| 38 | + - `qwen3_vllm_score` -> `.venv-reranker-score` | |
| 39 | + - `qwen3_gguf` -> `.venv-reranker-gguf` | |
| 40 | + - `qwen3_gguf_06b` -> `.venv-reranker-gguf-06b` | |
| 41 | + - `qwen3_transformers` -> `.venv-reranker-transformers` | |
| 42 | + - `qwen3_transformers_packed` -> `.venv-reranker-transformers-packed` | |
| 43 | + - `bge` -> `.venv-reranker-bge` | |
| 44 | + - `dashscope_rerank` -> `.venv-reranker-dashscope` | |
| 30 | 45 | ```bash |
| 31 | - ./scripts/setup_reranker_venv.sh | |
| 46 | + ./scripts/setup_reranker_venv.sh qwen3_gguf_06b | |
| 47 | + ``` | |
| 48 | + CUDA 构建建议: | |
| 49 | + ```bash | |
| 50 | + PATH=/usr/local/cuda/bin:$PATH \ | |
| 51 | + CUDACXX=/usr/local/cuda/bin/nvcc \ | |
| 52 | + CMAKE_ARGS="-DGGML_CUDA=on" \ | |
| 53 | + FORCE_CMAKE=1 \ | |
| 54 | + ./.venv-reranker-gguf/bin/pip install --no-cache-dir --force-reinstall --no-build-isolation llama-cpp-python==0.3.18 | |
| 32 | 55 | ``` |
| 33 | 56 | |
| 34 | 57 | ## 配置 |
| 35 | -- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `qwen3_transformers` | `bge` | `dashscope_rerank`),或环境变量 `RERANK_BACKEND`。 | |
| 58 | +- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `qwen3_vllm_score` | `qwen3_transformers` | `qwen3_transformers_packed` | `qwen3_gguf` | `qwen3_gguf_06b` | `bge` | `dashscope_rerank`),或环境变量 `RERANK_BACKEND`。 | |
| 36 | 59 | - **后端参数**:`services.rerank.backends.bge` / `services.rerank.backends.qwen3_vllm`,例如: |
| 37 | 60 | |
| 38 | 61 | ```yaml |
| 39 | 62 | services: |
| 40 | 63 | rerank: |
| 41 | - backend: "qwen3_vllm" # 或 bge | |
| 64 | + backend: "qwen3_gguf" # 或 qwen3_vllm / bge | |
| 42 | 65 | backends: |
| 43 | 66 | bge: |
| 44 | 67 | model_name: "BAAI/bge-reranker-v2-m3" |
| ... | ... | @@ -65,6 +88,44 @@ services: |
| 65 | 88 | tensor_parallel_size: 1 |
| 66 | 89 | gpu_memory_utilization: 0.8 |
| 67 | 90 | instruction: "Given a shopping query, rank product titles by relevance" |
| 91 | + qwen3_transformers_packed: | |
| 92 | + model_name: "Qwen/Qwen3-Reranker-0.6B" | |
| 93 | + instruction: "Rank products by query with category & style match prioritized" | |
| 94 | + max_model_len: 4096 | |
| 95 | + max_doc_len: 160 | |
| 96 | + max_docs_per_pack: 0 | |
| 97 | + use_fp16: true | |
| 98 | + sort_by_doc_length: true | |
| 99 | + attn_implementation: "eager" | |
| 100 | + qwen3_gguf: | |
| 101 | + repo_id: "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF" | |
| 102 | + filename: "*Q8_0.gguf" | |
| 103 | + local_dir: "./models/reranker/qwen3-reranker-4b-gguf" | |
| 104 | + cache_dir: "./model_cache" | |
| 105 | + instruction: "Rank products by query with category & style match prioritized" | |
| 106 | + n_ctx: 384 | |
| 107 | + n_batch: 384 | |
| 108 | + n_ubatch: 128 | |
| 109 | + n_gpu_layers: 24 | |
| 110 | + flash_attn: true | |
| 111 | + offload_kqv: true | |
| 112 | + infer_batch_size: 8 | |
| 113 | + sort_by_doc_length: true | |
| 114 | + length_sort_mode: "char" | |
| 115 | + qwen3_gguf_06b: | |
| 116 | + repo_id: "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF" | |
| 117 | + filename: "qwen3-reranker-0.6b-q8_0.gguf" | |
| 118 | + local_dir: "./models/reranker/qwen3-reranker-0.6b-q8_0-gguf" | |
| 119 | + cache_dir: "./model_cache" | |
| 120 | + instruction: "Rank products by query with category & style match prioritized" | |
| 121 | + n_ctx: 256 | |
| 122 | + n_batch: 256 | |
| 123 | + n_ubatch: 256 | |
| 124 | + n_gpu_layers: 999 | |
| 125 | + infer_batch_size: 32 | |
| 126 | + sort_by_doc_length: true | |
| 127 | + length_sort_mode: "char" | |
| 128 | + reuse_query_state: false | |
| 68 | 129 | dashscope_rerank: |
| 69 | 130 | model_name: "qwen3-rerank" |
| 70 | 131 | endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" |
| ... | ... | @@ -94,7 +155,7 @@ DashScope 认证: |
| 94 | 155 | ```bash |
| 95 | 156 | ./scripts/start_reranker.sh |
| 96 | 157 | ``` |
| 97 | -该脚本会使用隔离环境 `.venv-reranker`;首次请先执行 `./scripts/setup_reranker_venv.sh`。 | |
| 158 | +该脚本会按当前 `services.rerank.backend` 自动选择对应的独立 venv;首次请先执行 `./scripts/setup_reranker_venv.sh <backend>`。 | |
| 98 | 159 | |
| 99 | 160 | ## 性能压测(1000 docs) |
| 100 | 161 | ```bash |
| ... | ... | @@ -122,7 +183,7 @@ Content-Type: application/json |
| 122 | 183 | ``` |
| 123 | 184 | |
| 124 | 185 | `top_n` 为可选字段: |
| 125 | -- 对本地后端(`qwen3_vllm` / `qwen3_transformers` / `bge`)通常会忽略,仍返回全量分数。 | |
| 186 | +- 对本地后端(`qwen3_vllm` / `qwen3_transformers` / `qwen3_transformers_packed` / `qwen3_gguf` / `qwen3_gguf_06b` / `bge`)通常会忽略,仍返回全量分数。 | |
| 126 | 187 | - 对 `dashscope_rerank` 可用于控制云端返回的候选量,建议设置为 `page+size`(例如分页 `from=20,size=10` 时传 `30`)。 |
| 127 | 188 | |
| 128 | 189 | Response: |
| ... | ... | @@ -160,3 +221,6 @@ uvicorn reranker.server:app --host 0.0.0.0 --port 6007 --log-level info |
| 160 | 221 | - 运行时可用环境变量临时覆盖批量参数:`RERANK_VLLM_INFER_BATCH_SIZE`、`RERANK_VLLM_SORT_BY_DOC_LENGTH`。 |
| 161 | 222 | - **Qwen3-vLLM**:参考 [Qwen3-Reranker-0.6B](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B),需 GPU 与较多显存;与 BGE 相比适合长文本、高吞吐场景(vLLM 前缀缓存)。 |
| 162 | 223 | - **Qwen3-Transformers**:官方 Transformers Usage 方式,无需 vLLM;适合 CPU 或小显存。默认 `attn_implementation: "sdpa"`;若已安装 `flash_attn` 可设 `flash_attention_2`(未安装时服务会自动回退到 sdpa)。 |
| 224 | +- **Qwen3-Transformers-Packed**:仍使用 Hugging Face Transformers 与 PyTorch CUDA 内核,只定制 packed 输入、`position_ids` 和 4D `attention_mask`。它更适合在线检索里的“一个 query 对几百个短 doc”场景;默认 `attn_implementation: "eager"` 以保证自定义 mask 兼容性,若你的 `torch/transformers` 版本已验证支持,可再压测 `"sdpa"`。 | |
| 225 | +- **Qwen3-GGUF**:参考 [DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF](https://huggingface.co/DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF)。单卡 T4 且仅剩约 `4.8~6GB` 显存时,推荐 `Q8_0 + n_ctx=384 + n_gpu_layers=24 + flash_attn=true + offload_kqv=true` 起步;若启动 OOM,优先把 `n_gpu_layers` 下调到 `20`,再把 `n_ctx` 下调到 `320`。`infer_batch_size` 在 GGUF 后端是服务侧 work chunk,大多不如 `n_gpu_layers` / `n_ctx` 关键。 | |
| 226 | +- **Qwen3-GGUF-0.6B**:参考 [ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF](https://huggingface.co/ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF)。它的优点是权重小、显存占用低,单进程实测约 `0.9~1.1 GiB`;但在当前 llama.cpp 串行打分接法下,`1 query + 400 titles` 的实测延迟仍约 `265s`。因此它更适合低显存功能后备,不适合作为在线低延迟主 reranker。 | ... | ... |
reranker/backends/__init__.py
| ... | ... | @@ -43,14 +43,32 @@ def get_rerank_backend(name: str, config: Dict[str, Any]) -> RerankBackendProtoc |
| 43 | 43 | if name == "qwen3_vllm": |
| 44 | 44 | from reranker.backends.qwen3_vllm import Qwen3VLLMRerankerBackend |
| 45 | 45 | return Qwen3VLLMRerankerBackend(config) |
| 46 | + if name == "qwen3_vllm_score": | |
| 47 | + from reranker.backends.qwen3_vllm_score import Qwen3VLLMScoreRerankerBackend | |
| 48 | + return Qwen3VLLMScoreRerankerBackend(config) | |
| 46 | 49 | if name == "qwen3_transformers": |
| 47 | 50 | from reranker.backends.qwen3_transformers import Qwen3TransformersRerankerBackend |
| 48 | 51 | return Qwen3TransformersRerankerBackend(config) |
| 52 | + if name == "qwen3_transformers_packed": | |
| 53 | + from reranker.backends.qwen3_transformers_packed import ( | |
| 54 | + Qwen3TransformersPackedRerankerBackend, | |
| 55 | + ) | |
| 56 | + return Qwen3TransformersPackedRerankerBackend(config) | |
| 57 | + if name == "qwen3_gguf": | |
| 58 | + from reranker.backends.qwen3_gguf import Qwen3GGUFRerankerBackend | |
| 59 | + gguf_config = dict(config or {}) | |
| 60 | + gguf_config.setdefault("_backend_name", "qwen3_gguf") | |
| 61 | + return Qwen3GGUFRerankerBackend(gguf_config) | |
| 62 | + if name == "qwen3_gguf_06b": | |
| 63 | + from reranker.backends.qwen3_gguf import Qwen3GGUFRerankerBackend | |
| 64 | + gguf_config = dict(config or {}) | |
| 65 | + gguf_config.setdefault("_backend_name", "qwen3_gguf_06b") | |
| 66 | + return Qwen3GGUFRerankerBackend(gguf_config) | |
| 49 | 67 | if name == "dashscope_rerank": |
| 50 | 68 | from reranker.backends.dashscope_rerank import DashScopeRerankBackend |
| 51 | 69 | return DashScopeRerankBackend(config) |
| 52 | 70 | raise ValueError( |
| 53 | - f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_transformers, dashscope_rerank" | |
| 71 | + f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_vllm_score, qwen3_transformers, qwen3_transformers_packed, qwen3_gguf, qwen3_gguf_06b, dashscope_rerank" | |
| 54 | 72 | ) |
| 55 | 73 | |
| 56 | 74 | ... | ... |
| ... | ... | @@ -0,0 +1,408 @@ |
| 1 | +""" | |
| 2 | +Qwen3-Reranker GGUF backend using llama-cpp-python. | |
| 3 | + | |
| 4 | +Reference: | |
| 5 | +- https://huggingface.co/DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF | |
| 6 | +- https://huggingface.co/Qwen/Qwen3-Reranker-4B | |
| 7 | +- https://huggingface.co/ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF | |
| 8 | +- https://huggingface.co/Qwen/Qwen3-Reranker-0.6B | |
| 9 | +""" | |
| 10 | + | |
| 11 | +from __future__ import annotations | |
| 12 | + | |
| 13 | +import logging | |
| 14 | +import math | |
| 15 | +import os | |
| 16 | +import threading | |
| 17 | +import time | |
| 18 | +from pathlib import Path | |
| 19 | +from typing import Any, Dict, List, Tuple | |
| 20 | + | |
| 21 | + | |
| 22 | +logger = logging.getLogger("reranker.backends.qwen3_gguf") | |
| 23 | + | |
| 24 | + | |
| 25 | +_BACKEND_DEFAULTS: Dict[str, Dict[str, str]] = { | |
| 26 | + "qwen3_gguf": { | |
| 27 | + "repo_id": "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF", | |
| 28 | + "filename": "*Q8_0.gguf", | |
| 29 | + "local_dir": "./models/reranker/qwen3-reranker-4b-gguf", | |
| 30 | + }, | |
| 31 | + "qwen3_gguf_06b": { | |
| 32 | + "repo_id": "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF", | |
| 33 | + "filename": "qwen3-reranker-0.6b-q8_0.gguf", | |
| 34 | + "local_dir": "./models/reranker/qwen3-reranker-0.6b-q8_0-gguf", | |
| 35 | + }, | |
| 36 | +} | |
| 37 | + | |
| 38 | + | |
| 39 | +def deduplicate_with_positions(texts: List[str]) -> Tuple[List[str], List[int]]: | |
| 40 | + """Deduplicate texts globally while preserving first-seen order.""" | |
| 41 | + unique_texts: List[str] = [] | |
| 42 | + position_to_unique: List[int] = [] | |
| 43 | + seen: Dict[str, int] = {} | |
| 44 | + | |
| 45 | + for text in texts: | |
| 46 | + idx = seen.get(text) | |
| 47 | + if idx is None: | |
| 48 | + idx = len(unique_texts) | |
| 49 | + seen[text] = idx | |
| 50 | + unique_texts.append(text) | |
| 51 | + position_to_unique.append(idx) | |
| 52 | + | |
| 53 | + return unique_texts, position_to_unique | |
| 54 | + | |
| 55 | + | |
| 56 | +def _format_instruction(instruction: str, query: str, doc: str) -> str: | |
| 57 | + return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format( | |
| 58 | + instruction=instruction, | |
| 59 | + query=query, | |
| 60 | + doc=doc, | |
| 61 | + ) | |
| 62 | + | |
| 63 | + | |
| 64 | +class Qwen3GGUFRerankerBackend: | |
| 65 | + """ | |
| 66 | + Qwen3-Reranker GGUF backend using llama.cpp through llama-cpp-python. | |
| 67 | + | |
| 68 | + Tuned for short-query / short-doc reranking on a single GPU. | |
| 69 | + Config from services.rerank.backends.<backend_name>. | |
| 70 | + """ | |
| 71 | + | |
| 72 | + def __init__(self, config: Dict[str, Any]) -> None: | |
| 73 | + self._config = config or {} | |
| 74 | + self._backend_name = str(self._config.get("_backend_name") or "qwen3_gguf").strip() | |
| 75 | + defaults = _BACKEND_DEFAULTS.get(self._backend_name, _BACKEND_DEFAULTS["qwen3_gguf"]) | |
| 76 | + self._repo_id = str(self._config.get("repo_id") or defaults["repo_id"]).strip() | |
| 77 | + self._filename = str(self._config.get("filename") or defaults["filename"]).strip() | |
| 78 | + self._model_path = str(self._config.get("model_path") or "").strip() | |
| 79 | + self._cache_dir = str(self._config.get("cache_dir") or "").strip() or None | |
| 80 | + self._local_dir = str(self._config.get("local_dir") or defaults["local_dir"]).strip() or None | |
| 81 | + self._instruction = str( | |
| 82 | + self._config.get("instruction") | |
| 83 | + or "Rank products by query with category & style match prioritized" | |
| 84 | + ) | |
| 85 | + self._infer_batch_size = int( | |
| 86 | + os.getenv("RERANK_GGUF_INFER_BATCH_SIZE") or self._config.get("infer_batch_size", 8) | |
| 87 | + ) | |
| 88 | + sort_by_doc_length = os.getenv("RERANK_GGUF_SORT_BY_DOC_LENGTH") | |
| 89 | + if sort_by_doc_length is None: | |
| 90 | + sort_by_doc_length = self._config.get("sort_by_doc_length", True) | |
| 91 | + self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in { | |
| 92 | + "1", | |
| 93 | + "true", | |
| 94 | + "yes", | |
| 95 | + "y", | |
| 96 | + "on", | |
| 97 | + } | |
| 98 | + self._length_sort_mode = str(self._config.get("length_sort_mode") or "char").strip().lower() | |
| 99 | + self._reuse_query_state = bool(self._config.get("reuse_query_state", False)) | |
| 100 | + | |
| 101 | + n_ctx = int(self._config.get("n_ctx", self._config.get("max_model_len", 384))) | |
| 102 | + n_batch = int(self._config.get("n_batch", min(n_ctx, 384))) | |
| 103 | + n_ubatch = int(self._config.get("n_ubatch", min(n_batch, 128))) | |
| 104 | + n_gpu_layers = int(self._config.get("n_gpu_layers", 24)) | |
| 105 | + main_gpu = int(self._config.get("main_gpu", 0)) | |
| 106 | + n_threads = int(self._config.get("n_threads", 2)) | |
| 107 | + n_threads_batch = int(self._config.get("n_threads_batch", 4)) | |
| 108 | + flash_attn = bool(self._config.get("flash_attn", True)) | |
| 109 | + offload_kqv = bool(self._config.get("offload_kqv", True)) | |
| 110 | + use_mmap = bool(self._config.get("use_mmap", True)) | |
| 111 | + use_mlock = bool(self._config.get("use_mlock", False)) | |
| 112 | + verbose = bool(self._config.get("verbose", False)) | |
| 113 | + enable_warmup = bool(self._config.get("enable_warmup", True)) | |
| 114 | + | |
| 115 | + if self._infer_batch_size <= 0: | |
| 116 | + raise ValueError(f"infer_batch_size must be > 0, got {self._infer_batch_size}") | |
| 117 | + if n_ctx <= 0: | |
| 118 | + raise ValueError(f"n_ctx must be > 0, got {n_ctx}") | |
| 119 | + if n_batch <= 0 or n_ubatch <= 0: | |
| 120 | + raise ValueError(f"n_batch/n_ubatch must be > 0, got {n_batch}/{n_ubatch}") | |
| 121 | + | |
| 122 | + try: | |
| 123 | + from llama_cpp import Llama | |
| 124 | + except Exception as exc: # pragma: no cover - depends on optional dependency | |
| 125 | + raise RuntimeError( | |
| 126 | + f"{self._backend_name} backend requires llama-cpp-python. " | |
| 127 | + f"Install the {self._backend_name} backend venv first via " | |
| 128 | + f"scripts/setup_reranker_venv.sh {self._backend_name}." | |
| 129 | + ) from exc | |
| 130 | + | |
| 131 | + self._llama_class = Llama | |
| 132 | + self._n_ctx = n_ctx | |
| 133 | + self._n_batch = n_batch | |
| 134 | + self._n_ubatch = n_ubatch | |
| 135 | + self._n_gpu_layers = n_gpu_layers | |
| 136 | + self._enable_warmup = enable_warmup | |
| 137 | + self._infer_lock = threading.Lock() | |
| 138 | + | |
| 139 | + logger.info( | |
| 140 | + "[Qwen3_GGUF] Loading backend=%s repo=%s filename=%s model_path=%s n_ctx=%s n_batch=%s n_ubatch=%s n_gpu_layers=%s flash_attn=%s offload_kqv=%s reuse_query_state=%s", | |
| 141 | + self._backend_name, | |
| 142 | + self._repo_id, | |
| 143 | + self._filename, | |
| 144 | + self._model_path or None, | |
| 145 | + n_ctx, | |
| 146 | + n_batch, | |
| 147 | + n_ubatch, | |
| 148 | + n_gpu_layers, | |
| 149 | + flash_attn, | |
| 150 | + offload_kqv, | |
| 151 | + self._reuse_query_state, | |
| 152 | + ) | |
| 153 | + | |
| 154 | + llm_kwargs = { | |
| 155 | + "n_ctx": n_ctx, | |
| 156 | + "n_batch": n_batch, | |
| 157 | + "n_ubatch": n_ubatch, | |
| 158 | + "n_gpu_layers": n_gpu_layers, | |
| 159 | + "main_gpu": main_gpu, | |
| 160 | + "n_threads": n_threads, | |
| 161 | + "n_threads_batch": n_threads_batch, | |
| 162 | + "logits_all": True, | |
| 163 | + "offload_kqv": offload_kqv, | |
| 164 | + "flash_attn": flash_attn, | |
| 165 | + "use_mmap": use_mmap, | |
| 166 | + "use_mlock": use_mlock, | |
| 167 | + "verbose": verbose, | |
| 168 | + } | |
| 169 | + llm_kwargs = {key: value for key, value in llm_kwargs.items() if value is not None} | |
| 170 | + self._llm = self._load_model(llm_kwargs) | |
| 171 | + self._model_name = self._model_path or f"{self._repo_id}:{self._filename}" | |
| 172 | + | |
| 173 | + self._prefix = ( | |
| 174 | + "<|im_start|>system\n" | |
| 175 | + "Judge whether the Document meets the requirements based on the Query and the Instruct provided. " | |
| 176 | + 'Note that the answer can only be "yes" or "no".' | |
| 177 | + "<|im_end|>\n<|im_start|>user\n" | |
| 178 | + ) | |
| 179 | + self._suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" | |
| 180 | + self._prefix_tokens = self._tokenize(self._prefix, special=True) | |
| 181 | + self._suffix_tokens = self._tokenize(self._suffix, special=True) | |
| 182 | + self._request_prefix_template = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: " | |
| 183 | + self._effective_max_len = self._n_ctx - len(self._prefix_tokens) - len(self._suffix_tokens) | |
| 184 | + if self._effective_max_len <= 16: | |
| 185 | + raise RuntimeError( | |
| 186 | + f"n_ctx={self._n_ctx} is too small after prompt overhead; effective={self._effective_max_len}" | |
| 187 | + ) | |
| 188 | + | |
| 189 | + self._true_token = self._single_token_id("yes") | |
| 190 | + self._false_token = self._single_token_id("no") | |
| 191 | + | |
| 192 | + if self._enable_warmup: | |
| 193 | + self._warmup() | |
| 194 | + | |
| 195 | + logger.info( | |
| 196 | + "[Qwen3_GGUF] Model ready | backend=%s model=%s effective_max_len=%s infer_batch_size=%s sort_by_doc_length=%s", | |
| 197 | + self._backend_name, | |
| 198 | + self._model_name, | |
| 199 | + self._effective_max_len, | |
| 200 | + self._infer_batch_size, | |
| 201 | + self._sort_by_doc_length, | |
| 202 | + ) | |
| 203 | + | |
| 204 | + def _load_model(self, llm_kwargs: Dict[str, Any]): | |
| 205 | + if self._model_path: | |
| 206 | + return self._llama_class(model_path=self._model_path, **llm_kwargs) | |
| 207 | + if self._local_dir: | |
| 208 | + matches = sorted( | |
| 209 | + path for path in Path(self._local_dir).glob(self._filename) if path.is_file() | |
| 210 | + ) | |
| 211 | + if matches: | |
| 212 | + local_model_path = str(matches[0].resolve()) | |
| 213 | + logger.info("[Qwen3_GGUF] Using local GGUF file: %s", local_model_path) | |
| 214 | + return self._llama_class(model_path=local_model_path, **llm_kwargs) | |
| 215 | + return self._llama_class.from_pretrained( | |
| 216 | + repo_id=self._repo_id, | |
| 217 | + filename=self._filename, | |
| 218 | + local_dir=self._local_dir, | |
| 219 | + cache_dir=self._cache_dir, | |
| 220 | + **llm_kwargs, | |
| 221 | + ) | |
| 222 | + | |
| 223 | + def _tokenize(self, text: str, *, special: bool) -> List[int]: | |
| 224 | + return list( | |
| 225 | + self._llm.tokenize( | |
| 226 | + text.encode("utf-8"), | |
| 227 | + add_bos=False, | |
| 228 | + special=special, | |
| 229 | + ) | |
| 230 | + ) | |
| 231 | + | |
| 232 | + def _single_token_id(self, text: str) -> int: | |
| 233 | + token_ids = self._tokenize(text, special=False) | |
| 234 | + if len(token_ids) != 1: | |
| 235 | + raise RuntimeError(f"Expected {text!r} to be one token, got {token_ids}") | |
| 236 | + return int(token_ids[0]) | |
| 237 | + | |
| 238 | + def _warmup(self) -> None: | |
| 239 | + try: | |
| 240 | + prompt = self._build_prompt_tokens("warmup query", "warmup document") | |
| 241 | + with self._infer_lock: | |
| 242 | + self._eval_logits(prompt) | |
| 243 | + except Exception as exc: # pragma: no cover - defensive | |
| 244 | + logger.warning("[Qwen3_GGUF] Warmup failed: %s", exc) | |
| 245 | + | |
| 246 | + def _build_request_prefix_tokens(self, query: str) -> List[int]: | |
| 247 | + request_prefix = self._request_prefix_template.format( | |
| 248 | + instruction=self._instruction, | |
| 249 | + query=query, | |
| 250 | + ) | |
| 251 | + return self._tokenize(request_prefix, special=False) | |
| 252 | + | |
| 253 | + def _build_prompt_tokens(self, query: str, doc: str) -> List[int]: | |
| 254 | + pair = _format_instruction(self._instruction, query, doc) | |
| 255 | + pair_tokens = self._tokenize(pair, special=False) | |
| 256 | + pair_tokens = pair_tokens[: self._effective_max_len] | |
| 257 | + return self._prefix_tokens + pair_tokens + self._suffix_tokens | |
| 258 | + | |
| 259 | + def _eval_logits(self, prompt_tokens: List[int]) -> List[float]: | |
| 260 | + self._llm.reset() | |
| 261 | + self._llm.eval(prompt_tokens) | |
| 262 | + logits = self._llm.eval_logits | |
| 263 | + if not logits: | |
| 264 | + raise RuntimeError("llama.cpp returned empty logits") | |
| 265 | + return list(logits[-1]) | |
| 266 | + | |
| 267 | + def _score_prompt(self, prompt_tokens: List[int]) -> float: | |
| 268 | + logits = self._eval_logits(prompt_tokens) | |
| 269 | + true_logit = float(logits[self._true_token]) | |
| 270 | + false_logit = float(logits[self._false_token]) | |
| 271 | + max_logit = max(true_logit, false_logit) | |
| 272 | + true_exp = math.exp(true_logit - max_logit) | |
| 273 | + false_exp = math.exp(false_logit - max_logit) | |
| 274 | + return float(true_exp / (true_exp + false_exp)) | |
| 275 | + | |
| 276 | + def _supports_query_state_reuse(self) -> bool: | |
| 277 | + return ( | |
| 278 | + self._reuse_query_state | |
| 279 | + and hasattr(self._llm, "save_state") | |
| 280 | + and hasattr(self._llm, "load_state") | |
| 281 | + ) | |
| 282 | + | |
| 283 | + def _build_query_state_locked(self, query: str): | |
| 284 | + request_prefix_tokens = self._build_request_prefix_tokens(query) | |
| 285 | + max_doc_tokens = self._effective_max_len - len(request_prefix_tokens) | |
| 286 | + if max_doc_tokens <= 0: | |
| 287 | + return None, 0 | |
| 288 | + self._llm.reset() | |
| 289 | + self._llm.eval(self._prefix_tokens + request_prefix_tokens) | |
| 290 | + return self._llm.save_state(), max_doc_tokens | |
| 291 | + | |
| 292 | + def _score_doc_with_state_locked(self, state, doc_tokens: List[int], max_doc_tokens: int) -> float: | |
| 293 | + self._llm.load_state(state) | |
| 294 | + self._llm.eval(doc_tokens[:max_doc_tokens] + self._suffix_tokens) | |
| 295 | + logits = self._llm.eval_logits | |
| 296 | + if not logits: | |
| 297 | + raise RuntimeError("llama.cpp returned empty logits") | |
| 298 | + final_logits = list(logits[-1]) | |
| 299 | + true_logit = float(final_logits[self._true_token]) | |
| 300 | + false_logit = float(final_logits[self._false_token]) | |
| 301 | + max_logit = max(true_logit, false_logit) | |
| 302 | + true_exp = math.exp(true_logit - max_logit) | |
| 303 | + false_exp = math.exp(false_logit - max_logit) | |
| 304 | + return float(true_exp / (true_exp + false_exp)) | |
| 305 | + | |
| 306 | + def _estimate_doc_lengths(self, docs: List[str]) -> List[int]: | |
| 307 | + if self._length_sort_mode == "token": | |
| 308 | + return [len(self._tokenize(text, special=False)) for text in docs] | |
| 309 | + return [len(text) for text in docs] | |
| 310 | + | |
| 311 | + def score_with_meta( | |
| 312 | + self, | |
| 313 | + query: str, | |
| 314 | + docs: List[str], | |
| 315 | + normalize: bool = True, | |
| 316 | + ) -> Tuple[List[float], Dict[str, Any]]: | |
| 317 | + start_ts = time.time() | |
| 318 | + total_docs = len(docs) if docs else 0 | |
| 319 | + output_scores: List[float] = [0.0] * total_docs | |
| 320 | + | |
| 321 | + query = "" if query is None else str(query).strip() | |
| 322 | + indexed: List[Tuple[int, str]] = [] | |
| 323 | + for i, doc in enumerate(docs or []): | |
| 324 | + if doc is None: | |
| 325 | + continue | |
| 326 | + text = str(doc).strip() | |
| 327 | + if not text: | |
| 328 | + continue | |
| 329 | + indexed.append((i, text)) | |
| 330 | + | |
| 331 | + if not query or not indexed: | |
| 332 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 333 | + return output_scores, { | |
| 334 | + "input_docs": total_docs, | |
| 335 | + "usable_docs": len(indexed), | |
| 336 | + "unique_docs": 0, | |
| 337 | + "dedup_ratio": 0.0, | |
| 338 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 339 | + "model": self._model_name, | |
| 340 | + "backend": self._backend_name, | |
| 341 | + "normalize": normalize, | |
| 342 | + "infer_batch_size": self._infer_batch_size, | |
| 343 | + "inference_batches": 0, | |
| 344 | + "sort_by_doc_length": self._sort_by_doc_length, | |
| 345 | + "n_ctx": self._n_ctx, | |
| 346 | + "n_batch": self._n_batch, | |
| 347 | + "n_ubatch": self._n_ubatch, | |
| 348 | + "n_gpu_layers": self._n_gpu_layers, | |
| 349 | + } | |
| 350 | + | |
| 351 | + indexed_texts = [text for _, text in indexed] | |
| 352 | + unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts) | |
| 353 | + | |
| 354 | + lengths = self._estimate_doc_lengths(unique_texts) | |
| 355 | + order = list(range(len(unique_texts))) | |
| 356 | + if self._sort_by_doc_length and len(unique_texts) > 1: | |
| 357 | + order = sorted(order, key=lambda i: lengths[i]) | |
| 358 | + | |
| 359 | + unique_scores: List[float] = [0.0] * len(unique_texts) | |
| 360 | + unique_doc_tokens = [self._tokenize(text, special=False) for text in unique_texts] | |
| 361 | + inference_batches = 0 | |
| 362 | + with self._infer_lock: | |
| 363 | + query_state = None | |
| 364 | + max_doc_tokens = self._effective_max_len | |
| 365 | + if self._supports_query_state_reuse(): | |
| 366 | + query_state, max_doc_tokens = self._build_query_state_locked(query) | |
| 367 | + for start in range(0, len(order), self._infer_batch_size): | |
| 368 | + batch_indices = order[start : start + self._infer_batch_size] | |
| 369 | + inference_batches += 1 | |
| 370 | + for idx in batch_indices: | |
| 371 | + if query_state is not None: | |
| 372 | + unique_scores[idx] = self._score_doc_with_state_locked( | |
| 373 | + query_state, | |
| 374 | + unique_doc_tokens[idx], | |
| 375 | + max_doc_tokens, | |
| 376 | + ) | |
| 377 | + else: | |
| 378 | + prompt = self._build_prompt_tokens(query, unique_texts[idx]) | |
| 379 | + unique_scores[idx] = self._score_prompt(prompt) | |
| 380 | + | |
| 381 | + for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): | |
| 382 | + output_scores[orig_idx] = float(unique_scores[unique_idx]) | |
| 383 | + | |
| 384 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 385 | + dedup_ratio = 0.0 | |
| 386 | + if indexed: | |
| 387 | + dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) | |
| 388 | + | |
| 389 | + meta = { | |
| 390 | + "input_docs": total_docs, | |
| 391 | + "usable_docs": len(indexed), | |
| 392 | + "unique_docs": len(unique_texts), | |
| 393 | + "dedup_ratio": round(dedup_ratio, 4), | |
| 394 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 395 | + "model": self._model_name, | |
| 396 | + "backend": self._backend_name, | |
| 397 | + "normalize": normalize, | |
| 398 | + "infer_batch_size": self._infer_batch_size, | |
| 399 | + "inference_batches": inference_batches, | |
| 400 | + "sort_by_doc_length": self._sort_by_doc_length, | |
| 401 | + "length_sort_mode": self._length_sort_mode, | |
| 402 | + "n_ctx": self._n_ctx, | |
| 403 | + "n_batch": self._n_batch, | |
| 404 | + "n_ubatch": self._n_ubatch, | |
| 405 | + "n_gpu_layers": self._n_gpu_layers, | |
| 406 | + "reuse_query_state": query_state is not None, | |
| 407 | + } | |
| 408 | + return output_scores, meta | ... | ... |
| ... | ... | @@ -0,0 +1,398 @@ |
| 1 | +""" | |
| 2 | +Qwen3-Reranker backend using packed inference with Transformers. | |
| 3 | + | |
| 4 | +This backend implements the sequence stitching optimization described in | |
| 5 | +Qwen3-Reranker packed inference examples: | |
| 6 | +1. Share the query/instruction prefix across many documents. | |
| 7 | +2. Reset document ``position_ids`` relative to the shared prefix. | |
| 8 | +3. Use a custom causal attention mask so each document can attend to the | |
| 9 | + prefix and itself, but never to other documents. | |
| 10 | + | |
| 11 | +Compared with the standard per-pair batching path, this reduces repeated | |
| 12 | +prefix computation and removes inter-sample padding waste. For online search | |
| 13 | +requests like ``1 query + 400 docs``, the backend further packs documents into | |
| 14 | +multiple chunks under a configurable total token budget. | |
| 15 | +""" | |
| 16 | + | |
| 17 | +from __future__ import annotations | |
| 18 | + | |
| 19 | +import logging | |
| 20 | +import threading | |
| 21 | +import time | |
| 22 | +from typing import Any, Dict, List, Sequence, Tuple | |
| 23 | + | |
| 24 | +import torch | |
| 25 | +from transformers import AutoModelForCausalLM, AutoTokenizer | |
| 26 | + | |
| 27 | +logger = logging.getLogger("reranker.backends.qwen3_transformers_packed") | |
| 28 | + | |
| 29 | +_DEFAULT_PREFIX = ( | |
| 30 | + "<|im_start|>system\n" | |
| 31 | + "Judge whether the Document meets the requirements based on the Query and the Instruct " | |
| 32 | + 'provided. Note that the answer can only be "yes" or "no".' | |
| 33 | + "<|im_end|>\n<|im_start|>user\n" | |
| 34 | +) | |
| 35 | +_DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" | |
| 36 | +_DEFAULT_PAIR_PREFIX_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n<Document>: " | |
| 37 | + | |
| 38 | + | |
| 39 | +def _deduplicate_with_positions(texts: Sequence[str]) -> Tuple[List[str], List[int]]: | |
| 40 | + unique_texts: List[str] = [] | |
| 41 | + position_to_unique: List[int] = [] | |
| 42 | + seen: Dict[str, int] = {} | |
| 43 | + | |
| 44 | + for text in texts: | |
| 45 | + idx = seen.get(text) | |
| 46 | + if idx is None: | |
| 47 | + idx = len(unique_texts) | |
| 48 | + seen[text] = idx | |
| 49 | + unique_texts.append(text) | |
| 50 | + position_to_unique.append(idx) | |
| 51 | + | |
| 52 | + return unique_texts, position_to_unique | |
| 53 | + | |
| 54 | + | |
| 55 | +class Qwen3TransformersPackedRerankerBackend: | |
| 56 | + """ | |
| 57 | + Qwen3-Reranker packed inference backend using Transformers. | |
| 58 | + | |
| 59 | + Config from ``services.rerank.backends.qwen3_transformers_packed``. | |
| 60 | + """ | |
| 61 | + | |
| 62 | + def __init__(self, config: Dict[str, Any]) -> None: | |
| 63 | + self._config = config or {} | |
| 64 | + model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B") | |
| 65 | + self._instruction = str( | |
| 66 | + self._config.get("instruction") | |
| 67 | + or "Rank products by query with category & style match prioritized" | |
| 68 | + ) | |
| 69 | + self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX) | |
| 70 | + self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX) | |
| 71 | + self._pair_prefix_template = str( | |
| 72 | + self._config.get("pair_prefix_template") or _DEFAULT_PAIR_PREFIX_TEMPLATE | |
| 73 | + ) | |
| 74 | + | |
| 75 | + max_model_len = int(self._config.get("max_model_len", 4096)) | |
| 76 | + max_doc_len = int(self._config.get("max_doc_len", 160)) | |
| 77 | + max_docs_per_pack = int(self._config.get("max_docs_per_pack", 0)) | |
| 78 | + use_fp16 = bool(self._config.get("use_fp16", True)) | |
| 79 | + device = self._config.get("device") | |
| 80 | + attn_impl = str(self._config.get("attn_implementation") or "eager").strip() | |
| 81 | + sort_by_doc_length = self._config.get("sort_by_doc_length", True) | |
| 82 | + | |
| 83 | + self._model_name = model_name | |
| 84 | + self._max_model_len = max_model_len | |
| 85 | + self._max_doc_len = max_doc_len | |
| 86 | + self._max_docs_per_pack = max_docs_per_pack | |
| 87 | + self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in { | |
| 88 | + "1", | |
| 89 | + "true", | |
| 90 | + "yes", | |
| 91 | + "y", | |
| 92 | + "on", | |
| 93 | + } | |
| 94 | + self._attn_impl = attn_impl | |
| 95 | + | |
| 96 | + logger.info( | |
| 97 | + "[Qwen3_Transformers_Packed] Loading model %s (max_model_len=%s, max_doc_len=%s, " | |
| 98 | + "max_docs_per_pack=%s, fp16=%s, attn_impl=%s)", | |
| 99 | + model_name, | |
| 100 | + max_model_len, | |
| 101 | + max_doc_len, | |
| 102 | + max_docs_per_pack, | |
| 103 | + use_fp16, | |
| 104 | + attn_impl, | |
| 105 | + ) | |
| 106 | + | |
| 107 | + self._tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") | |
| 108 | + self._tokenizer.pad_token = self._tokenizer.eos_token | |
| 109 | + | |
| 110 | + self._prefix_tokens = self._tokenizer.encode(self._prefix, add_special_tokens=False) | |
| 111 | + self._suffix_tokens = self._tokenizer.encode(self._suffix, add_special_tokens=False) | |
| 112 | + self._suffix_len = len(self._suffix_tokens) | |
| 113 | + | |
| 114 | + if not torch.cuda.is_available(): | |
| 115 | + raise RuntimeError( | |
| 116 | + "qwen3_transformers_packed backend requires CUDA GPU, " | |
| 117 | + "but torch.cuda.is_available() is False" | |
| 118 | + ) | |
| 119 | + | |
| 120 | + kwargs: Dict[str, Any] = {} | |
| 121 | + if use_fp16: | |
| 122 | + kwargs["torch_dtype"] = torch.float16 | |
| 123 | + if attn_impl: | |
| 124 | + kwargs["attn_implementation"] = attn_impl | |
| 125 | + | |
| 126 | + self._model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).eval() | |
| 127 | + target_device = str(device).strip() if device is not None else "cuda" | |
| 128 | + if not target_device.startswith("cuda"): | |
| 129 | + raise ValueError( | |
| 130 | + "qwen3_transformers_packed backend is GPU-only. " | |
| 131 | + f"Unsupported device setting: {target_device!r}" | |
| 132 | + ) | |
| 133 | + self._model = self._model.to(target_device) | |
| 134 | + self._device = next(self._model.parameters()).device | |
| 135 | + if self._device.type != "cuda": | |
| 136 | + raise RuntimeError( | |
| 137 | + "qwen3_transformers_packed backend failed to place model on CUDA. " | |
| 138 | + f"Current device: {self._device}" | |
| 139 | + ) | |
| 140 | + | |
| 141 | + self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes") | |
| 142 | + self._token_false_id = self._tokenizer.convert_tokens_to_ids("no") | |
| 143 | + if self._token_true_id is None or self._token_false_id is None: | |
| 144 | + raise RuntimeError("Failed to resolve Qwen3 reranker classifier token ids for yes/no") | |
| 145 | + | |
| 146 | + prefix_budget = len(self._prefix_tokens) + self._suffix_len + 1 | |
| 147 | + if self._max_model_len <= prefix_budget: | |
| 148 | + raise ValueError( | |
| 149 | + "max_model_len is too small for packed reranking. " | |
| 150 | + f"Need > {prefix_budget}, got {self._max_model_len}." | |
| 151 | + ) | |
| 152 | + if self._max_doc_len <= 0: | |
| 153 | + raise ValueError(f"max_doc_len must be > 0, got {self._max_doc_len}") | |
| 154 | + if self._max_docs_per_pack < 0: | |
| 155 | + raise ValueError( | |
| 156 | + f"max_docs_per_pack must be >= 0, got {self._max_docs_per_pack}" | |
| 157 | + ) | |
| 158 | + | |
| 159 | + self._infer_lock = threading.Lock() | |
| 160 | + | |
| 161 | + logger.info( | |
| 162 | + "[Qwen3_Transformers_Packed] Model ready | model=%s device=%s", | |
| 163 | + model_name, | |
| 164 | + self._device, | |
| 165 | + ) | |
| 166 | + | |
| 167 | + def _build_pair_prefix_tokens(self, query: str) -> List[int]: | |
| 168 | + pair_prefix = self._pair_prefix_template.format( | |
| 169 | + prefix=self._prefix, | |
| 170 | + instruction=self._instruction, | |
| 171 | + query=query, | |
| 172 | + ) | |
| 173 | + return self._tokenizer.encode(pair_prefix, add_special_tokens=False) | |
| 174 | + | |
| 175 | + def _tokenize_documents(self, docs: Sequence[str], query_prefix_len: int) -> List[List[int]]: | |
| 176 | + max_doc_tokens = min( | |
| 177 | + self._max_doc_len, | |
| 178 | + max(1, self._max_model_len - query_prefix_len - self._suffix_len), | |
| 179 | + ) | |
| 180 | + tokenized = self._tokenizer( | |
| 181 | + list(docs), | |
| 182 | + padding=False, | |
| 183 | + truncation=True, | |
| 184 | + max_length=max_doc_tokens, | |
| 185 | + add_special_tokens=False, | |
| 186 | + return_attention_mask=False, | |
| 187 | + ) | |
| 188 | + return [list(ids) for ids in tokenized["input_ids"]] | |
| 189 | + | |
| 190 | + def _build_pack_plan( | |
| 191 | + self, | |
| 192 | + query_prefix_len: int, | |
| 193 | + doc_tokens: Sequence[Sequence[int]], | |
| 194 | + ) -> List[List[int]]: | |
| 195 | + order = list(range(len(doc_tokens))) | |
| 196 | + if self._sort_by_doc_length and len(order) > 1: | |
| 197 | + order.sort(key=lambda idx: len(doc_tokens[idx])) | |
| 198 | + | |
| 199 | + packs: List[List[int]] = [] | |
| 200 | + current_pack: List[int] = [] | |
| 201 | + current_len = query_prefix_len | |
| 202 | + for idx in order: | |
| 203 | + packed_doc_len = len(doc_tokens[idx]) + self._suffix_len | |
| 204 | + if packed_doc_len <= 0: | |
| 205 | + continue | |
| 206 | + | |
| 207 | + over_docs_cap = self._max_docs_per_pack > 0 and len(current_pack) >= self._max_docs_per_pack | |
| 208 | + over_token_cap = current_pack and (current_len + packed_doc_len > self._max_model_len) | |
| 209 | + if over_docs_cap or over_token_cap: | |
| 210 | + packs.append(current_pack) | |
| 211 | + current_pack = [] | |
| 212 | + current_len = query_prefix_len | |
| 213 | + | |
| 214 | + if query_prefix_len + packed_doc_len > self._max_model_len: | |
| 215 | + raise ValueError( | |
| 216 | + "Packed doc still exceeds max_model_len after truncation. " | |
| 217 | + f"query_prefix_len={query_prefix_len}, doc_len={packed_doc_len}, " | |
| 218 | + f"max_model_len={self._max_model_len}" | |
| 219 | + ) | |
| 220 | + | |
| 221 | + current_pack.append(idx) | |
| 222 | + current_len += packed_doc_len | |
| 223 | + | |
| 224 | + if current_pack: | |
| 225 | + packs.append(current_pack) | |
| 226 | + return packs | |
| 227 | + | |
| 228 | + def _build_pack_inputs( | |
| 229 | + self, | |
| 230 | + query_prefix_tokens: Sequence[int], | |
| 231 | + doc_tokens: Sequence[Sequence[int]], | |
| 232 | + doc_indices: Sequence[int], | |
| 233 | + ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: | |
| 234 | + prefix_len = len(query_prefix_tokens) | |
| 235 | + input_ids_list = list(query_prefix_tokens) | |
| 236 | + position_ids_list = list(range(prefix_len)) | |
| 237 | + spans: List[Tuple[int, int]] = [] | |
| 238 | + current_len = prefix_len | |
| 239 | + | |
| 240 | + for idx in doc_indices: | |
| 241 | + doc_with_suffix = list(doc_tokens[idx]) + self._suffix_tokens | |
| 242 | + start = current_len | |
| 243 | + end = start + len(doc_with_suffix) | |
| 244 | + spans.append((start, end)) | |
| 245 | + input_ids_list.extend(doc_with_suffix) | |
| 246 | + position_ids_list.extend(range(prefix_len, prefix_len + len(doc_with_suffix))) | |
| 247 | + current_len = end | |
| 248 | + | |
| 249 | + total_len = len(input_ids_list) | |
| 250 | + device = self._device | |
| 251 | + neg_inf = torch.finfo(torch.float32).min | |
| 252 | + | |
| 253 | + allowed = torch.zeros((total_len, total_len), dtype=torch.bool, device=device) | |
| 254 | + prefix_causal = torch.tril( | |
| 255 | + torch.ones((prefix_len, prefix_len), dtype=torch.bool, device=device) | |
| 256 | + ) | |
| 257 | + allowed[:prefix_len, :prefix_len] = prefix_causal | |
| 258 | + for start, end in spans: | |
| 259 | + allowed[start:end, :prefix_len] = True | |
| 260 | + doc_len = end - start | |
| 261 | + allowed[start:end, start:end] = torch.tril( | |
| 262 | + torch.ones((doc_len, doc_len), dtype=torch.bool, device=device) | |
| 263 | + ) | |
| 264 | + | |
| 265 | + attention_mask = torch.full( | |
| 266 | + (total_len, total_len), | |
| 267 | + neg_inf, | |
| 268 | + dtype=torch.float32, | |
| 269 | + device=device, | |
| 270 | + ) | |
| 271 | + attention_mask.masked_fill_(allowed, 0.0) | |
| 272 | + | |
| 273 | + inputs = { | |
| 274 | + "input_ids": torch.tensor([input_ids_list], dtype=torch.long, device=device), | |
| 275 | + "position_ids": torch.tensor([position_ids_list], dtype=torch.long, device=device), | |
| 276 | + "attention_mask": attention_mask.view(1, 1, total_len, total_len), | |
| 277 | + } | |
| 278 | + logits_ids = torch.tensor( | |
| 279 | + [end - 1 for _, end in spans], | |
| 280 | + dtype=torch.long, | |
| 281 | + device=device, | |
| 282 | + ) | |
| 283 | + return inputs, logits_ids | |
| 284 | + | |
| 285 | + @torch.no_grad() | |
| 286 | + def _score_pack( | |
| 287 | + self, | |
| 288 | + query_prefix_tokens: Sequence[int], | |
| 289 | + doc_tokens: Sequence[Sequence[int]], | |
| 290 | + doc_indices: Sequence[int], | |
| 291 | + ) -> Tuple[List[float], int]: | |
| 292 | + inputs, logits_ids = self._build_pack_inputs( | |
| 293 | + query_prefix_tokens=query_prefix_tokens, | |
| 294 | + doc_tokens=doc_tokens, | |
| 295 | + doc_indices=doc_indices, | |
| 296 | + ) | |
| 297 | + outputs = self._model(**inputs) | |
| 298 | + scores = outputs.logits[0, logits_ids, :] | |
| 299 | + true_vector = scores[:, self._token_true_id] | |
| 300 | + false_vector = scores[:, self._token_false_id] | |
| 301 | + pair_scores = torch.stack([false_vector, true_vector], dim=1) | |
| 302 | + pair_scores = torch.nn.functional.log_softmax(pair_scores, dim=1) | |
| 303 | + return pair_scores[:, 1].exp().tolist(), int(inputs["input_ids"].shape[1]) | |
| 304 | + | |
| 305 | + def score_with_meta( | |
| 306 | + self, | |
| 307 | + query: str, | |
| 308 | + docs: List[str], | |
| 309 | + normalize: bool = True, | |
| 310 | + ) -> Tuple[List[float], Dict[str, Any]]: | |
| 311 | + start_ts = time.time() | |
| 312 | + total_docs = len(docs) if docs else 0 | |
| 313 | + output_scores: List[float] = [0.0] * total_docs | |
| 314 | + | |
| 315 | + query = "" if query is None else str(query).strip() | |
| 316 | + indexed: List[Tuple[int, str]] = [] | |
| 317 | + for i, doc in enumerate(docs or []): | |
| 318 | + if doc is None: | |
| 319 | + continue | |
| 320 | + text = str(doc).strip() | |
| 321 | + if not text: | |
| 322 | + continue | |
| 323 | + indexed.append((i, text)) | |
| 324 | + | |
| 325 | + if not query or not indexed: | |
| 326 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 327 | + return output_scores, { | |
| 328 | + "input_docs": total_docs, | |
| 329 | + "usable_docs": len(indexed), | |
| 330 | + "unique_docs": 0, | |
| 331 | + "dedup_ratio": 0.0, | |
| 332 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 333 | + "model": self._model_name, | |
| 334 | + "backend": "qwen3_transformers_packed", | |
| 335 | + "normalize": normalize, | |
| 336 | + "packed_batches": 0, | |
| 337 | + "max_model_len": self._max_model_len, | |
| 338 | + "max_doc_len": self._max_doc_len, | |
| 339 | + "sort_by_doc_length": self._sort_by_doc_length, | |
| 340 | + } | |
| 341 | + | |
| 342 | + indexed_texts = [text for _, text in indexed] | |
| 343 | + unique_texts, position_to_unique = _deduplicate_with_positions(indexed_texts) | |
| 344 | + | |
| 345 | + query_prefix_tokens = self._build_pair_prefix_tokens(query) | |
| 346 | + doc_tokens = self._tokenize_documents(unique_texts, query_prefix_len=len(query_prefix_tokens)) | |
| 347 | + pack_plan = self._build_pack_plan( | |
| 348 | + query_prefix_len=len(query_prefix_tokens), | |
| 349 | + doc_tokens=doc_tokens, | |
| 350 | + ) | |
| 351 | + | |
| 352 | + unique_scores: List[float] = [0.0] * len(unique_texts) | |
| 353 | + pack_lengths: List[int] = [] | |
| 354 | + with self._infer_lock: | |
| 355 | + for pack_doc_indices in pack_plan: | |
| 356 | + batch_scores, pack_seq_len = self._score_pack( | |
| 357 | + query_prefix_tokens=query_prefix_tokens, | |
| 358 | + doc_tokens=doc_tokens, | |
| 359 | + doc_indices=pack_doc_indices, | |
| 360 | + ) | |
| 361 | + if len(batch_scores) != len(pack_doc_indices): | |
| 362 | + raise RuntimeError( | |
| 363 | + "Packed reranker score size mismatch: " | |
| 364 | + f"expected {len(pack_doc_indices)}, got {len(batch_scores)}" | |
| 365 | + ) | |
| 366 | + for idx, score in zip(pack_doc_indices, batch_scores): | |
| 367 | + unique_scores[idx] = float(score) | |
| 368 | + pack_lengths.append(pack_seq_len) | |
| 369 | + | |
| 370 | + for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): | |
| 371 | + output_scores[orig_idx] = float(unique_scores[unique_idx]) | |
| 372 | + | |
| 373 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 374 | + dedup_ratio = 0.0 | |
| 375 | + if indexed: | |
| 376 | + dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) | |
| 377 | + | |
| 378 | + meta = { | |
| 379 | + "input_docs": total_docs, | |
| 380 | + "usable_docs": len(indexed), | |
| 381 | + "unique_docs": len(unique_texts), | |
| 382 | + "dedup_ratio": round(dedup_ratio, 4), | |
| 383 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 384 | + "model": self._model_name, | |
| 385 | + "backend": "qwen3_transformers_packed", | |
| 386 | + "normalize": normalize, | |
| 387 | + "packed_batches": len(pack_plan), | |
| 388 | + "packed_max_seq_len": max(pack_lengths) if pack_lengths else 0, | |
| 389 | + "packed_avg_seq_len": round(sum(pack_lengths) / len(pack_lengths), 3) | |
| 390 | + if pack_lengths | |
| 391 | + else 0.0, | |
| 392 | + "max_model_len": self._max_model_len, | |
| 393 | + "max_doc_len": self._max_doc_len, | |
| 394 | + "max_docs_per_pack": self._max_docs_per_pack, | |
| 395 | + "sort_by_doc_length": self._sort_by_doc_length, | |
| 396 | + "attn_implementation": self._attn_impl, | |
| 397 | + } | |
| 398 | + return output_scores, meta | ... | ... |
reranker/backends/qwen3_vllm.py
| ... | ... | @@ -45,7 +45,7 @@ def deduplicate_with_positions(texts: List[str]) -> Tuple[List[str], List[int]]: |
| 45 | 45 | return unique_texts, position_to_unique |
| 46 | 46 | |
| 47 | 47 | |
| 48 | -def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]: | |
| 48 | +def _format_instruction__standard(instruction: str, query: str, doc: str) -> List[Dict[str, str]]: | |
| 49 | 49 | """Build chat messages for one (query, doc) pair.""" |
| 50 | 50 | return [ |
| 51 | 51 | { |
| ... | ... | @@ -58,6 +58,18 @@ def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str |
| 58 | 58 | }, |
| 59 | 59 | ] |
| 60 | 60 | |
| 61 | +def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]: | |
| 62 | + """Build chat messages for one (query, doc) pair.""" | |
| 63 | + return [ | |
| 64 | + { | |
| 65 | + "role": "system", | |
| 66 | + "content": instruction, | |
| 67 | + }, | |
| 68 | + { | |
| 69 | + "role": "user", | |
| 70 | + "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}", | |
| 71 | + }, | |
| 72 | + ] | |
| 61 | 73 | |
| 62 | 74 | class Qwen3VLLMRerankerBackend: |
| 63 | 75 | """ |
| ... | ... | @@ -78,6 +90,17 @@ class Qwen3VLLMRerankerBackend: |
| 78 | 90 | self._config.get("instruction") |
| 79 | 91 | or "Given a query, score the product for relevance" |
| 80 | 92 | ) |
| 93 | + _fmt = str(self._config.get("instruction_format") or "compact").strip().lower() | |
| 94 | + if _fmt not in {"standard", "compact"}: | |
| 95 | + raise ValueError( | |
| 96 | + f"instruction_format must be 'standard' or 'compact', got {_fmt!r}" | |
| 97 | + ) | |
| 98 | + self._instruction_format = _fmt | |
| 99 | + self._format_messages = ( | |
| 100 | + _format_instruction__standard | |
| 101 | + if self._instruction_format == "standard" | |
| 102 | + else _format_instruction | |
| 103 | + ) | |
| 81 | 104 | infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get("infer_batch_size", 64) |
| 82 | 105 | sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH") |
| 83 | 106 | if sort_by_doc_length is None: |
| ... | ... | @@ -95,13 +118,15 @@ class Qwen3VLLMRerankerBackend: |
| 95 | 118 | ) |
| 96 | 119 | |
| 97 | 120 | logger.info( |
| 98 | - "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)", | |
| 121 | + "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s, " | |
| 122 | + "instruction_format=%s)", | |
| 99 | 123 | model_name, |
| 100 | 124 | max_model_len, |
| 101 | 125 | tensor_parallel_size, |
| 102 | 126 | gpu_memory_utilization, |
| 103 | 127 | dtype, |
| 104 | 128 | enable_prefix_caching, |
| 129 | + self._instruction_format, | |
| 105 | 130 | ) |
| 106 | 131 | |
| 107 | 132 | self._llm = LLM( |
| ... | ... | @@ -145,7 +170,7 @@ class Qwen3VLLMRerankerBackend: |
| 145 | 170 | ) -> List[TokensPrompt]: |
| 146 | 171 | """Build tokenized prompts for vLLM from (query, doc) pairs. Batch apply_chat_template.""" |
| 147 | 172 | messages_batch = [ |
| 148 | - _format_instruction(self._instruction, q, d) for q, d in pairs | |
| 173 | + self._format_messages(self._instruction, q, d) for q, d in pairs | |
| 149 | 174 | ] |
| 150 | 175 | tokenized = self._tokenizer.apply_chat_template( |
| 151 | 176 | messages_batch, |
| ... | ... | @@ -242,6 +267,7 @@ class Qwen3VLLMRerankerBackend: |
| 242 | 267 | "infer_batch_size": self._infer_batch_size, |
| 243 | 268 | "inference_batches": 0, |
| 244 | 269 | "sort_by_doc_length": self._sort_by_doc_length, |
| 270 | + "instruction_format": self._instruction_format, | |
| 245 | 271 | } |
| 246 | 272 | |
| 247 | 273 | # Deduplicate globally by text, keep mapping to original indices. |
| ... | ... | @@ -289,6 +315,7 @@ class Qwen3VLLMRerankerBackend: |
| 289 | 315 | "normalize": normalize, |
| 290 | 316 | "infer_batch_size": self._infer_batch_size, |
| 291 | 317 | "inference_batches": inference_batches, |
| 292 | - "sort_by_doc_length": self._sort_by_doc_length | |
| 318 | + "sort_by_doc_length": self._sort_by_doc_length, | |
| 319 | + "instruction_format": self._instruction_format, | |
| 293 | 320 | } |
| 294 | 321 | return output_scores, meta | ... | ... |
| ... | ... | @@ -0,0 +1,323 @@ |
| 1 | +""" | |
| 2 | +Qwen3-Reranker via vLLM ``LLM.score()`` (pooling / cross-encoder score API). | |
| 3 | + | |
| 4 | +Matches vLLM ``examples/offline_inference/qwen3_reranker.py``: paired | |
| 5 | +``llm.score(query_texts, doc_texts)`` with the recommended prefix/suffix templates. | |
| 6 | +Requires vLLM >= 0.17 (uses ``runner``/``convert`` auto, not legacy ``task="score"``). | |
| 7 | + | |
| 8 | +Dedicated venv: ``.venv-reranker-score`` + ``requirements_reranker_qwen3_vllm_score.txt`` | |
| 9 | +(see ``./scripts/setup_reranker_venv.sh qwen3_vllm_score``). Default ``model_name`` can match | |
| 10 | +``qwen3_vllm``; only the Python env differs for pinned high-performance vLLM. | |
| 11 | + | |
| 12 | +Reference: https://docs.vllm.ai/ — Qwen3 reranker example | |
| 13 | +""" | |
| 14 | + | |
| 15 | +from __future__ import annotations | |
| 16 | + | |
| 17 | +import logging | |
| 18 | +import os | |
| 19 | +import threading | |
| 20 | +import time | |
| 21 | +from typing import Any, Dict, List, Tuple | |
| 22 | + | |
| 23 | +logger = logging.getLogger("reranker.backends.qwen3_vllm_score") | |
| 24 | + | |
| 25 | +import torch | |
| 26 | +from vllm import LLM | |
| 27 | + | |
| 28 | +from reranker.backends.qwen3_vllm import deduplicate_with_positions | |
| 29 | + | |
| 30 | +# Official vLLM Qwen3 reranker prompt layout (im_start blocks + assistant suffix). | |
| 31 | +_DEFAULT_PREFIX = ( | |
| 32 | + "<|im_start|>system\n" | |
| 33 | + "Judge whether the Document meets the requirements based on the Query and the Instruct " | |
| 34 | + 'provided. Note that the answer can only be "yes" or "no".' | |
| 35 | + "<|im_end|>\n<|im_start|>user\n" | |
| 36 | +) | |
| 37 | +_DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" | |
| 38 | +_DEFAULT_QUERY_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n" | |
| 39 | +_DEFAULT_DOCUMENT_TEMPLATE = "<Document>: {doc}{suffix}" | |
| 40 | +# compact:与 qwen3_vllm._format_instruction 一致(instruction 作 system,user 内重复 Instruct) | |
| 41 | +_IM_USER_START = "<|im_end|>\n<|im_start|>user\n" | |
| 42 | + | |
| 43 | + | |
| 44 | +def _resolve_vllm_attention_config(config: Dict[str, Any]) -> Dict[str, Any] | None: | |
| 45 | + """ | |
| 46 | + vLLM 0.18 defaults to Flash-Attention paths that require compute capability >= 8 (Ampere+). | |
| 47 | + Turing / Volta (e.g. T4 sm_75) must use a non-FA backend such as TRITON_ATTN. | |
| 48 | + """ | |
| 49 | + env = (os.getenv("RERANK_VLLM_ATTENTION_BACKEND") or "").strip() | |
| 50 | + raw = config.get("vllm_attention_backend") | |
| 51 | + if env: | |
| 52 | + choice = env | |
| 53 | + elif raw is not None and str(raw).strip() and str(raw).strip().lower() != "auto": | |
| 54 | + choice = str(raw).strip() | |
| 55 | + else: | |
| 56 | + choice = "" | |
| 57 | + if choice: | |
| 58 | + backend = choice.strip().upper() | |
| 59 | + if backend == "AUTO": | |
| 60 | + choice = "" | |
| 61 | + else: | |
| 62 | + logger.info("[Qwen3_VLLM_SCORE] attention_config.backend=%s (from config/env)", backend) | |
| 63 | + return {"backend": backend} | |
| 64 | + | |
| 65 | + major, minor = torch.cuda.get_device_capability() | |
| 66 | + if major < 8: | |
| 67 | + logger.info( | |
| 68 | + "[Qwen3_VLLM_SCORE] GPU compute capability %d.%d < 8.0; using attention backend " | |
| 69 | + "TRITON_ATTN (Flash-Attention 2 requires sm >= 80). " | |
| 70 | + "Override with services.rerank.backends.qwen3_vllm_score.vllm_attention_backend " | |
| 71 | + "or RERANK_VLLM_ATTENTION_BACKEND.", | |
| 72 | + major, | |
| 73 | + minor, | |
| 74 | + ) | |
| 75 | + return {"backend": "TRITON_ATTN"} | |
| 76 | + return None | |
| 77 | + | |
| 78 | + | |
| 79 | +class Qwen3VLLMScoreRerankerBackend: | |
| 80 | + """ | |
| 81 | + Qwen3 reranker using vLLM ``LLM.score()`` (pooling runner) for cross-encoder scores. | |
| 82 | + | |
| 83 | + Config from ``services.rerank.backends.qwen3_vllm_score``. | |
| 84 | + """ | |
| 85 | + | |
| 86 | + def __init__(self, config: Dict[str, Any]) -> None: | |
| 87 | + self._config = config or {} | |
| 88 | + model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B") | |
| 89 | + max_model_len = int(self._config.get("max_model_len", 2048)) | |
| 90 | + tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1)) | |
| 91 | + gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.4)) | |
| 92 | + enable_prefix_caching = bool(self._config.get("enable_prefix_caching", False)) | |
| 93 | + enforce_eager = bool(self._config.get("enforce_eager", True)) | |
| 94 | + dtype = str(self._config.get("dtype", "float16")).strip().lower() | |
| 95 | + use_hf_overrides = self._config.get("use_original_qwen3_hf_overrides") | |
| 96 | + if use_hf_overrides is None: | |
| 97 | + use_hf_overrides = True | |
| 98 | + use_hf_overrides = bool(use_hf_overrides) | |
| 99 | + | |
| 100 | + self._instruction = str( | |
| 101 | + self._config.get("instruction") | |
| 102 | + or "Given a query, score the product for relevance" | |
| 103 | + ) | |
| 104 | + _fmt = str(self._config.get("instruction_format") or "standard").strip().lower() | |
| 105 | + if _fmt not in {"standard", "compact"}: | |
| 106 | + raise ValueError( | |
| 107 | + f"instruction_format must be 'standard' or 'compact', got {_fmt!r}" | |
| 108 | + ) | |
| 109 | + self._instruction_format = _fmt | |
| 110 | + self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX) | |
| 111 | + self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX) | |
| 112 | + self._query_template = str(self._config.get("query_template") or _DEFAULT_QUERY_TEMPLATE) | |
| 113 | + self._document_template = str( | |
| 114 | + self._config.get("document_template") or _DEFAULT_DOCUMENT_TEMPLATE | |
| 115 | + ) | |
| 116 | + | |
| 117 | + infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get( | |
| 118 | + "infer_batch_size", 64 | |
| 119 | + ) | |
| 120 | + sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH") | |
| 121 | + if sort_by_doc_length is None: | |
| 122 | + sort_by_doc_length = self._config.get("sort_by_doc_length", True) | |
| 123 | + | |
| 124 | + self._infer_batch_size = int(infer_batch_size) | |
| 125 | + self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in { | |
| 126 | + "1", | |
| 127 | + "true", | |
| 128 | + "yes", | |
| 129 | + "y", | |
| 130 | + "on", | |
| 131 | + } | |
| 132 | + | |
| 133 | + if not torch.cuda.is_available(): | |
| 134 | + raise RuntimeError( | |
| 135 | + "qwen3_vllm_score backend requires CUDA GPU, but torch.cuda.is_available() is False" | |
| 136 | + ) | |
| 137 | + if dtype not in {"float16", "half", "auto"}: | |
| 138 | + raise ValueError( | |
| 139 | + f"Unsupported dtype for qwen3_vllm_score: {dtype!r}. Use float16/half/auto." | |
| 140 | + ) | |
| 141 | + if self._infer_batch_size <= 0: | |
| 142 | + raise ValueError(f"infer_batch_size must be > 0, got {self._infer_batch_size}") | |
| 143 | + | |
| 144 | + runner = str(self._config.get("vllm_runner") or "auto").strip().lower() | |
| 145 | + convert = str(self._config.get("vllm_convert") or "auto").strip().lower() | |
| 146 | + if runner not in {"auto", "generate", "pooling", "draft"}: | |
| 147 | + raise ValueError(f"Invalid vllm_runner: {runner!r}") | |
| 148 | + if convert not in {"auto", "none", "embed", "classify"}: | |
| 149 | + raise ValueError(f"Invalid vllm_convert: {convert!r}") | |
| 150 | + | |
| 151 | + logger.info( | |
| 152 | + "[Qwen3_VLLM_SCORE] Loading model %s (LLM.score API, runner=%s, convert=%s, " | |
| 153 | + "hf_overrides=%s, max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s, " | |
| 154 | + "instruction_format=%s)", | |
| 155 | + model_name, | |
| 156 | + runner, | |
| 157 | + convert, | |
| 158 | + use_hf_overrides, | |
| 159 | + max_model_len, | |
| 160 | + tensor_parallel_size, | |
| 161 | + gpu_memory_utilization, | |
| 162 | + dtype, | |
| 163 | + enable_prefix_caching, | |
| 164 | + self._instruction_format, | |
| 165 | + ) | |
| 166 | + | |
| 167 | + # vLLM 0.17+ uses runner/convert instead of LLM(..., task="score"). With the official | |
| 168 | + # Qwen3 reranker hf_overrides, architecture becomes *ForSequenceClassification -> pooling+classify. | |
| 169 | + llm_kwargs: Dict[str, Any] = { | |
| 170 | + "model": model_name, | |
| 171 | + "runner": runner, | |
| 172 | + "convert": convert, | |
| 173 | + "tensor_parallel_size": tensor_parallel_size, | |
| 174 | + "max_model_len": max_model_len, | |
| 175 | + "gpu_memory_utilization": gpu_memory_utilization, | |
| 176 | + "enable_prefix_caching": enable_prefix_caching, | |
| 177 | + "enforce_eager": enforce_eager, | |
| 178 | + "dtype": dtype, | |
| 179 | + } | |
| 180 | + hf_overrides: Dict[str, Any] = dict(self._config.get("hf_overrides") or {}) | |
| 181 | + if use_hf_overrides: | |
| 182 | + hf_overrides = { | |
| 183 | + **hf_overrides, | |
| 184 | + "architectures": ["Qwen3ForSequenceClassification"], | |
| 185 | + "classifier_from_token": ["no", "yes"], | |
| 186 | + "is_original_qwen3_reranker": True, | |
| 187 | + } | |
| 188 | + if hf_overrides: | |
| 189 | + llm_kwargs["hf_overrides"] = hf_overrides | |
| 190 | + | |
| 191 | + attn_cfg = _resolve_vllm_attention_config(self._config) | |
| 192 | + if attn_cfg is not None: | |
| 193 | + llm_kwargs["attention_config"] = attn_cfg | |
| 194 | + | |
| 195 | + self._llm = LLM(**llm_kwargs) | |
| 196 | + # vLLM score path: single-process safety (mirrors generate backend until verified). | |
| 197 | + self._infer_lock = threading.Lock() | |
| 198 | + | |
| 199 | + self._model_name = model_name | |
| 200 | + logger.info("[Qwen3_VLLM_SCORE] Model ready | model=%s", model_name) | |
| 201 | + | |
| 202 | + def _format_pair(self, query: str, doc: str) -> Tuple[str, str]: | |
| 203 | + if self._instruction_format == "compact": | |
| 204 | + # Align with reranker.backends.qwen3_vllm._format_instruction query/doc split for LLM.score(). | |
| 205 | + compact_prefix = f"<|im_start|>system\n{self._instruction}{_IM_USER_START}" | |
| 206 | + q_text = ( | |
| 207 | + f"{compact_prefix}<Instruct>: {self._instruction}\n\n<Query>: {query}\n" | |
| 208 | + ) | |
| 209 | + d_text = f"\n<Document>: {doc}{self._suffix}" | |
| 210 | + return q_text, d_text | |
| 211 | + q_text = self._query_template.format( | |
| 212 | + prefix=self._prefix, | |
| 213 | + instruction=self._instruction, | |
| 214 | + query=query, | |
| 215 | + ) | |
| 216 | + d_text = self._document_template.format(doc=doc, suffix=self._suffix) | |
| 217 | + return q_text, d_text | |
| 218 | + | |
| 219 | + def _score_batch(self, pairs: List[Tuple[str, str]]) -> List[float]: | |
| 220 | + if not pairs: | |
| 221 | + return [] | |
| 222 | + queries: List[str] = [] | |
| 223 | + documents: List[str] = [] | |
| 224 | + for q, d in pairs: | |
| 225 | + qt, dt = self._format_pair(q, d) | |
| 226 | + queries.append(qt) | |
| 227 | + documents.append(dt) | |
| 228 | + with self._infer_lock: | |
| 229 | + outputs = self._llm.score(queries, documents, use_tqdm=False) | |
| 230 | + scores: List[float] = [] | |
| 231 | + for out in outputs: | |
| 232 | + so = out.outputs | |
| 233 | + scores.append(float(so.score)) | |
| 234 | + return scores | |
| 235 | + | |
| 236 | + @staticmethod | |
| 237 | + def _estimate_doc_lengths(docs: List[str]) -> List[int]: | |
| 238 | + if not docs: | |
| 239 | + return [] | |
| 240 | + return [len(text) for text in docs] | |
| 241 | + | |
| 242 | + def score_with_meta( | |
| 243 | + self, | |
| 244 | + query: str, | |
| 245 | + docs: List[str], | |
| 246 | + normalize: bool = True, | |
| 247 | + ) -> Tuple[List[float], Dict[str, Any]]: | |
| 248 | + start_ts = time.time() | |
| 249 | + total_docs = len(docs) if docs else 0 | |
| 250 | + output_scores: List[float] = [0.0] * total_docs | |
| 251 | + | |
| 252 | + query = "" if query is None else str(query).strip() | |
| 253 | + indexed: List[Tuple[int, str]] = [] | |
| 254 | + for i, doc in enumerate(docs or []): | |
| 255 | + if doc is None: | |
| 256 | + continue | |
| 257 | + text = str(doc).strip() | |
| 258 | + if not text: | |
| 259 | + continue | |
| 260 | + indexed.append((i, text)) | |
| 261 | + | |
| 262 | + if not query or not indexed: | |
| 263 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 264 | + return output_scores, { | |
| 265 | + "input_docs": total_docs, | |
| 266 | + "usable_docs": len(indexed), | |
| 267 | + "unique_docs": 0, | |
| 268 | + "dedup_ratio": 0.0, | |
| 269 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 270 | + "model": self._model_name, | |
| 271 | + "backend": "qwen3_vllm_score", | |
| 272 | + "normalize": normalize, | |
| 273 | + "infer_batch_size": self._infer_batch_size, | |
| 274 | + "inference_batches": 0, | |
| 275 | + "sort_by_doc_length": self._sort_by_doc_length, | |
| 276 | + "instruction_format": self._instruction_format, | |
| 277 | + } | |
| 278 | + | |
| 279 | + indexed_texts = [text for _, text in indexed] | |
| 280 | + unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts) | |
| 281 | + | |
| 282 | + lengths = self._estimate_doc_lengths(unique_texts) | |
| 283 | + order = list(range(len(unique_texts))) | |
| 284 | + if self._sort_by_doc_length and len(unique_texts) > 1: | |
| 285 | + order = sorted(order, key=lambda i: lengths[i]) | |
| 286 | + | |
| 287 | + unique_scores: List[float] = [0.0] * len(unique_texts) | |
| 288 | + inference_batches = 0 | |
| 289 | + for start in range(0, len(order), self._infer_batch_size): | |
| 290 | + batch_indices = order[start : start + self._infer_batch_size] | |
| 291 | + inference_batches += 1 | |
| 292 | + pairs = [(query, unique_texts[i]) for i in batch_indices] | |
| 293 | + batch_scores = self._score_batch(pairs) | |
| 294 | + if len(batch_scores) != len(batch_indices): | |
| 295 | + raise RuntimeError( | |
| 296 | + f"Reranker score size mismatch: expected {len(batch_indices)}, got {len(batch_scores)}" | |
| 297 | + ) | |
| 298 | + for idx, score in zip(batch_indices, batch_scores): | |
| 299 | + unique_scores[idx] = float(score) | |
| 300 | + | |
| 301 | + for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): | |
| 302 | + output_scores[orig_idx] = float(unique_scores[unique_idx]) | |
| 303 | + | |
| 304 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 305 | + dedup_ratio = 0.0 | |
| 306 | + if indexed: | |
| 307 | + dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) | |
| 308 | + | |
| 309 | + meta = { | |
| 310 | + "input_docs": total_docs, | |
| 311 | + "usable_docs": len(indexed), | |
| 312 | + "unique_docs": len(unique_texts), | |
| 313 | + "dedup_ratio": round(dedup_ratio, 4), | |
| 314 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 315 | + "model": self._model_name, | |
| 316 | + "backend": "qwen3_vllm_score", | |
| 317 | + "normalize": normalize, | |
| 318 | + "infer_batch_size": self._infer_batch_size, | |
| 319 | + "inference_batches": inference_batches, | |
| 320 | + "sort_by_doc_length": self._sort_by_doc_length, | |
| 321 | + "instruction_format": self._instruction_format, | |
| 322 | + } | |
| 323 | + return output_scores, meta | ... | ... |
reranker/server.py
| ... | ... | @@ -7,7 +7,7 @@ Request: { "query": "...", "docs": ["doc1", "doc2", ...], "normalize": optional |
| 7 | 7 | Response: { "scores": [float], "meta": {...} } |
| 8 | 8 | |
| 9 | 9 | Backend selected via config: services.rerank.backend |
| 10 | -(bge | qwen3_vllm | qwen3_transformers | dashscope_rerank), env RERANK_BACKEND. | |
| 10 | +(bge | qwen3_vllm | qwen3_vllm_score | qwen3_transformers | qwen3_transformers_packed | qwen3_gguf | qwen3_gguf_06b | dashscope_rerank), env RERANK_BACKEND. | |
| 11 | 11 | """ |
| 12 | 12 | |
| 13 | 13 | import logging |
| ... | ... | @@ -99,12 +99,17 @@ def health() -> Dict[str, Any]: |
| 99 | 99 | model_info = getattr(_reranker, "_model_name", None) or getattr( |
| 100 | 100 | _reranker, "_config", {} |
| 101 | 101 | ).get("model_name", _backend_name) |
| 102 | - return { | |
| 102 | + payload: Dict[str, Any] = { | |
| 103 | 103 | "status": "ok" if _reranker is not None else "unavailable", |
| 104 | 104 | "model_loaded": _reranker is not None, |
| 105 | 105 | "model": model_info, |
| 106 | 106 | "backend": _backend_name, |
| 107 | 107 | } |
| 108 | + if _reranker is not None: | |
| 109 | + _fmt = getattr(_reranker, "_instruction_format", None) | |
| 110 | + if _fmt is not None: | |
| 111 | + payload["instruction_format"] = _fmt | |
| 112 | + return payload | |
| 108 | 113 | |
| 109 | 114 | |
| 110 | 115 | @app.post("/rerank", response_model=RerankResponse) | ... | ... |
| ... | ... | @@ -0,0 +1,198 @@ |
| 1 | +#!/usr/bin/env python3 | |
| 2 | +""" | |
| 3 | +Local tuning probe for GGUF reranker backends. | |
| 4 | + | |
| 5 | +Runs the backend directly in a fresh process per config to measure: | |
| 6 | +- load time | |
| 7 | +- GPU memory used by this process | |
| 8 | +- single-request rerank latency | |
| 9 | + | |
| 10 | +Example: | |
| 11 | + ./.venv-reranker-gguf/bin/python scripts/benchmark_reranker_gguf_local.py | |
| 12 | + ./.venv-reranker-gguf-06b/bin/python scripts/benchmark_reranker_gguf_local.py --backend-name qwen3_gguf_06b --docs 400 | |
| 13 | +""" | |
| 14 | + | |
| 15 | +from __future__ import annotations | |
| 16 | + | |
| 17 | +import argparse | |
| 18 | +import json | |
| 19 | +import os | |
| 20 | +import random | |
| 21 | +import statistics | |
| 22 | +import subprocess | |
| 23 | +import sys | |
| 24 | +import time | |
| 25 | +from pathlib import Path | |
| 26 | +from typing import Any | |
| 27 | + | |
| 28 | + | |
| 29 | +DEFAULT_TITLES = Path("/home/ubuntu/rerank_test/titles.1.8w") | |
| 30 | + | |
| 31 | + | |
| 32 | +def load_titles(path: Path) -> list[str]: | |
| 33 | + items: list[str] = [] | |
| 34 | + with path.open(encoding="utf-8", errors="replace") as fh: | |
| 35 | + for line in fh: | |
| 36 | + text = line.strip() | |
| 37 | + if text: | |
| 38 | + items.append(text) | |
| 39 | + return items | |
| 40 | + | |
| 41 | + | |
| 42 | +def gpu_mem_for_pid(pid: int) -> int: | |
| 43 | + try: | |
| 44 | + out = subprocess.check_output( | |
| 45 | + [ | |
| 46 | + "nvidia-smi", | |
| 47 | + "--query-compute-apps=pid,used_gpu_memory", | |
| 48 | + "--format=csv,noheader,nounits", | |
| 49 | + ], | |
| 50 | + text=True, | |
| 51 | + ) | |
| 52 | + except Exception: | |
| 53 | + return -1 | |
| 54 | + for raw in out.splitlines(): | |
| 55 | + parts = [p.strip() for p in raw.split(",")] | |
| 56 | + if len(parts) != 2: | |
| 57 | + continue | |
| 58 | + try: | |
| 59 | + row_pid = int(parts[0]) | |
| 60 | + row_mem = int(parts[1]) | |
| 61 | + except ValueError: | |
| 62 | + continue | |
| 63 | + if row_pid == pid: | |
| 64 | + return row_mem | |
| 65 | + return -1 | |
| 66 | + | |
| 67 | + | |
| 68 | +def main() -> int: | |
| 69 | + parser = argparse.ArgumentParser() | |
| 70 | + parser.add_argument("--backend-name", type=str, default="qwen3_gguf") | |
| 71 | + parser.add_argument("--titles-file", type=Path, default=DEFAULT_TITLES) | |
| 72 | + parser.add_argument("--query", type=str, default="白色oversized T-shirt") | |
| 73 | + parser.add_argument("--docs", type=int, default=160) | |
| 74 | + parser.add_argument("--repeat", type=int, default=1) | |
| 75 | + parser.add_argument("--seed", type=int, default=42) | |
| 76 | + parser.add_argument( | |
| 77 | + "--configs-json", | |
| 78 | + type=str, | |
| 79 | + default="", | |
| 80 | + help="JSON array of config objects; when omitted, uses built-in scan set.", | |
| 81 | + ) | |
| 82 | + args = parser.parse_args() | |
| 83 | + | |
| 84 | + if not args.titles_file.is_file(): | |
| 85 | + print(f"missing titles file: {args.titles_file}", file=sys.stderr) | |
| 86 | + return 2 | |
| 87 | + | |
| 88 | + titles = load_titles(args.titles_file) | |
| 89 | + if len(titles) < args.docs: | |
| 90 | + print(f"not enough titles: need {args.docs}, got {len(titles)}", file=sys.stderr) | |
| 91 | + return 2 | |
| 92 | + | |
| 93 | + random.seed(args.seed) | |
| 94 | + docs = random.sample(titles, args.docs) | |
| 95 | + | |
| 96 | + if args.configs_json: | |
| 97 | + configs = json.loads(args.configs_json) | |
| 98 | + elif args.backend_name == "qwen3_gguf_06b": | |
| 99 | + configs = [ | |
| 100 | + {"name": "gguf_06b_full_256", "n_ctx": 256, "n_batch": 256, "n_ubatch": 256, "n_gpu_layers": 999}, | |
| 101 | + {"name": "gguf_06b_full_320", "n_ctx": 320, "n_batch": 320, "n_ubatch": 320, "n_gpu_layers": 999}, | |
| 102 | + {"name": "gguf_06b_full_384", "n_ctx": 384, "n_batch": 384, "n_ubatch": 384, "n_gpu_layers": 999}, | |
| 103 | + {"name": "gguf_06b_full_512", "n_ctx": 512, "n_batch": 512, "n_ubatch": 512, "n_gpu_layers": 999}, | |
| 104 | + ] | |
| 105 | + else: | |
| 106 | + configs = [ | |
| 107 | + {"name": "gguf_t4_24g", "n_ctx": 384, "n_batch": 384, "n_ubatch": 128, "n_gpu_layers": 24}, | |
| 108 | + {"name": "gguf_t4_40g", "n_ctx": 384, "n_batch": 384, "n_ubatch": 128, "n_gpu_layers": 40}, | |
| 109 | + {"name": "gguf_t4_full", "n_ctx": 384, "n_batch": 384, "n_ubatch": 128, "n_gpu_layers": 999}, | |
| 110 | + {"name": "gguf_t4_full_512", "n_ctx": 512, "n_batch": 512, "n_ubatch": 256, "n_gpu_layers": 999}, | |
| 111 | + {"name": "gguf_t4_full_512_u512", "n_ctx": 512, "n_batch": 512, "n_ubatch": 512, "n_gpu_layers": 999}, | |
| 112 | + {"name": "gguf_t4_full_768", "n_ctx": 768, "n_batch": 768, "n_ubatch": 256, "n_gpu_layers": 999}, | |
| 113 | + ] | |
| 114 | + | |
| 115 | + from reranker.backends.qwen3_gguf import Qwen3GGUFRerankerBackend | |
| 116 | + | |
| 117 | + default_cfg_by_backend: dict[str, dict[str, Any]] = { | |
| 118 | + "qwen3_gguf": { | |
| 119 | + "_backend_name": "qwen3_gguf", | |
| 120 | + "repo_id": "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF", | |
| 121 | + "filename": "*Q8_0.gguf", | |
| 122 | + "local_dir": "./models/reranker/qwen3-reranker-4b-gguf", | |
| 123 | + "infer_batch_size": 8, | |
| 124 | + }, | |
| 125 | + "qwen3_gguf_06b": { | |
| 126 | + "_backend_name": "qwen3_gguf_06b", | |
| 127 | + "repo_id": "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF", | |
| 128 | + "filename": "qwen3-reranker-0.6b-q8_0.gguf", | |
| 129 | + "local_dir": "./models/reranker/qwen3-reranker-0.6b-q8_0-gguf", | |
| 130 | + "infer_batch_size": 32, | |
| 131 | + }, | |
| 132 | + } | |
| 133 | + if args.backend_name not in default_cfg_by_backend: | |
| 134 | + print(f"unsupported backend: {args.backend_name}", file=sys.stderr) | |
| 135 | + return 2 | |
| 136 | + | |
| 137 | + base_cfg: dict[str, Any] = { | |
| 138 | + **default_cfg_by_backend[args.backend_name], | |
| 139 | + "instruction": "Rank products by query with category & style match prioritized", | |
| 140 | + "cache_dir": "./model_cache", | |
| 141 | + "main_gpu": 0, | |
| 142 | + "n_threads": 2, | |
| 143 | + "n_threads_batch": 4, | |
| 144 | + "flash_attn": True, | |
| 145 | + "offload_kqv": True, | |
| 146 | + "use_mmap": True, | |
| 147 | + "use_mlock": False, | |
| 148 | + "sort_by_doc_length": True, | |
| 149 | + "length_sort_mode": "char", | |
| 150 | + "enable_warmup": True, | |
| 151 | + "verbose": False, | |
| 152 | + "reuse_query_state": True, | |
| 153 | + } | |
| 154 | + | |
| 155 | + all_results: list[dict[str, Any]] = [] | |
| 156 | + for cfg in configs: | |
| 157 | + merged = dict(base_cfg) | |
| 158 | + merged.update(cfg) | |
| 159 | + name = str(merged.pop("name")) | |
| 160 | + | |
| 161 | + t0 = time.perf_counter() | |
| 162 | + backend = Qwen3GGUFRerankerBackend(merged) | |
| 163 | + load_ms = (time.perf_counter() - t0) * 1000.0 | |
| 164 | + gpu_mem_mib = gpu_mem_for_pid(os.getpid()) | |
| 165 | + | |
| 166 | + runs: list[float] = [] | |
| 167 | + last_meta: dict[str, Any] = {} | |
| 168 | + for _ in range(args.repeat): | |
| 169 | + t1 = time.perf_counter() | |
| 170 | + _scores, meta = backend.score_with_meta(args.query, docs, normalize=True) | |
| 171 | + runs.append((time.perf_counter() - t1) * 1000.0) | |
| 172 | + last_meta = dict(meta) | |
| 173 | + | |
| 174 | + result = { | |
| 175 | + "name": name, | |
| 176 | + "config": merged, | |
| 177 | + "load_ms": round(load_ms, 2), | |
| 178 | + "gpu_mem_mib": gpu_mem_mib, | |
| 179 | + "latency_ms_min": round(min(runs), 2), | |
| 180 | + "latency_ms_avg": round(statistics.mean(runs), 2), | |
| 181 | + "latency_ms_max": round(max(runs), 2), | |
| 182 | + "meta": last_meta, | |
| 183 | + } | |
| 184 | + all_results.append(result) | |
| 185 | + print(json.dumps(result, ensure_ascii=False)) | |
| 186 | + del backend | |
| 187 | + | |
| 188 | + print("SUMMARY") | |
| 189 | + for item in sorted(all_results, key=lambda x: x["latency_ms_avg"]): | |
| 190 | + print( | |
| 191 | + f'{item["name"]}: avg={item["latency_ms_avg"]}ms ' | |
| 192 | + f'gpu={item["gpu_mem_mib"]}MiB load={item["load_ms"]}ms' | |
| 193 | + ) | |
| 194 | + return 0 | |
| 195 | + | |
| 196 | + | |
| 197 | +if __name__ == "__main__": | |
| 198 | + raise SystemExit(main()) | ... | ... |
scripts/benchmark_reranker_random_titles.py
| ... | ... | @@ -6,6 +6,7 @@ Randomly samples N titles from a text file (one title per line), POSTs to the |
| 6 | 6 | rerank HTTP API, prints wall-clock latency. |
| 7 | 7 | |
| 8 | 8 | Supports multiple N values (comma-separated) and multiple repeats per N. |
| 9 | +Each invocation runs 3 warmup requests with n=400 first; those are not timed for summaries. | |
| 9 | 10 | |
| 10 | 11 | Example: |
| 11 | 12 | source activate.sh |
| ... | ... | @@ -149,6 +150,23 @@ def main() -> int: |
| 149 | 150 | action="store_true", |
| 150 | 151 | help="Print first ~500 chars of response body on success (last run only).", |
| 151 | 152 | ) |
| 153 | + parser.add_argument( | |
| 154 | + "--tag", | |
| 155 | + type=str, | |
| 156 | + default=os.environ.get("BENCH_TAG", ""), | |
| 157 | + help="Optional label stored in --json-summary-out (default: env BENCH_TAG or empty).", | |
| 158 | + ) | |
| 159 | + parser.add_argument( | |
| 160 | + "--json-summary-out", | |
| 161 | + type=Path, | |
| 162 | + default=None, | |
| 163 | + help="Write one JSON object with per-n latencies and aggregates for downstream tables.", | |
| 164 | + ) | |
| 165 | + parser.add_argument( | |
| 166 | + "--quiet-runs", | |
| 167 | + action="store_true", | |
| 168 | + help="Suppress per-run lines; still prints warmup lines and text summaries.", | |
| 169 | + ) | |
| 152 | 170 | args = parser.parse_args() |
| 153 | 171 | |
| 154 | 172 | try: |
| ... | ... | @@ -167,7 +185,9 @@ def main() -> int: |
| 167 | 185 | return 2 |
| 168 | 186 | |
| 169 | 187 | titles = _load_titles(args.titles_file) |
| 170 | - max_n = max(doc_counts) | |
| 188 | + warmup_n = 400 | |
| 189 | + warmup_runs = 3 | |
| 190 | + max_n = max(max(doc_counts), warmup_n) | |
| 171 | 191 | if len(titles) < max_n: |
| 172 | 192 | print( |
| 173 | 193 | f"error: file has only {len(titles)} non-empty lines, need at least {max_n}", |
| ... | ... | @@ -181,6 +201,33 @@ def main() -> int: |
| 181 | 201 | summary: dict[int, List[float]] = {n: [] for n in doc_counts} |
| 182 | 202 | |
| 183 | 203 | with httpx.Client(timeout=args.timeout) as client: |
| 204 | + for w in range(warmup_runs): | |
| 205 | + if args.seed is not None: | |
| 206 | + random.seed(args.seed + 8_000_000 + w) | |
| 207 | + docs_w = random.sample(titles, warmup_n) | |
| 208 | + try: | |
| 209 | + ok_w, status_w, _elapsed_w, scores_len_w, _text_w = _do_rerank( | |
| 210 | + client, | |
| 211 | + args.url, | |
| 212 | + args.query, | |
| 213 | + docs_w, | |
| 214 | + top_n=top_n, | |
| 215 | + normalize=normalize, | |
| 216 | + ) | |
| 217 | + except httpx.HTTPError as exc: | |
| 218 | + print( | |
| 219 | + f"warmup n={warmup_n} {w + 1}/{warmup_runs} error: request failed: {exc}", | |
| 220 | + file=sys.stderr, | |
| 221 | + ) | |
| 222 | + any_fail = True | |
| 223 | + continue | |
| 224 | + if not ok_w: | |
| 225 | + any_fail = True | |
| 226 | + print( | |
| 227 | + f"warmup n={warmup_n} {w + 1}/{warmup_runs} status={status_w} " | |
| 228 | + f"scores={scores_len_w if scores_len_w is not None else 'n/a'} (not timed)" | |
| 229 | + ) | |
| 230 | + | |
| 184 | 231 | for n in doc_counts: |
| 185 | 232 | for run_idx in range(repeat): |
| 186 | 233 | if args.seed is not None: |
| ... | ... | @@ -208,10 +255,11 @@ def main() -> int: |
| 208 | 255 | else: |
| 209 | 256 | any_fail = True |
| 210 | 257 | |
| 211 | - print( | |
| 212 | - f"n={n} run={run_idx + 1}/{repeat} status={status} " | |
| 213 | - f"latency_ms={elapsed_ms:.2f} scores={scores_len if scores_len is not None else 'n/a'}" | |
| 214 | - ) | |
| 258 | + if not args.quiet_runs: | |
| 259 | + print( | |
| 260 | + f"n={n} run={run_idx + 1}/{repeat} status={status} " | |
| 261 | + f"latency_ms={elapsed_ms:.2f} scores={scores_len if scores_len is not None else 'n/a'}" | |
| 262 | + ) | |
| 215 | 263 | if args.print_body_preview and text and run_idx == repeat - 1 and n == doc_counts[-1]: |
| 216 | 264 | preview = text[:500] + ("…" if len(text) > 500 else "") |
| 217 | 265 | print(preview) |
| ... | ... | @@ -230,6 +278,33 @@ def main() -> int: |
| 230 | 278 | f"summary n={n} runs={len(lat)} min_ms={lo:.2f} max_ms={hi:.2f} avg_ms={avg:.2f}{extra}" |
| 231 | 279 | ) |
| 232 | 280 | |
| 281 | + if args.json_summary_out is not None: | |
| 282 | + per_n: dict = {} | |
| 283 | + for n in doc_counts: | |
| 284 | + lat = summary[n] | |
| 285 | + row: dict = {"values_ms": lat, "runs": len(lat)} | |
| 286 | + if lat: | |
| 287 | + row["mean_ms"] = statistics.mean(lat) | |
| 288 | + row["min_ms"] = min(lat) | |
| 289 | + row["max_ms"] = max(lat) | |
| 290 | + if len(lat) >= 2: | |
| 291 | + row["stdev_ms"] = statistics.stdev(lat) | |
| 292 | + per_n[str(n)] = row | |
| 293 | + out_obj = { | |
| 294 | + "tag": args.tag or None, | |
| 295 | + "doc_counts": doc_counts, | |
| 296 | + "repeat": repeat, | |
| 297 | + "url": args.url, | |
| 298 | + "per_n": per_n, | |
| 299 | + "failed": bool(any_fail), | |
| 300 | + } | |
| 301 | + args.json_summary_out.parent.mkdir(parents=True, exist_ok=True) | |
| 302 | + args.json_summary_out.write_text( | |
| 303 | + json.dumps(out_obj, ensure_ascii=False, indent=2) + "\n", | |
| 304 | + encoding="utf-8", | |
| 305 | + ) | |
| 306 | + print(f"wrote json summary -> {args.json_summary_out}") | |
| 307 | + | |
| 233 | 308 | return 1 if any_fail else 0 |
| 234 | 309 | |
| 235 | 310 | ... | ... |
| ... | ... | @@ -0,0 +1,68 @@ |
| 1 | +#!/bin/bash | |
| 2 | +# | |
| 3 | +# Shared helpers for mapping reranker backends to isolated virtualenvs. | |
| 4 | +# | |
| 5 | + | |
| 6 | +set -euo pipefail | |
| 7 | + | |
| 8 | +detect_rerank_backend() { | |
| 9 | + local project_root="$1" | |
| 10 | + local backend="${RERANK_BACKEND:-}" | |
| 11 | + | |
| 12 | + if [[ -n "${backend}" ]]; then | |
| 13 | + printf '%s\n' "${backend}" | |
| 14 | + return 0 | |
| 15 | + fi | |
| 16 | + | |
| 17 | + backend="$( | |
| 18 | + awk ' | |
| 19 | + /^ rerank:$/ { in_rerank=1; next } | |
| 20 | + in_rerank && /^ [^ ]/ { in_rerank=0 } | |
| 21 | + in_rerank && /^ backend:/ { | |
| 22 | + gsub(/"/, "", $2) | |
| 23 | + print $2 | |
| 24 | + exit | |
| 25 | + } | |
| 26 | + ' "${project_root}/config/config.yaml" | |
| 27 | + )" | |
| 28 | + | |
| 29 | + if [[ -z "${backend}" ]]; then | |
| 30 | + backend="qwen3_vllm" | |
| 31 | + fi | |
| 32 | + | |
| 33 | + printf '%s\n' "${backend}" | |
| 34 | +} | |
| 35 | + | |
| 36 | +reranker_backend_venv_dir() { | |
| 37 | + local project_root="$1" | |
| 38 | + local backend="$2" | |
| 39 | + | |
| 40 | + case "${backend}" in | |
| 41 | + qwen3_vllm) printf '%s/.venv-reranker\n' "${project_root}" ;; | |
| 42 | + qwen3_vllm_score) printf '%s/.venv-reranker-score\n' "${project_root}" ;; | |
| 43 | + qwen3_gguf) printf '%s/.venv-reranker-gguf\n' "${project_root}" ;; | |
| 44 | + qwen3_gguf_06b) printf '%s/.venv-reranker-gguf-06b\n' "${project_root}" ;; | |
| 45 | + qwen3_transformers) printf '%s/.venv-reranker-transformers\n' "${project_root}" ;; | |
| 46 | + qwen3_transformers_packed) printf '%s/.venv-reranker-transformers-packed\n' "${project_root}" ;; | |
| 47 | + bge) printf '%s/.venv-reranker-bge\n' "${project_root}" ;; | |
| 48 | + dashscope_rerank) printf '%s/.venv-reranker-dashscope\n' "${project_root}" ;; | |
| 49 | + *) printf '%s/.venv-reranker-%s\n' "${project_root}" "${backend}" ;; | |
| 50 | + esac | |
| 51 | +} | |
| 52 | + | |
| 53 | +reranker_backend_requirements_file() { | |
| 54 | + local project_root="$1" | |
| 55 | + local backend="$2" | |
| 56 | + | |
| 57 | + case "${backend}" in | |
| 58 | + qwen3_vllm) printf '%s/requirements_reranker_qwen3_vllm.txt\n' "${project_root}" ;; | |
| 59 | + qwen3_vllm_score) printf '%s/requirements_reranker_qwen3_vllm_score.txt\n' "${project_root}" ;; | |
| 60 | + qwen3_gguf) printf '%s/requirements_reranker_qwen3_gguf.txt\n' "${project_root}" ;; | |
| 61 | + qwen3_gguf_06b) printf '%s/requirements_reranker_qwen3_gguf_06b.txt\n' "${project_root}" ;; | |
| 62 | + qwen3_transformers) printf '%s/requirements_reranker_qwen3_transformers.txt\n' "${project_root}" ;; | |
| 63 | + qwen3_transformers_packed) printf '%s/requirements_reranker_qwen3_transformers_packed.txt\n' "${project_root}" ;; | |
| 64 | + bge) printf '%s/requirements_reranker_bge.txt\n' "${project_root}" ;; | |
| 65 | + dashscope_rerank) printf '%s/requirements_reranker_dashscope.txt\n' "${project_root}" ;; | |
| 66 | + *) return 1 ;; | |
| 67 | + esac | |
| 68 | +} | ... | ... |
| ... | ... | @@ -0,0 +1,100 @@ |
| 1 | +#!/usr/bin/env python3 | |
| 2 | +""" | |
| 3 | +Surgically patch config/config.yaml: | |
| 4 | + services.rerank.backend | |
| 5 | + services.rerank.backends.qwen3_vllm.instruction_format | |
| 6 | + services.rerank.backends.qwen3_vllm_score.instruction_format | |
| 7 | + | |
| 8 | +Preserves comments and unrelated lines. Used for benchmark matrix runs. | |
| 9 | +""" | |
| 10 | + | |
| 11 | +from __future__ import annotations | |
| 12 | + | |
| 13 | +import argparse | |
| 14 | +import re | |
| 15 | +import sys | |
| 16 | +from pathlib import Path | |
| 17 | + | |
| 18 | + | |
| 19 | +def _with_stripped_body(line: str) -> tuple[str, str]: | |
| 20 | + """Return (body without end newline, newline suffix including '' if none).""" | |
| 21 | + if line.endswith("\r\n"): | |
| 22 | + return line[:-2], "\r\n" | |
| 23 | + if line.endswith("\n"): | |
| 24 | + return line[:-1], "\n" | |
| 25 | + return line, "" | |
| 26 | + | |
| 27 | + | |
| 28 | +def _patch_backend_in_rerank_block(lines: list[str], backend: str) -> None: | |
| 29 | + in_rerank = False | |
| 30 | + for i, line in enumerate(lines): | |
| 31 | + if line.startswith(" rerank:"): | |
| 32 | + in_rerank = True | |
| 33 | + continue | |
| 34 | + if in_rerank: | |
| 35 | + if line.startswith(" ") and not line.startswith(" ") and line.strip(): | |
| 36 | + in_rerank = False | |
| 37 | + continue | |
| 38 | + body, nl = _with_stripped_body(line) | |
| 39 | + m = re.match(r'^(\s*backend:\s*")[^"]+(".*)$', body) | |
| 40 | + if m: | |
| 41 | + lines[i] = f'{m.group(1)}{backend}{m.group(2)}{nl}' | |
| 42 | + return | |
| 43 | + raise RuntimeError("services.rerank.backend line not found") | |
| 44 | + | |
| 45 | + | |
| 46 | +def _patch_instruction_format_under_backend( | |
| 47 | + lines: list[str], section: str, fmt: str | |
| 48 | +) -> None: | |
| 49 | + """section is 'qwen3_vllm' or 'qwen3_vllm_score' (first line is ' qwen3_vllm:').""" | |
| 50 | + header = f" {section}:" | |
| 51 | + start = None | |
| 52 | + for i, line in enumerate(lines): | |
| 53 | + if line.rstrip() == header: | |
| 54 | + start = i | |
| 55 | + break | |
| 56 | + if start is None: | |
| 57 | + raise RuntimeError(f"section {section!r} not found") | |
| 58 | + | |
| 59 | + for j in range(start + 1, len(lines)): | |
| 60 | + line = lines[j] | |
| 61 | + body, nl = _with_stripped_body(line) | |
| 62 | + if re.match(r"^ [a-zA-Z0-9_]+:\s*$", body): | |
| 63 | + break | |
| 64 | + m = re.match(r"^(\s*instruction_format:\s*)\S+", body) | |
| 65 | + if m: | |
| 66 | + lines[j] = f"{m.group(1)}{fmt}{nl}" | |
| 67 | + return | |
| 68 | + raise RuntimeError(f"instruction_format not found under {section!r}") | |
| 69 | + | |
| 70 | + | |
| 71 | +def main() -> int: | |
| 72 | + p = argparse.ArgumentParser() | |
| 73 | + p.add_argument( | |
| 74 | + "--config", | |
| 75 | + type=Path, | |
| 76 | + default=Path(__file__).resolve().parent.parent / "config" / "config.yaml", | |
| 77 | + ) | |
| 78 | + p.add_argument("--backend", choices=("qwen3_vllm", "qwen3_vllm_score"), required=True) | |
| 79 | + p.add_argument( | |
| 80 | + "--instruction-format", | |
| 81 | + dest="instruction_format", | |
| 82 | + choices=("compact", "standard"), | |
| 83 | + required=True, | |
| 84 | + ) | |
| 85 | + args = p.parse_args() | |
| 86 | + text = args.config.read_text(encoding="utf-8") | |
| 87 | + lines = text.splitlines(keepends=True) | |
| 88 | + if not lines: | |
| 89 | + print("empty config", file=sys.stderr) | |
| 90 | + return 2 | |
| 91 | + _patch_backend_in_rerank_block(lines, args.backend) | |
| 92 | + _patch_instruction_format_under_backend(lines, "qwen3_vllm", args.instruction_format) | |
| 93 | + _patch_instruction_format_under_backend(lines, "qwen3_vllm_score", args.instruction_format) | |
| 94 | + args.config.write_text("".join(lines), encoding="utf-8") | |
| 95 | + print(f"patched {args.config}: backend={args.backend} instruction_format={args.instruction_format} (both vLLM blocks)") | |
| 96 | + return 0 | |
| 97 | + | |
| 98 | + | |
| 99 | +if __name__ == "__main__": | |
| 100 | + raise SystemExit(main()) | ... | ... |
| ... | ... | @@ -0,0 +1,89 @@ |
| 1 | +#!/usr/bin/env bash | |
| 2 | +# Patch config, restart reranker, wait for /health, run benchmark_reranker_random_titles.py. | |
| 3 | +# Requires: curl, .venv with PyYAML not needed (patch is standalone Python). | |
| 4 | + | |
| 5 | +set -euo pipefail | |
| 6 | +ROOT="$(cd "$(dirname "$0")/.." && pwd)" | |
| 7 | +cd "$ROOT" | |
| 8 | + | |
| 9 | +PYTHON="${ROOT}/.venv/bin/python" | |
| 10 | +DAY="$(date +%F)" | |
| 11 | +OUT_DIR="${ROOT}/perf_reports/reranker_vllm_instruction/${DAY}" | |
| 12 | +mkdir -p "$OUT_DIR" | |
| 13 | + | |
| 14 | +health_ok() { | |
| 15 | + local want_backend="$1" | |
| 16 | + local want_fmt="$2" | |
| 17 | + local body | |
| 18 | + if ! body="$(curl -sS --connect-timeout 2 --max-time 5 "http://127.0.0.1:6007/health" 2>/dev/null)"; then | |
| 19 | + return 1 | |
| 20 | + fi | |
| 21 | + echo "$body" | "$PYTHON" -c " | |
| 22 | +import json, sys | |
| 23 | +want_b, want_f = sys.argv[1], sys.argv[2] | |
| 24 | +d = json.load(sys.stdin) | |
| 25 | +if d.get('status') != 'ok' or not d.get('model_loaded'): | |
| 26 | + sys.exit(1) | |
| 27 | +if d.get('backend') != want_b: | |
| 28 | + sys.exit(1) | |
| 29 | +if d.get('instruction_format') != want_f: | |
| 30 | + sys.exit(1) | |
| 31 | +sys.exit(0) | |
| 32 | +" "$want_backend" "$want_fmt" | |
| 33 | +} | |
| 34 | + | |
| 35 | +wait_health() { | |
| 36 | + local want_backend="$1" | |
| 37 | + local want_fmt="$2" | |
| 38 | + local i | |
| 39 | + for i in $(seq 1 180); do | |
| 40 | + if health_ok "$want_backend" "$want_fmt"; then | |
| 41 | + curl -sS "http://127.0.0.1:6007/health" | "$PYTHON" -m json.tool | |
| 42 | + return 0 | |
| 43 | + fi | |
| 44 | + echo "[wait] ${i}/180 backend=${want_backend} instruction_format=${want_fmt} ..." | |
| 45 | + sleep 3 | |
| 46 | + done | |
| 47 | + echo "[error] health did not match in time" >&2 | |
| 48 | + return 1 | |
| 49 | +} | |
| 50 | + | |
| 51 | +run_one() { | |
| 52 | + local backend="$1" | |
| 53 | + local fmt="$2" | |
| 54 | + local tag="${backend}|${fmt}" | |
| 55 | + local jf="${OUT_DIR}/${backend}_${fmt}.json" | |
| 56 | + | |
| 57 | + echo "========== ${tag} ==========" | |
| 58 | + "$PYTHON" "${ROOT}/scripts/patch_rerank_vllm_benchmark_config.py" \ | |
| 59 | + --backend "$backend" --instruction-format "$fmt" | |
| 60 | + | |
| 61 | + "${ROOT}/restart.sh" reranker | |
| 62 | + wait_health "$backend" "$fmt" | |
| 63 | + | |
| 64 | + if ! "$PYTHON" "${ROOT}/scripts/benchmark_reranker_random_titles.py" \ | |
| 65 | + 100,200,400,600,800,1000 \ | |
| 66 | + --repeat 5 \ | |
| 67 | + --seed 42 \ | |
| 68 | + --quiet-runs \ | |
| 69 | + --timeout 360 \ | |
| 70 | + --tag "$tag" \ | |
| 71 | + --json-summary-out "$jf" | |
| 72 | + then | |
| 73 | + echo "[warn] benchmark exited non-zero for ${tag} (see ${jf} failed flag / partial runs)" >&2 | |
| 74 | + fi | |
| 75 | + | |
| 76 | + echo "artifact: $jf" | |
| 77 | +} | |
| 78 | + | |
| 79 | +run_one qwen3_vllm compact | |
| 80 | +run_one qwen3_vllm standard | |
| 81 | +run_one qwen3_vllm_score compact | |
| 82 | +run_one qwen3_vllm_score standard | |
| 83 | + | |
| 84 | +# Restore repo-default-style rerank settings (score + compact). | |
| 85 | +"$PYTHON" "${ROOT}/scripts/patch_rerank_vllm_benchmark_config.py" \ | |
| 86 | + --backend qwen3_vllm_score --instruction-format compact | |
| 87 | +"${ROOT}/restart.sh" reranker | |
| 88 | +wait_health qwen3_vllm_score compact | |
| 89 | +echo "Restored config: qwen3_vllm_score + compact. Done. Artifacts under ${OUT_DIR}" | ... | ... |
scripts/setup_reranker_venv.sh
| 1 | 1 | #!/bin/bash |
| 2 | 2 | # |
| 3 | -# Create isolated venv for reranker service (.venv-reranker). | |
| 3 | +# Create isolated venv for one reranker backend. | |
| 4 | 4 | # |
| 5 | 5 | set -euo pipefail |
| 6 | 6 | |
| 7 | 7 | PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" |
| 8 | 8 | cd "${PROJECT_ROOT}" |
| 9 | 9 | |
| 10 | -VENV_DIR="${PROJECT_ROOT}/.venv-reranker" | |
| 11 | 10 | PYTHON_BIN="${PYTHON_BIN:-python3}" |
| 12 | 11 | TMP_DIR="${RERANKER_PIP_TMPDIR:-${PROJECT_ROOT}/.tmp/reranker-pip}" |
| 13 | 12 | |
| 13 | +# shellcheck source=scripts/lib/load_env.sh | |
| 14 | +source "${PROJECT_ROOT}/scripts/lib/load_env.sh" | |
| 15 | +load_env_file "${PROJECT_ROOT}/.env" | |
| 16 | +# shellcheck source=scripts/lib/reranker_backend_env.sh | |
| 17 | +source "${PROJECT_ROOT}/scripts/lib/reranker_backend_env.sh" | |
| 18 | + | |
| 19 | +BACKEND="${1:-$(detect_rerank_backend "${PROJECT_ROOT}")}" | |
| 20 | +VENV_DIR="${RERANKER_VENV:-$(reranker_backend_venv_dir "${PROJECT_ROOT}" "${BACKEND}")}" | |
| 21 | +REQ_FILE="$(reranker_backend_requirements_file "${PROJECT_ROOT}" "${BACKEND}")" | |
| 22 | + | |
| 23 | +if [[ ! -f "${REQ_FILE}" ]]; then | |
| 24 | + echo "ERROR: requirements file not found for reranker backend ${BACKEND}: ${REQ_FILE}" >&2 | |
| 25 | + exit 1 | |
| 26 | +fi | |
| 27 | + | |
| 14 | 28 | if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then |
| 15 | 29 | echo "ERROR: python not found: ${PYTHON_BIN}" >&2 |
| 16 | 30 | exit 1 |
| ... | ... | @@ -34,9 +48,35 @@ PIP_ARGS=(--no-cache-dir) |
| 34 | 48 | |
| 35 | 49 | echo "Using TMPDIR=${TMPDIR}" |
| 36 | 50 | "${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" --upgrade pip wheel |
| 37 | -"${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" -r requirements_reranker_service.txt | |
| 51 | +"${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" -r "${REQ_FILE}" | |
| 52 | + | |
| 53 | +if [[ "${BACKEND}" == qwen3_gguf* ]]; then | |
| 54 | + if [[ -x "/usr/local/cuda/bin/nvcc" ]]; then | |
| 55 | + "${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" \ | |
| 56 | + cmake \ | |
| 57 | + ninja \ | |
| 58 | + scikit-build-core \ | |
| 59 | + flit_core \ | |
| 60 | + setuptools-scm | |
| 61 | + echo "Rebuilding llama-cpp-python with CUDA support for ${BACKEND}" | |
| 62 | + PATH="/usr/local/cuda/bin:/usr/bin:/bin" \ | |
| 63 | + CC="/usr/bin/x86_64-linux-gnu-gcc" \ | |
| 64 | + CXX="/usr/bin/x86_64-linux-gnu-g++" \ | |
| 65 | + CUDACXX="/usr/local/cuda/bin/nvcc" \ | |
| 66 | + CMAKE_ARGS="-DGGML_CUDA=on" \ | |
| 67 | + FORCE_CMAKE=1 \ | |
| 68 | + "${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" \ | |
| 69 | + --force-reinstall \ | |
| 70 | + --no-build-isolation \ | |
| 71 | + "llama-cpp-python==0.3.18" | |
| 72 | + else | |
| 73 | + echo "WARNING: /usr/local/cuda/bin/nvcc not found; ${BACKEND} will be installed without CUDA support." >&2 | |
| 74 | + fi | |
| 75 | +fi | |
| 38 | 76 | |
| 39 | 77 | echo |
| 40 | 78 | echo "Done." |
| 79 | +echo "Backend: ${BACKEND}" | |
| 41 | 80 | echo "Reranker venv: ${VENV_DIR}" |
| 81 | +echo "Requirements: ${REQ_FILE}" | |
| 42 | 82 | echo "Start service: ./scripts/start_reranker.sh" | ... | ... |
scripts/start_reranker.sh
| 1 | 1 | #!/bin/bash |
| 2 | 2 | # |
| 3 | -# Start reranker service from isolated venv (.venv-reranker). | |
| 3 | +# Start reranker service from its backend-specific isolated venv. | |
| 4 | 4 | # |
| 5 | 5 | set -euo pipefail |
| 6 | 6 | |
| 7 | 7 | PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" |
| 8 | 8 | cd "${PROJECT_ROOT}" |
| 9 | 9 | |
| 10 | -RERANKER_VENV="${RERANKER_VENV:-${PROJECT_ROOT}/.venv-reranker}" | |
| 11 | -PYTHON_BIN="${RERANKER_VENV}/bin/python" | |
| 12 | - | |
| 13 | -if [[ ! -x "${PYTHON_BIN}" ]]; then | |
| 14 | - echo "ERROR: reranker venv not found: ${RERANKER_VENV}" >&2 | |
| 15 | - echo "Please run: ./scripts/setup_reranker_venv.sh" >&2 | |
| 16 | - exit 1 | |
| 17 | -fi | |
| 18 | - | |
| 19 | 10 | # Load .env without activating main venv. |
| 20 | 11 | # shellcheck source=scripts/lib/load_env.sh |
| 21 | 12 | source "${PROJECT_ROOT}/scripts/lib/load_env.sh" |
| 22 | 13 | load_env_file "${PROJECT_ROOT}/.env" |
| 14 | +# shellcheck source=scripts/lib/reranker_backend_env.sh | |
| 15 | +source "${PROJECT_ROOT}/scripts/lib/reranker_backend_env.sh" | |
| 23 | 16 | |
| 24 | 17 | RERANKER_HOST="${RERANKER_HOST:-0.0.0.0}" |
| 25 | 18 | RERANKER_PORT="${RERANKER_PORT:-6007}" |
| 26 | -RERANK_BACKEND=$("${PYTHON_BIN}" -c "from config.services_config import get_rerank_backend_config; print(get_rerank_backend_config()[0])") | |
| 19 | +RERANK_BACKEND="${RERANK_BACKEND:-$(detect_rerank_backend "${PROJECT_ROOT}")}" | |
| 20 | +RERANKER_VENV="${RERANKER_VENV:-$(reranker_backend_venv_dir "${PROJECT_ROOT}" "${RERANK_BACKEND}")}" | |
| 21 | +PYTHON_BIN="${RERANKER_VENV}/bin/python" | |
| 22 | + | |
| 23 | +if [[ ! -x "${PYTHON_BIN}" ]]; then | |
| 24 | + echo "ERROR: reranker venv not found for backend ${RERANK_BACKEND}: ${RERANKER_VENV}" >&2 | |
| 25 | + echo "Please run: ./scripts/setup_reranker_venv.sh ${RERANK_BACKEND}" >&2 | |
| 26 | + exit 1 | |
| 27 | +fi | |
| 27 | 28 | |
| 28 | 29 | # Keep vLLM/triton/torch caches out of system disk. |
| 29 | 30 | RERANKER_RUNTIME_DIR="${RERANKER_RUNTIME_DIR:-${PROJECT_ROOT}/.runtime/reranker}" |
| ... | ... | @@ -42,23 +43,56 @@ export TMPDIR="${RERANKER_RUNTIME_DIR}/tmp" |
| 42 | 43 | export VLLM_NO_USAGE_STATS="${VLLM_NO_USAGE_STATS:-1}" |
| 43 | 44 | export PATH="${RERANKER_VENV}/bin:${PATH}" |
| 44 | 45 | |
| 45 | -if [[ "${RERANK_BACKEND}" == "qwen3_vllm" ]]; then | |
| 46 | +if [[ "${RERANK_BACKEND}" == qwen3_gguf* ]]; then | |
| 47 | + export HF_HUB_DISABLE_XET="${HF_HUB_DISABLE_XET:-1}" | |
| 48 | +fi | |
| 49 | + | |
| 50 | +if [[ "${RERANK_BACKEND}" == "qwen3_vllm" || "${RERANK_BACKEND}" == "qwen3_vllm_score" || "${RERANK_BACKEND}" == "qwen3_transformers_packed" ]]; then | |
| 46 | 51 | if ! command -v nvidia-smi >/dev/null 2>&1 || ! nvidia-smi >/dev/null 2>&1; then |
| 47 | - echo "ERROR: qwen3_vllm backend requires NVIDIA GPU, but nvidia-smi is unavailable." >&2 | |
| 52 | + echo "ERROR: ${RERANK_BACKEND} backend requires NVIDIA GPU, but nvidia-smi is unavailable." >&2 | |
| 48 | 53 | exit 1 |
| 49 | 54 | fi |
| 50 | 55 | if ! "${PYTHON_BIN}" - <<'PY' |
| 51 | 56 | try: |
| 52 | - import vllm # noqa: F401 | |
| 53 | 57 | import torch |
| 58 | + try: | |
| 59 | + import vllm # noqa: F401 | |
| 60 | + except Exception: | |
| 61 | + pass | |
| 54 | 62 | if not torch.cuda.is_available(): |
| 55 | 63 | raise SystemExit(1) |
| 56 | 64 | except Exception: |
| 57 | 65 | raise SystemExit(1) |
| 58 | 66 | PY |
| 59 | 67 | then |
| 60 | - echo "ERROR: qwen3_vllm backend requires vllm + CUDA runtime in ${RERANKER_VENV}." >&2 | |
| 61 | - echo "Please run: ./scripts/setup_reranker_venv.sh and verify CUDA is available." >&2 | |
| 68 | + if [[ "${RERANK_BACKEND}" == "qwen3_transformers_packed" ]]; then | |
| 69 | + echo "ERROR: ${RERANK_BACKEND} backend requires torch + CUDA runtime in ${RERANKER_VENV}." >&2 | |
| 70 | + else | |
| 71 | + echo "ERROR: ${RERANK_BACKEND} backend requires vllm + CUDA runtime in ${RERANKER_VENV}." >&2 | |
| 72 | + fi | |
| 73 | + echo "Please run: ./scripts/setup_reranker_venv.sh ${RERANK_BACKEND} and verify CUDA is available." >&2 | |
| 74 | + exit 1 | |
| 75 | + fi | |
| 76 | +fi | |
| 77 | + | |
| 78 | +if [[ "${RERANK_BACKEND}" == qwen3_gguf* ]]; then | |
| 79 | + gguf_check_status=0 | |
| 80 | + "${PYTHON_BIN}" - <<'PY' || gguf_check_status=$? | |
| 81 | +try: | |
| 82 | + import llama_cpp | |
| 83 | + if hasattr(llama_cpp, "llama_supports_gpu_offload") and not llama_cpp.llama_supports_gpu_offload(): | |
| 84 | + raise SystemExit(2) | |
| 85 | +except Exception: | |
| 86 | + raise SystemExit(1) | |
| 87 | +PY | |
| 88 | + if [[ "${gguf_check_status}" != "0" ]]; then | |
| 89 | + if [[ "${gguf_check_status}" == "2" ]]; then | |
| 90 | + echo "ERROR: ${RERANK_BACKEND} backend detected a CPU-only llama-cpp-python build in ${RERANKER_VENV}." >&2 | |
| 91 | + echo "Please rerun: ./scripts/setup_reranker_venv.sh ${RERANK_BACKEND}" >&2 | |
| 92 | + else | |
| 93 | + echo "ERROR: ${RERANK_BACKEND} backend requires llama-cpp-python in ${RERANKER_VENV}." >&2 | |
| 94 | + echo "Please run: ./scripts/setup_reranker_venv.sh ${RERANK_BACKEND}" >&2 | |
| 95 | + fi | |
| 62 | 96 | exit 1 |
| 63 | 97 | fi |
| 64 | 98 | fi | ... | ... |
search/rerank_client.py
| ... | ... | @@ -200,19 +200,24 @@ def _multiply_fusion_factors( |
| 200 | 200 | knn_score: float, |
| 201 | 201 | fusion: RerankFusionConfig, |
| 202 | 202 | ) -> Tuple[float, float, float, float]: |
| 203 | - """(rerank_factor, text_factor, knn_factor, fused).""" | |
| 203 | + """(rerank_factor, text_factor, knn_factor, fused_without_style_boost).""" | |
| 204 | 204 | r = (max(rerank_score, 0.0) + fusion.rerank_bias) ** fusion.rerank_exponent |
| 205 | 205 | t = (max(text_score, 0.0) + fusion.text_bias) ** fusion.text_exponent |
| 206 | 206 | k = (max(knn_score, 0.0) + fusion.knn_bias) ** fusion.knn_exponent |
| 207 | 207 | return r, t, k, r * t * k |
| 208 | 208 | |
| 209 | 209 | |
| 210 | +def _has_selected_sku(hit: Dict[str, Any]) -> bool: | |
| 211 | + return bool(str(hit.get("_style_rerank_suffix") or "").strip()) | |
| 212 | + | |
| 213 | + | |
| 210 | 214 | def fuse_scores_and_resort( |
| 211 | 215 | es_hits: List[Dict[str, Any]], |
| 212 | 216 | rerank_scores: List[float], |
| 213 | 217 | weight_es: float = DEFAULT_WEIGHT_ES, |
| 214 | 218 | weight_ai: float = DEFAULT_WEIGHT_AI, |
| 215 | 219 | fusion: Optional[RerankFusionConfig] = None, |
| 220 | + style_intent_selected_sku_boost: float = 1.2, | |
| 216 | 221 | debug: bool = False, |
| 217 | 222 | rerank_debug_rows: Optional[List[Dict[str, Any]]] = None, |
| 218 | 223 | ) -> List[Dict[str, Any]]: |
| ... | ... | @@ -220,7 +225,10 @@ def fuse_scores_and_resort( |
| 220 | 225 | 将 ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。 |
| 221 | 226 | |
| 222 | 227 | 融合形式(由 ``fusion`` 配置 bias / exponent):: |
| 223 | - fused = (max(rerank,0)+b_r)^e_r * (max(text,0)+b_t)^e_t * (max(knn,0)+b_k)^e_k | |
| 228 | + fused = (max(rerank,0)+b_r)^e_r * (max(text,0)+b_t)^e_t * (max(knn,0)+b_k)^e_k * sku_boost | |
| 229 | + | |
| 230 | + 其中 sku_boost 仅在当前 hit 已选中 SKU 时生效,默认值为 1.2,可通过 | |
| 231 | + ``query.style_intent.selected_sku_boost`` 配置。 | |
| 224 | 232 | |
| 225 | 233 | 对每条 hit 会写入: |
| 226 | 234 | - _original_score: 原始 ES 分数 |
| ... | ... | @@ -252,12 +260,16 @@ def fuse_scores_and_resort( |
| 252 | 260 | rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors( |
| 253 | 261 | rerank_score, text_score, knn_score, f |
| 254 | 262 | ) |
| 263 | + sku_selected = _has_selected_sku(hit) | |
| 264 | + style_boost = style_intent_selected_sku_boost if sku_selected else 1.0 | |
| 265 | + fused *= style_boost | |
| 255 | 266 | |
| 256 | 267 | hit["_original_score"] = hit.get("_score") |
| 257 | 268 | hit["_rerank_score"] = rerank_score |
| 258 | 269 | hit["_text_score"] = text_score |
| 259 | 270 | hit["_knn_score"] = knn_score |
| 260 | 271 | hit["_fused_score"] = fused |
| 272 | + hit["_style_intent_selected_sku_boost"] = style_boost | |
| 261 | 273 | if debug: |
| 262 | 274 | hit["_text_source_score"] = text_components["source_score"] |
| 263 | 275 | hit["_text_translation_score"] = text_components["translation_score"] |
| ... | ... | @@ -285,6 +297,8 @@ def fuse_scores_and_resort( |
| 285 | 297 | "rerank_factor": rerank_factor, |
| 286 | 298 | "text_factor": text_factor, |
| 287 | 299 | "knn_factor": knn_factor, |
| 300 | + "style_intent_selected_sku": sku_selected, | |
| 301 | + "style_intent_selected_sku_boost": style_boost, | |
| 288 | 302 | "matched_queries": matched_queries, |
| 289 | 303 | "fused_score": fused, |
| 290 | 304 | } |
| ... | ... | @@ -311,6 +325,7 @@ def run_rerank( |
| 311 | 325 | top_n: Optional[int] = None, |
| 312 | 326 | debug: bool = False, |
| 313 | 327 | fusion: Optional[RerankFusionConfig] = None, |
| 328 | + style_intent_selected_sku_boost: float = 1.2, | |
| 314 | 329 | ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]: |
| 315 | 330 | """ |
| 316 | 331 | 完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。 |
| ... | ... | @@ -345,6 +360,7 @@ def run_rerank( |
| 345 | 360 | weight_es=weight_es, |
| 346 | 361 | weight_ai=weight_ai, |
| 347 | 362 | fusion=fusion, |
| 363 | + style_intent_selected_sku_boost=style_intent_selected_sku_boost, | |
| 348 | 364 | debug=debug, |
| 349 | 365 | rerank_debug_rows=rerank_debug_rows, |
| 350 | 366 | ) | ... | ... |
search/searcher.py
| ... | ... | @@ -594,6 +594,7 @@ class Searcher: |
| 594 | 594 | top_n=(from_ + size), |
| 595 | 595 | debug=debug, |
| 596 | 596 | fusion=rc.fusion, |
| 597 | + style_intent_selected_sku_boost=self.config.query_config.style_intent_selected_sku_boost, | |
| 597 | 598 | ) |
| 598 | 599 | |
| 599 | 600 | if rerank_meta is not None: |
| ... | ... | @@ -1055,4 +1056,3 @@ class Searcher: |
| 1055 | 1056 | except Exception as e: |
| 1056 | 1057 | logger.error(f"Failed to get document {doc_id} from tenant {tenant_id}: {e}", exc_info=True) |
| 1057 | 1058 | return None |
| 1058 | - | ... | ... |
search/sku_intent_selector.py
| ... | ... | @@ -5,12 +5,10 @@ SKU selection for style-intent-aware search results. |
| 5 | 5 | from __future__ import annotations |
| 6 | 6 | |
| 7 | 7 | from dataclasses import dataclass, field |
| 8 | -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple | |
| 9 | - | |
| 10 | -import numpy as np | |
| 8 | +from typing import Any, Callable, Dict, List, Optional, Tuple | |
| 11 | 9 | |
| 12 | 10 | from query.style_intent import StyleIntentProfile, StyleIntentRegistry |
| 13 | -from query.tokenization import normalize_query_text | |
| 11 | +from query.tokenization import normalize_query_text, simple_tokenize_query | |
| 14 | 12 | |
| 15 | 13 | |
| 16 | 14 | @dataclass(frozen=True) |
| ... | ... | @@ -34,24 +32,11 @@ class SkuSelectionDecision: |
| 34 | 32 | |
| 35 | 33 | |
| 36 | 34 | @dataclass |
| 37 | -class _SkuCandidate: | |
| 38 | - index: int | |
| 39 | - sku_id: str | |
| 40 | - sku: Dict[str, Any] | |
| 41 | - selection_text: str | |
| 42 | - normalized_selection_text: str | |
| 43 | - intent_values: Dict[str, str] | |
| 44 | - normalized_intent_values: Dict[str, str] | |
| 45 | - | |
| 46 | - | |
| 47 | -@dataclass | |
| 48 | 35 | class _SelectionContext: |
| 49 | - query_texts: Tuple[str, ...] | |
| 50 | - matched_terms_by_intent: Dict[str, Tuple[str, ...]] | |
| 51 | - query_vector: Optional[np.ndarray] | |
| 36 | + attribute_terms_by_intent: Dict[str, Tuple[str, ...]] | |
| 37 | + normalized_text_cache: Dict[str, str] = field(default_factory=dict) | |
| 38 | + tokenized_text_cache: Dict[str, Tuple[str, ...]] = field(default_factory=dict) | |
| 52 | 39 | text_match_cache: Dict[Tuple[str, str], bool] = field(default_factory=dict) |
| 53 | - selection_vector_cache: Dict[str, Optional[np.ndarray]] = field(default_factory=dict) | |
| 54 | - similarity_cache: Dict[str, Optional[float]] = field(default_factory=dict) | |
| 55 | 40 | |
| 56 | 41 | |
| 57 | 42 | class StyleSkuSelector: |
| ... | ... | @@ -76,7 +61,7 @@ class StyleSkuSelector: |
| 76 | 61 | if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active: |
| 77 | 62 | return decisions |
| 78 | 63 | |
| 79 | - selection_context = self._build_selection_context(parsed_query, style_profile) | |
| 64 | + selection_context = self._build_selection_context(style_profile) | |
| 80 | 65 | |
| 81 | 66 | for hit in es_hits: |
| 82 | 67 | source = hit.get("_source") |
| ... | ... | @@ -126,81 +111,37 @@ class StyleSkuSelector: |
| 126 | 111 | else: |
| 127 | 112 | hit.pop("_style_rerank_suffix", None) |
| 128 | 113 | |
| 129 | - def _build_query_texts( | |
| 130 | - self, | |
| 131 | - parsed_query: Any, | |
| 132 | - style_profile: StyleIntentProfile, | |
| 133 | - ) -> List[str]: | |
| 134 | - texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text] | |
| 135 | - if texts: | |
| 136 | - return list(dict.fromkeys(texts)) | |
| 137 | - | |
| 138 | - fallbacks: List[str] = [] | |
| 139 | - for value in ( | |
| 140 | - getattr(parsed_query, "original_query", None), | |
| 141 | - getattr(parsed_query, "query_normalized", None), | |
| 142 | - getattr(parsed_query, "rewritten_query", None), | |
| 143 | - ): | |
| 144 | - normalized = normalize_query_text(value) | |
| 145 | - if normalized: | |
| 146 | - fallbacks.append(normalized) | |
| 147 | - translations = getattr(parsed_query, "translations", {}) or {} | |
| 148 | - if isinstance(translations, dict): | |
| 149 | - for value in translations.values(): | |
| 150 | - normalized = normalize_query_text(value) | |
| 151 | - if normalized: | |
| 152 | - fallbacks.append(normalized) | |
| 153 | - return list(dict.fromkeys(fallbacks)) | |
| 154 | - | |
| 155 | - def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]: | |
| 156 | - query_vector = getattr(parsed_query, "query_vector", None) | |
| 157 | - if query_vector is not None: | |
| 158 | - return np.asarray(query_vector, dtype=np.float32) | |
| 159 | - | |
| 160 | - text_encoder = self._get_text_encoder() | |
| 161 | - if text_encoder is None: | |
| 162 | - return None | |
| 163 | - | |
| 164 | - query_text = ( | |
| 165 | - getattr(parsed_query, "rewritten_query", None) | |
| 166 | - or getattr(parsed_query, "query_normalized", None) | |
| 167 | - or getattr(parsed_query, "original_query", None) | |
| 168 | - ) | |
| 169 | - if not query_text: | |
| 170 | - return None | |
| 171 | - | |
| 172 | - vectors = text_encoder.encode([query_text], priority=1) | |
| 173 | - if vectors is None or len(vectors) == 0 or vectors[0] is None: | |
| 174 | - return None | |
| 175 | - return np.asarray(vectors[0], dtype=np.float32) | |
| 176 | - | |
| 177 | 114 | def _build_selection_context( |
| 178 | 115 | self, |
| 179 | - parsed_query: Any, | |
| 180 | 116 | style_profile: StyleIntentProfile, |
| 181 | 117 | ) -> _SelectionContext: |
| 182 | - matched_terms_by_intent: Dict[str, List[str]] = {} | |
| 118 | + attribute_terms_by_intent: Dict[str, List[str]] = {} | |
| 183 | 119 | for intent in style_profile.intents: |
| 184 | - normalized_term = normalize_query_text(intent.matched_term) | |
| 185 | - if not normalized_term: | |
| 186 | - continue | |
| 187 | - matched_terms = matched_terms_by_intent.setdefault(intent.intent_type, []) | |
| 188 | - if normalized_term not in matched_terms: | |
| 189 | - matched_terms.append(normalized_term) | |
| 120 | + terms = attribute_terms_by_intent.setdefault(intent.intent_type, []) | |
| 121 | + for raw_term in intent.attribute_terms: | |
| 122 | + normalized_term = normalize_query_text(raw_term) | |
| 123 | + if not normalized_term or normalized_term in terms: | |
| 124 | + continue | |
| 125 | + terms.append(normalized_term) | |
| 190 | 126 | |
| 191 | 127 | return _SelectionContext( |
| 192 | - query_texts=tuple(self._build_query_texts(parsed_query, style_profile)), | |
| 193 | - matched_terms_by_intent={ | |
| 128 | + attribute_terms_by_intent={ | |
| 194 | 129 | intent_type: tuple(terms) |
| 195 | - for intent_type, terms in matched_terms_by_intent.items() | |
| 130 | + for intent_type, terms in attribute_terms_by_intent.items() | |
| 196 | 131 | }, |
| 197 | - query_vector=self._get_query_vector(parsed_query), | |
| 198 | 132 | ) |
| 199 | 133 | |
| 200 | - def _get_text_encoder(self) -> Any: | |
| 201 | - if self._text_encoder_getter is None: | |
| 202 | - return None | |
| 203 | - return self._text_encoder_getter() | |
| 134 | + @staticmethod | |
| 135 | + def _normalize_cached(selection_context: _SelectionContext, value: Any) -> str: | |
| 136 | + raw = str(value or "").strip() | |
| 137 | + if not raw: | |
| 138 | + return "" | |
| 139 | + cached = selection_context.normalized_text_cache.get(raw) | |
| 140 | + if cached is not None: | |
| 141 | + return cached | |
| 142 | + normalized = normalize_query_text(raw) | |
| 143 | + selection_context.normalized_text_cache[raw] = normalized | |
| 144 | + return normalized | |
| 204 | 145 | |
| 205 | 146 | def _resolve_dimensions( |
| 206 | 147 | self, |
| ... | ... | @@ -225,51 +166,6 @@ class StyleSkuSelector: |
| 225 | 166 | resolved[intent.intent_type] = matched_field |
| 226 | 167 | return resolved |
| 227 | 168 | |
| 228 | - def _build_candidates( | |
| 229 | - self, | |
| 230 | - skus: List[Dict[str, Any]], | |
| 231 | - resolved_dimensions: Dict[str, Optional[str]], | |
| 232 | - ) -> List[_SkuCandidate]: | |
| 233 | - if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): | |
| 234 | - return [] | |
| 235 | - | |
| 236 | - candidates: List[_SkuCandidate] = [] | |
| 237 | - for index, sku in enumerate(skus): | |
| 238 | - intent_values: Dict[str, str] = {} | |
| 239 | - normalized_intent_values: Dict[str, str] = {} | |
| 240 | - for intent_type, field_name in resolved_dimensions.items(): | |
| 241 | - if not field_name: | |
| 242 | - continue | |
| 243 | - raw = str(sku.get(field_name) or "").strip() | |
| 244 | - intent_values[intent_type] = raw | |
| 245 | - normalized_intent_values[intent_type] = normalize_query_text(raw) | |
| 246 | - | |
| 247 | - selection_parts: List[str] = [] | |
| 248 | - norm_parts: List[str] = [] | |
| 249 | - seen: set[str] = set() | |
| 250 | - for intent_type, raw in intent_values.items(): | |
| 251 | - nv = normalized_intent_values[intent_type] | |
| 252 | - if not nv or nv in seen: | |
| 253 | - continue | |
| 254 | - seen.add(nv) | |
| 255 | - selection_parts.append(raw) | |
| 256 | - norm_parts.append(nv) | |
| 257 | - | |
| 258 | - selection_text = " ".join(selection_parts).strip() | |
| 259 | - normalized_selection_text = " ".join(norm_parts).strip() | |
| 260 | - candidates.append( | |
| 261 | - _SkuCandidate( | |
| 262 | - index=index, | |
| 263 | - sku_id=str(sku.get("sku_id") or ""), | |
| 264 | - sku=sku, | |
| 265 | - selection_text=selection_text, | |
| 266 | - normalized_selection_text=normalized_selection_text, | |
| 267 | - intent_values=intent_values, | |
| 268 | - normalized_intent_values=normalized_intent_values, | |
| 269 | - ) | |
| 270 | - ) | |
| 271 | - return candidates | |
| 272 | - | |
| 273 | 169 | @staticmethod |
| 274 | 170 | def _empty_decision( |
| 275 | 171 | resolved_dimensions: Dict[str, Optional[str]], |
| ... | ... | @@ -286,13 +182,10 @@ class StyleSkuSelector: |
| 286 | 182 | def _is_text_match( |
| 287 | 183 | self, |
| 288 | 184 | intent_type: str, |
| 289 | - value: str, | |
| 290 | 185 | selection_context: _SelectionContext, |
| 291 | 186 | *, |
| 292 | - normalized_value: Optional[str] = None, | |
| 187 | + normalized_value: str, | |
| 293 | 188 | ) -> bool: |
| 294 | - if normalized_value is None: | |
| 295 | - normalized_value = normalize_query_text(value) | |
| 296 | 189 | if not normalized_value: |
| 297 | 190 | return False |
| 298 | 191 | |
| ... | ... | @@ -301,84 +194,94 @@ class StyleSkuSelector: |
| 301 | 194 | if cached is not None: |
| 302 | 195 | return cached |
| 303 | 196 | |
| 304 | - matched_terms = selection_context.matched_terms_by_intent.get(intent_type, ()) | |
| 305 | - has_term_match = any(term in normalized_value for term in matched_terms if term) | |
| 306 | - query_contains_value = any( | |
| 307 | - normalized_value in query_text | |
| 308 | - for query_text in selection_context.query_texts | |
| 197 | + attribute_terms = selection_context.attribute_terms_by_intent.get(intent_type, ()) | |
| 198 | + value_tokens = self._tokenize_cached(selection_context, normalized_value) | |
| 199 | + matched = any( | |
| 200 | + self._matches_term_tokens( | |
| 201 | + term=term, | |
| 202 | + value_tokens=value_tokens, | |
| 203 | + selection_context=selection_context, | |
| 204 | + normalized_value=normalized_value, | |
| 205 | + ) | |
| 206 | + for term in attribute_terms | |
| 207 | + if term | |
| 309 | 208 | ) |
| 310 | - matched = bool(has_term_match or query_contains_value) | |
| 311 | 209 | selection_context.text_match_cache[cache_key] = matched |
| 312 | 210 | return matched |
| 313 | 211 | |
| 314 | - def _find_first_text_match( | |
| 212 | + @staticmethod | |
| 213 | + def _tokenize_cached(selection_context: _SelectionContext, value: str) -> Tuple[str, ...]: | |
| 214 | + normalized_value = normalize_query_text(value) | |
| 215 | + if not normalized_value: | |
| 216 | + return () | |
| 217 | + cached = selection_context.tokenized_text_cache.get(normalized_value) | |
| 218 | + if cached is not None: | |
| 219 | + return cached | |
| 220 | + tokens = tuple(normalize_query_text(token) for token in simple_tokenize_query(normalized_value) if token) | |
| 221 | + selection_context.tokenized_text_cache[normalized_value] = tokens | |
| 222 | + return tokens | |
| 223 | + | |
| 224 | + def _matches_term_tokens( | |
| 315 | 225 | self, |
| 316 | - candidates: Sequence[_SkuCandidate], | |
| 226 | + *, | |
| 227 | + term: str, | |
| 228 | + value_tokens: Tuple[str, ...], | |
| 317 | 229 | selection_context: _SelectionContext, |
| 318 | - ) -> Optional[_SkuCandidate]: | |
| 319 | - for candidate in candidates: | |
| 320 | - if candidate.intent_values and all( | |
| 321 | - self._is_text_match( | |
| 322 | - intent_type, | |
| 323 | - value, | |
| 324 | - selection_context, | |
| 325 | - normalized_value=candidate.normalized_intent_values[intent_type], | |
| 326 | - ) | |
| 327 | - for intent_type, value in candidate.intent_values.items() | |
| 328 | - ): | |
| 329 | - return candidate | |
| 330 | - return None | |
| 230 | + normalized_value: str, | |
| 231 | + ) -> bool: | |
| 232 | + normalized_term = normalize_query_text(term) | |
| 233 | + if not normalized_term: | |
| 234 | + return False | |
| 235 | + if normalized_term == normalized_value: | |
| 236 | + return True | |
| 331 | 237 | |
| 332 | - def _select_by_embedding( | |
| 238 | + term_tokens = self._tokenize_cached(selection_context, normalized_term) | |
| 239 | + if not term_tokens or not value_tokens: | |
| 240 | + return normalized_term in normalized_value | |
| 241 | + | |
| 242 | + term_length = len(term_tokens) | |
| 243 | + value_length = len(value_tokens) | |
| 244 | + if term_length > value_length: | |
| 245 | + return False | |
| 246 | + | |
| 247 | + for start in range(value_length - term_length + 1): | |
| 248 | + if value_tokens[start:start + term_length] == term_tokens: | |
| 249 | + return True | |
| 250 | + return False | |
| 251 | + | |
| 252 | + def _find_first_text_match( | |
| 333 | 253 | self, |
| 334 | - candidates: Sequence[_SkuCandidate], | |
| 254 | + skus: List[Dict[str, Any]], | |
| 255 | + resolved_dimensions: Dict[str, Optional[str]], | |
| 335 | 256 | selection_context: _SelectionContext, |
| 336 | - ) -> Tuple[Optional[_SkuCandidate], Optional[float]]: | |
| 337 | - if not candidates: | |
| 338 | - return None, None | |
| 339 | - text_encoder = self._get_text_encoder() | |
| 340 | - if selection_context.query_vector is None or text_encoder is None: | |
| 341 | - return None, None | |
| 342 | - | |
| 343 | - unique_texts = list( | |
| 344 | - dict.fromkeys( | |
| 345 | - candidate.normalized_selection_text | |
| 346 | - for candidate in candidates | |
| 347 | - if candidate.normalized_selection_text | |
| 348 | - and candidate.normalized_selection_text not in selection_context.selection_vector_cache | |
| 349 | - ) | |
| 350 | - ) | |
| 351 | - if unique_texts: | |
| 352 | - vectors = text_encoder.encode(unique_texts, priority=1) | |
| 353 | - for key, vector in zip(unique_texts, vectors): | |
| 354 | - selection_context.selection_vector_cache[key] = ( | |
| 355 | - np.asarray(vector, dtype=np.float32) if vector is not None else None | |
| 356 | - ) | |
| 357 | - | |
| 358 | - best_candidate: Optional[_SkuCandidate] = None | |
| 359 | - best_score: Optional[float] = None | |
| 360 | - query_vector_array = np.asarray(selection_context.query_vector, dtype=np.float32) | |
| 361 | - for candidate in candidates: | |
| 362 | - normalized_text = candidate.normalized_selection_text | |
| 363 | - if not normalized_text: | |
| 364 | - continue | |
| 257 | + ) -> Optional[Tuple[str, str]]: | |
| 258 | + for sku in skus: | |
| 259 | + selection_parts: List[str] = [] | |
| 260 | + seen_parts: set[str] = set() | |
| 261 | + matched = True | |
| 365 | 262 | |
| 366 | - score = selection_context.similarity_cache.get(normalized_text) | |
| 367 | - if score is None: | |
| 368 | - candidate_vector = selection_context.selection_vector_cache.get(normalized_text) | |
| 369 | - if candidate_vector is None: | |
| 370 | - selection_context.similarity_cache[normalized_text] = None | |
| 371 | - continue | |
| 372 | - score = float(np.inner(query_vector_array, candidate_vector)) | |
| 373 | - selection_context.similarity_cache[normalized_text] = score | |
| 263 | + for intent_type, field_name in resolved_dimensions.items(): | |
| 264 | + if not field_name: | |
| 265 | + matched = False | |
| 266 | + break | |
| 374 | 267 | |
| 375 | - if score is None: | |
| 376 | - continue | |
| 377 | - if best_score is None or score > best_score: | |
| 378 | - best_candidate = candidate | |
| 379 | - best_score = score | |
| 268 | + raw_value = str(sku.get(field_name) or "").strip() | |
| 269 | + normalized_value = self._normalize_cached(selection_context, raw_value) | |
| 270 | + if not self._is_text_match( | |
| 271 | + intent_type, | |
| 272 | + selection_context, | |
| 273 | + normalized_value=normalized_value, | |
| 274 | + ): | |
| 275 | + matched = False | |
| 276 | + break | |
| 380 | 277 | |
| 381 | - return best_candidate, best_score | |
| 278 | + if raw_value and normalized_value not in seen_parts: | |
| 279 | + seen_parts.add(normalized_value) | |
| 280 | + selection_parts.append(raw_value) | |
| 281 | + | |
| 282 | + if matched: | |
| 283 | + return str(sku.get("sku_id") or ""), " ".join(selection_parts).strip() | |
| 284 | + return None | |
| 382 | 285 | |
| 383 | 286 | def _select_for_source( |
| 384 | 287 | self, |
| ... | ... | @@ -395,36 +298,29 @@ class StyleSkuSelector: |
| 395 | 298 | if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): |
| 396 | 299 | return self._empty_decision(resolved_dimensions, matched_stage="unresolved") |
| 397 | 300 | |
| 398 | - candidates = self._build_candidates(skus, resolved_dimensions) | |
| 399 | - if not candidates: | |
| 400 | - return self._empty_decision(resolved_dimensions, matched_stage="no_candidates") | |
| 401 | - | |
| 402 | - text_match = self._find_first_text_match(candidates, selection_context) | |
| 403 | - if text_match is not None: | |
| 404 | - return self._build_decision(text_match, resolved_dimensions, matched_stage="text") | |
| 405 | - | |
| 406 | - chosen, similarity_score = self._select_by_embedding(candidates, selection_context) | |
| 407 | - if chosen is None: | |
| 301 | + text_match = self._find_first_text_match(skus, resolved_dimensions, selection_context) | |
| 302 | + if text_match is None: | |
| 408 | 303 | return self._empty_decision(resolved_dimensions, matched_stage="no_match") |
| 409 | 304 | return self._build_decision( |
| 410 | - chosen, | |
| 411 | - resolved_dimensions, | |
| 412 | - matched_stage="embedding", | |
| 413 | - similarity_score=similarity_score, | |
| 305 | + selected_sku_id=text_match[0], | |
| 306 | + selected_text=text_match[1], | |
| 307 | + resolved_dimensions=resolved_dimensions, | |
| 308 | + matched_stage="text", | |
| 414 | 309 | ) |
| 415 | 310 | |
| 416 | 311 | @staticmethod |
| 417 | 312 | def _build_decision( |
| 418 | - candidate: _SkuCandidate, | |
| 313 | + selected_sku_id: str, | |
| 314 | + selected_text: str, | |
| 419 | 315 | resolved_dimensions: Dict[str, Optional[str]], |
| 420 | 316 | *, |
| 421 | 317 | matched_stage: str, |
| 422 | 318 | similarity_score: Optional[float] = None, |
| 423 | 319 | ) -> SkuSelectionDecision: |
| 424 | 320 | return SkuSelectionDecision( |
| 425 | - selected_sku_id=candidate.sku_id or None, | |
| 426 | - rerank_suffix=str(candidate.selection_text or "").strip(), | |
| 427 | - selected_text=str(candidate.selection_text or "").strip(), | |
| 321 | + selected_sku_id=selected_sku_id or None, | |
| 322 | + rerank_suffix=str(selected_text or "").strip(), | |
| 323 | + selected_text=str(selected_text or "").strip(), | |
| 428 | 324 | matched_stage=matched_stage, |
| 429 | 325 | similarity_score=similarity_score, |
| 430 | 326 | resolved_dimensions=dict(resolved_dimensions), | ... | ... |
| ... | ... | @@ -0,0 +1,452 @@ |
| 1 | +""" | |
| 2 | +SKU selection for style-intent-aware search results. | |
| 3 | +""" | |
| 4 | + | |
| 5 | +from __future__ import annotations | |
| 6 | + | |
| 7 | +from dataclasses import dataclass, field | |
| 8 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple | |
| 9 | + | |
| 10 | +import numpy as np | |
| 11 | + | |
| 12 | +from query.style_intent import StyleIntentProfile, StyleIntentRegistry | |
| 13 | +from query.tokenization import normalize_query_text | |
| 14 | + | |
| 15 | + | |
| 16 | +@dataclass(frozen=True) | |
| 17 | +class SkuSelectionDecision: | |
| 18 | + selected_sku_id: Optional[str] | |
| 19 | + rerank_suffix: str | |
| 20 | + selected_text: str | |
| 21 | + matched_stage: str | |
| 22 | + similarity_score: Optional[float] = None | |
| 23 | + resolved_dimensions: Dict[str, Optional[str]] = field(default_factory=dict) | |
| 24 | + | |
| 25 | + def to_dict(self) -> Dict[str, Any]: | |
| 26 | + return { | |
| 27 | + "selected_sku_id": self.selected_sku_id, | |
| 28 | + "rerank_suffix": self.rerank_suffix, | |
| 29 | + "selected_text": self.selected_text, | |
| 30 | + "matched_stage": self.matched_stage, | |
| 31 | + "similarity_score": self.similarity_score, | |
| 32 | + "resolved_dimensions": dict(self.resolved_dimensions), | |
| 33 | + } | |
| 34 | + | |
| 35 | + | |
| 36 | +@dataclass | |
| 37 | +class _SkuCandidate: | |
| 38 | + index: int | |
| 39 | + sku_id: str | |
| 40 | + sku: Dict[str, Any] | |
| 41 | + selection_text: str | |
| 42 | + normalized_selection_text: str | |
| 43 | + intent_values: Dict[str, str] | |
| 44 | + normalized_intent_values: Dict[str, str] | |
| 45 | + | |
| 46 | + | |
| 47 | +@dataclass | |
| 48 | +class _SelectionContext: | |
| 49 | + query_texts: Tuple[str, ...] | |
| 50 | + matched_terms_by_intent: Dict[str, Tuple[str, ...]] | |
| 51 | + query_vector: Optional[np.ndarray] | |
| 52 | + text_match_cache: Dict[Tuple[str, str], bool] = field(default_factory=dict) | |
| 53 | + selection_vector_cache: Dict[str, Optional[np.ndarray]] = field(default_factory=dict) | |
| 54 | + similarity_cache: Dict[str, Optional[float]] = field(default_factory=dict) | |
| 55 | + | |
| 56 | + | |
| 57 | +class StyleSkuSelector: | |
| 58 | + """Selects the best SKU for an SPU based on detected style intent.""" | |
| 59 | + | |
| 60 | + def __init__( | |
| 61 | + self, | |
| 62 | + registry: StyleIntentRegistry, | |
| 63 | + *, | |
| 64 | + text_encoder_getter: Optional[Callable[[], Any]] = None, | |
| 65 | + ) -> None: | |
| 66 | + self.registry = registry | |
| 67 | + self._text_encoder_getter = text_encoder_getter | |
| 68 | + | |
| 69 | + def prepare_hits( | |
| 70 | + self, | |
| 71 | + es_hits: List[Dict[str, Any]], | |
| 72 | + parsed_query: Any, | |
| 73 | + ) -> Dict[str, SkuSelectionDecision]: | |
| 74 | + decisions: Dict[str, SkuSelectionDecision] = {} | |
| 75 | + style_profile = getattr(parsed_query, "style_intent_profile", None) | |
| 76 | + if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active: | |
| 77 | + return decisions | |
| 78 | + | |
| 79 | + selection_context = self._build_selection_context(parsed_query, style_profile) | |
| 80 | + | |
| 81 | + for hit in es_hits: | |
| 82 | + source = hit.get("_source") | |
| 83 | + if not isinstance(source, dict): | |
| 84 | + continue | |
| 85 | + | |
| 86 | + decision = self._select_for_source( | |
| 87 | + source, | |
| 88 | + style_profile=style_profile, | |
| 89 | + selection_context=selection_context, | |
| 90 | + ) | |
| 91 | + if decision is None: | |
| 92 | + continue | |
| 93 | + | |
| 94 | + if decision.rerank_suffix: | |
| 95 | + hit["_style_rerank_suffix"] = decision.rerank_suffix | |
| 96 | + else: | |
| 97 | + hit.pop("_style_rerank_suffix", None) | |
| 98 | + | |
| 99 | + doc_id = hit.get("_id") | |
| 100 | + if doc_id is not None: | |
| 101 | + decisions[str(doc_id)] = decision | |
| 102 | + | |
| 103 | + return decisions | |
| 104 | + | |
| 105 | + def apply_precomputed_decisions( | |
| 106 | + self, | |
| 107 | + es_hits: List[Dict[str, Any]], | |
| 108 | + decisions: Dict[str, SkuSelectionDecision], | |
| 109 | + ) -> None: | |
| 110 | + if not es_hits or not decisions: | |
| 111 | + return | |
| 112 | + | |
| 113 | + for hit in es_hits: | |
| 114 | + doc_id = hit.get("_id") | |
| 115 | + if doc_id is None: | |
| 116 | + continue | |
| 117 | + decision = decisions.get(str(doc_id)) | |
| 118 | + if decision is None: | |
| 119 | + continue | |
| 120 | + source = hit.get("_source") | |
| 121 | + if not isinstance(source, dict): | |
| 122 | + continue | |
| 123 | + self._apply_decision_to_source(source, decision) | |
| 124 | + if decision.rerank_suffix: | |
| 125 | + hit["_style_rerank_suffix"] = decision.rerank_suffix | |
| 126 | + else: | |
| 127 | + hit.pop("_style_rerank_suffix", None) | |
| 128 | + | |
| 129 | + def _build_query_texts( | |
| 130 | + self, | |
| 131 | + parsed_query: Any, | |
| 132 | + style_profile: StyleIntentProfile, | |
| 133 | + ) -> List[str]: | |
| 134 | + texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text] | |
| 135 | + if texts: | |
| 136 | + return list(dict.fromkeys(texts)) | |
| 137 | + | |
| 138 | + fallbacks: List[str] = [] | |
| 139 | + for value in ( | |
| 140 | + getattr(parsed_query, "original_query", None), | |
| 141 | + getattr(parsed_query, "query_normalized", None), | |
| 142 | + getattr(parsed_query, "rewritten_query", None), | |
| 143 | + ): | |
| 144 | + normalized = normalize_query_text(value) | |
| 145 | + if normalized: | |
| 146 | + fallbacks.append(normalized) | |
| 147 | + translations = getattr(parsed_query, "translations", {}) or {} | |
| 148 | + if isinstance(translations, dict): | |
| 149 | + for value in translations.values(): | |
| 150 | + normalized = normalize_query_text(value) | |
| 151 | + if normalized: | |
| 152 | + fallbacks.append(normalized) | |
| 153 | + return list(dict.fromkeys(fallbacks)) | |
| 154 | + | |
| 155 | + def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]: | |
| 156 | + query_vector = getattr(parsed_query, "query_vector", None) | |
| 157 | + if query_vector is not None: | |
| 158 | + return np.asarray(query_vector, dtype=np.float32) | |
| 159 | + | |
| 160 | + text_encoder = self._get_text_encoder() | |
| 161 | + if text_encoder is None: | |
| 162 | + return None | |
| 163 | + | |
| 164 | + query_text = ( | |
| 165 | + getattr(parsed_query, "rewritten_query", None) | |
| 166 | + or getattr(parsed_query, "query_normalized", None) | |
| 167 | + or getattr(parsed_query, "original_query", None) | |
| 168 | + ) | |
| 169 | + if not query_text: | |
| 170 | + return None | |
| 171 | + | |
| 172 | + vectors = text_encoder.encode([query_text], priority=1) | |
| 173 | + if vectors is None or len(vectors) == 0 or vectors[0] is None: | |
| 174 | + return None | |
| 175 | + return np.asarray(vectors[0], dtype=np.float32) | |
| 176 | + | |
| 177 | + def _build_selection_context( | |
| 178 | + self, | |
| 179 | + parsed_query: Any, | |
| 180 | + style_profile: StyleIntentProfile, | |
| 181 | + ) -> _SelectionContext: | |
| 182 | + matched_terms_by_intent: Dict[str, List[str]] = {} | |
| 183 | + for intent in style_profile.intents: | |
| 184 | + normalized_term = normalize_query_text(intent.matched_term) | |
| 185 | + if not normalized_term: | |
| 186 | + continue | |
| 187 | + matched_terms = matched_terms_by_intent.setdefault(intent.intent_type, []) | |
| 188 | + if normalized_term not in matched_terms: | |
| 189 | + matched_terms.append(normalized_term) | |
| 190 | + | |
| 191 | + return _SelectionContext( | |
| 192 | + query_texts=tuple(self._build_query_texts(parsed_query, style_profile)), | |
| 193 | + matched_terms_by_intent={ | |
| 194 | + intent_type: tuple(terms) | |
| 195 | + for intent_type, terms in matched_terms_by_intent.items() | |
| 196 | + }, | |
| 197 | + query_vector=self._get_query_vector(parsed_query), | |
| 198 | + ) | |
| 199 | + | |
| 200 | + def _get_text_encoder(self) -> Any: | |
| 201 | + if self._text_encoder_getter is None: | |
| 202 | + return None | |
| 203 | + return self._text_encoder_getter() | |
| 204 | + | |
| 205 | + def _resolve_dimensions( | |
| 206 | + self, | |
| 207 | + source: Dict[str, Any], | |
| 208 | + style_profile: StyleIntentProfile, | |
| 209 | + ) -> Dict[str, Optional[str]]: | |
| 210 | + option_names = { | |
| 211 | + "option1_value": normalize_query_text(source.get("option1_name")), | |
| 212 | + "option2_value": normalize_query_text(source.get("option2_name")), | |
| 213 | + "option3_value": normalize_query_text(source.get("option3_name")), | |
| 214 | + } | |
| 215 | + resolved: Dict[str, Optional[str]] = {} | |
| 216 | + for intent in style_profile.intents: | |
| 217 | + if intent.intent_type in resolved: | |
| 218 | + continue | |
| 219 | + aliases = set(intent.dimension_aliases or self.registry.get_dimension_aliases(intent.intent_type)) | |
| 220 | + matched_field = None | |
| 221 | + for field_name, option_name in option_names.items(): | |
| 222 | + if option_name and option_name in aliases: | |
| 223 | + matched_field = field_name | |
| 224 | + break | |
| 225 | + resolved[intent.intent_type] = matched_field | |
| 226 | + return resolved | |
| 227 | + | |
| 228 | + def _build_candidates( | |
| 229 | + self, | |
| 230 | + skus: List[Dict[str, Any]], | |
| 231 | + resolved_dimensions: Dict[str, Optional[str]], | |
| 232 | + ) -> List[_SkuCandidate]: | |
| 233 | + if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): | |
| 234 | + return [] | |
| 235 | + | |
| 236 | + candidates: List[_SkuCandidate] = [] | |
| 237 | + for index, sku in enumerate(skus): | |
| 238 | + intent_values: Dict[str, str] = {} | |
| 239 | + normalized_intent_values: Dict[str, str] = {} | |
| 240 | + for intent_type, field_name in resolved_dimensions.items(): | |
| 241 | + if not field_name: | |
| 242 | + continue | |
| 243 | + raw = str(sku.get(field_name) or "").strip() | |
| 244 | + intent_values[intent_type] = raw | |
| 245 | + normalized_intent_values[intent_type] = normalize_query_text(raw) | |
| 246 | + | |
| 247 | + selection_parts: List[str] = [] | |
| 248 | + norm_parts: List[str] = [] | |
| 249 | + seen: set[str] = set() | |
| 250 | + for intent_type, raw in intent_values.items(): | |
| 251 | + nv = normalized_intent_values[intent_type] | |
| 252 | + if not nv or nv in seen: | |
| 253 | + continue | |
| 254 | + seen.add(nv) | |
| 255 | + selection_parts.append(raw) | |
| 256 | + norm_parts.append(nv) | |
| 257 | + | |
| 258 | + selection_text = " ".join(selection_parts).strip() | |
| 259 | + normalized_selection_text = " ".join(norm_parts).strip() | |
| 260 | + candidates.append( | |
| 261 | + _SkuCandidate( | |
| 262 | + index=index, | |
| 263 | + sku_id=str(sku.get("sku_id") or ""), | |
| 264 | + sku=sku, | |
| 265 | + selection_text=selection_text, | |
| 266 | + normalized_selection_text=normalized_selection_text, | |
| 267 | + intent_values=intent_values, | |
| 268 | + normalized_intent_values=normalized_intent_values, | |
| 269 | + ) | |
| 270 | + ) | |
| 271 | + return candidates | |
| 272 | + | |
| 273 | + @staticmethod | |
| 274 | + def _empty_decision( | |
| 275 | + resolved_dimensions: Dict[str, Optional[str]], | |
| 276 | + matched_stage: str, | |
| 277 | + ) -> SkuSelectionDecision: | |
| 278 | + return SkuSelectionDecision( | |
| 279 | + selected_sku_id=None, | |
| 280 | + rerank_suffix="", | |
| 281 | + selected_text="", | |
| 282 | + matched_stage=matched_stage, | |
| 283 | + resolved_dimensions=dict(resolved_dimensions), | |
| 284 | + ) | |
| 285 | + | |
| 286 | + def _is_text_match( | |
| 287 | + self, | |
| 288 | + intent_type: str, | |
| 289 | + value: str, | |
| 290 | + selection_context: _SelectionContext, | |
| 291 | + *, | |
| 292 | + normalized_value: Optional[str] = None, | |
| 293 | + ) -> bool: | |
| 294 | + if normalized_value is None: | |
| 295 | + normalized_value = normalize_query_text(value) | |
| 296 | + if not normalized_value: | |
| 297 | + return False | |
| 298 | + | |
| 299 | + cache_key = (intent_type, normalized_value) | |
| 300 | + cached = selection_context.text_match_cache.get(cache_key) | |
| 301 | + if cached is not None: | |
| 302 | + return cached | |
| 303 | + | |
| 304 | + matched_terms = selection_context.matched_terms_by_intent.get(intent_type, ()) | |
| 305 | + has_term_match = any(term in normalized_value for term in matched_terms if term) | |
| 306 | + query_contains_value = any( | |
| 307 | + normalized_value in query_text | |
| 308 | + for query_text in selection_context.query_texts | |
| 309 | + ) | |
| 310 | + matched = bool(has_term_match or query_contains_value) | |
| 311 | + selection_context.text_match_cache[cache_key] = matched | |
| 312 | + return matched | |
| 313 | + | |
| 314 | + def _find_first_text_match( | |
| 315 | + self, | |
| 316 | + candidates: Sequence[_SkuCandidate], | |
| 317 | + selection_context: _SelectionContext, | |
| 318 | + ) -> Optional[_SkuCandidate]: | |
| 319 | + for candidate in candidates: | |
| 320 | + if candidate.intent_values and all( | |
| 321 | + self._is_text_match( | |
| 322 | + intent_type, | |
| 323 | + value, | |
| 324 | + selection_context, | |
| 325 | + normalized_value=candidate.normalized_intent_values[intent_type], | |
| 326 | + ) | |
| 327 | + for intent_type, value in candidate.intent_values.items() | |
| 328 | + ): | |
| 329 | + return candidate | |
| 330 | + return None | |
| 331 | + | |
| 332 | + def _select_by_embedding( | |
| 333 | + self, | |
| 334 | + candidates: Sequence[_SkuCandidate], | |
| 335 | + selection_context: _SelectionContext, | |
| 336 | + ) -> Tuple[Optional[_SkuCandidate], Optional[float]]: | |
| 337 | + if not candidates: | |
| 338 | + return None, None | |
| 339 | + text_encoder = self._get_text_encoder() | |
| 340 | + if selection_context.query_vector is None or text_encoder is None: | |
| 341 | + return None, None | |
| 342 | + | |
| 343 | + unique_texts = list( | |
| 344 | + dict.fromkeys( | |
| 345 | + candidate.normalized_selection_text | |
| 346 | + for candidate in candidates | |
| 347 | + if candidate.normalized_selection_text | |
| 348 | + and candidate.normalized_selection_text not in selection_context.selection_vector_cache | |
| 349 | + ) | |
| 350 | + ) | |
| 351 | + if unique_texts: | |
| 352 | + vectors = text_encoder.encode(unique_texts, priority=1) | |
| 353 | + for key, vector in zip(unique_texts, vectors): | |
| 354 | + selection_context.selection_vector_cache[key] = ( | |
| 355 | + np.asarray(vector, dtype=np.float32) if vector is not None else None | |
| 356 | + ) | |
| 357 | + | |
| 358 | + best_candidate: Optional[_SkuCandidate] = None | |
| 359 | + best_score: Optional[float] = None | |
| 360 | + query_vector_array = np.asarray(selection_context.query_vector, dtype=np.float32) | |
| 361 | + for candidate in candidates: | |
| 362 | + normalized_text = candidate.normalized_selection_text | |
| 363 | + if not normalized_text: | |
| 364 | + continue | |
| 365 | + | |
| 366 | + score = selection_context.similarity_cache.get(normalized_text) | |
| 367 | + if score is None: | |
| 368 | + candidate_vector = selection_context.selection_vector_cache.get(normalized_text) | |
| 369 | + if candidate_vector is None: | |
| 370 | + selection_context.similarity_cache[normalized_text] = None | |
| 371 | + continue | |
| 372 | + score = float(np.inner(query_vector_array, candidate_vector)) | |
| 373 | + selection_context.similarity_cache[normalized_text] = score | |
| 374 | + | |
| 375 | + if score is None: | |
| 376 | + continue | |
| 377 | + if best_score is None or score > best_score: | |
| 378 | + best_candidate = candidate | |
| 379 | + best_score = score | |
| 380 | + | |
| 381 | + return best_candidate, best_score | |
| 382 | + | |
| 383 | + def _select_for_source( | |
| 384 | + self, | |
| 385 | + source: Dict[str, Any], | |
| 386 | + *, | |
| 387 | + style_profile: StyleIntentProfile, | |
| 388 | + selection_context: _SelectionContext, | |
| 389 | + ) -> Optional[SkuSelectionDecision]: | |
| 390 | + skus = source.get("skus") | |
| 391 | + if not isinstance(skus, list) or not skus: | |
| 392 | + return None | |
| 393 | + | |
| 394 | + resolved_dimensions = self._resolve_dimensions(source, style_profile) | |
| 395 | + if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): | |
| 396 | + return self._empty_decision(resolved_dimensions, matched_stage="unresolved") | |
| 397 | + | |
| 398 | + candidates = self._build_candidates(skus, resolved_dimensions) | |
| 399 | + if not candidates: | |
| 400 | + return self._empty_decision(resolved_dimensions, matched_stage="no_candidates") | |
| 401 | + | |
| 402 | + text_match = self._find_first_text_match(candidates, selection_context) | |
| 403 | + if text_match is not None: | |
| 404 | + return self._build_decision(text_match, resolved_dimensions, matched_stage="text") | |
| 405 | + | |
| 406 | + chosen, similarity_score = self._select_by_embedding(candidates, selection_context) | |
| 407 | + if chosen is None: | |
| 408 | + return self._empty_decision(resolved_dimensions, matched_stage="no_match") | |
| 409 | + return self._build_decision( | |
| 410 | + chosen, | |
| 411 | + resolved_dimensions, | |
| 412 | + matched_stage="embedding", | |
| 413 | + similarity_score=similarity_score, | |
| 414 | + ) | |
| 415 | + | |
| 416 | + @staticmethod | |
| 417 | + def _build_decision( | |
| 418 | + candidate: _SkuCandidate, | |
| 419 | + resolved_dimensions: Dict[str, Optional[str]], | |
| 420 | + *, | |
| 421 | + matched_stage: str, | |
| 422 | + similarity_score: Optional[float] = None, | |
| 423 | + ) -> SkuSelectionDecision: | |
| 424 | + return SkuSelectionDecision( | |
| 425 | + selected_sku_id=candidate.sku_id or None, | |
| 426 | + rerank_suffix=str(candidate.selection_text or "").strip(), | |
| 427 | + selected_text=str(candidate.selection_text or "").strip(), | |
| 428 | + matched_stage=matched_stage, | |
| 429 | + similarity_score=similarity_score, | |
| 430 | + resolved_dimensions=dict(resolved_dimensions), | |
| 431 | + ) | |
| 432 | + | |
| 433 | + @staticmethod | |
| 434 | + def _apply_decision_to_source(source: Dict[str, Any], decision: SkuSelectionDecision) -> None: | |
| 435 | + skus = source.get("skus") | |
| 436 | + if not isinstance(skus, list) or not skus or not decision.selected_sku_id: | |
| 437 | + return | |
| 438 | + | |
| 439 | + selected_index = None | |
| 440 | + for index, sku in enumerate(skus): | |
| 441 | + if str(sku.get("sku_id") or "") == decision.selected_sku_id: | |
| 442 | + selected_index = index | |
| 443 | + break | |
| 444 | + if selected_index is None: | |
| 445 | + return | |
| 446 | + | |
| 447 | + selected_sku = skus.pop(selected_index) | |
| 448 | + skus.insert(0, selected_sku) | |
| 449 | + | |
| 450 | + image_src = selected_sku.get("image_src") or selected_sku.get("imageSrc") | |
| 451 | + if image_src: | |
| 452 | + source["image_url"] = image_src | ... | ... |
tests/test_rerank_client.py
| ... | ... | @@ -118,3 +118,34 @@ def test_fuse_scores_and_resort_uses_configurable_fusion_params(): |
| 118 | 118 | by_id = {h["_id"]: h for h in hits} |
| 119 | 119 | assert isclose(by_id["a"]["_fused_score"], 1.0, rel_tol=1e-9) |
| 120 | 120 | assert isclose(by_id["b"]["_fused_score"], 0.0, rel_tol=1e-9) |
| 121 | + | |
| 122 | + | |
| 123 | +def test_fuse_scores_and_resort_boosts_hits_with_selected_sku(): | |
| 124 | + hits = [ | |
| 125 | + { | |
| 126 | + "_id": "style-selected", | |
| 127 | + "_score": 1.0, | |
| 128 | + "_style_rerank_suffix": "Blue XL", | |
| 129 | + "matched_queries": {"base_query": 1.0, "knn_query": 0.0}, | |
| 130 | + }, | |
| 131 | + { | |
| 132 | + "_id": "plain", | |
| 133 | + "_score": 1.0, | |
| 134 | + "matched_queries": {"base_query": 1.0, "knn_query": 0.0}, | |
| 135 | + }, | |
| 136 | + ] | |
| 137 | + | |
| 138 | + debug = fuse_scores_and_resort( | |
| 139 | + hits, | |
| 140 | + [1.0, 1.0], | |
| 141 | + style_intent_selected_sku_boost=1.2, | |
| 142 | + debug=True, | |
| 143 | + ) | |
| 144 | + | |
| 145 | + by_id = {h["_id"]: h for h in hits} | |
| 146 | + assert isclose(by_id["style-selected"]["_fused_score"], by_id["plain"]["_fused_score"] * 1.2, rel_tol=1e-9) | |
| 147 | + assert by_id["style-selected"]["_style_intent_selected_sku_boost"] == 1.2 | |
| 148 | + assert by_id["plain"]["_style_intent_selected_sku_boost"] == 1.0 | |
| 149 | + assert [h["_id"] for h in hits] == ["style-selected", "plain"] | |
| 150 | + assert debug[0]["style_intent_selected_sku"] is True | |
| 151 | + assert debug[0]["style_intent_selected_sku_boost"] == 1.2 | ... | ... |
| ... | ... | @@ -0,0 +1,119 @@ |
| 1 | +from __future__ import annotations | |
| 2 | + | |
| 3 | +import sys | |
| 4 | +import types | |
| 5 | + | |
| 6 | +from reranker.backends import get_rerank_backend | |
| 7 | +from reranker.backends.qwen3_gguf import Qwen3GGUFRerankerBackend | |
| 8 | + | |
| 9 | + | |
| 10 | +class _FakeLlama: | |
| 11 | + def __init__(self, model_path: str | None = None, **kwargs): | |
| 12 | + self.model_path = model_path | |
| 13 | + self.kwargs = kwargs | |
| 14 | + self.eval_logits = [] | |
| 15 | + self._tokens = [] | |
| 16 | + self.eval_call_count = 0 | |
| 17 | + | |
| 18 | + @classmethod | |
| 19 | + def from_pretrained(cls, repo_id: str, filename: str, local_dir=None, cache_dir=None, **kwargs): | |
| 20 | + inst = cls(model_path=f"{repo_id}/{filename}", **kwargs) | |
| 21 | + inst.repo_id = repo_id | |
| 22 | + inst.filename = filename | |
| 23 | + inst.local_dir = local_dir | |
| 24 | + inst.cache_dir = cache_dir | |
| 25 | + return inst | |
| 26 | + | |
| 27 | + def tokenize(self, text: bytes, add_bos: bool = False, special: bool = False): | |
| 28 | + raw = text.decode("utf-8") | |
| 29 | + if raw == "yes": | |
| 30 | + return [1] | |
| 31 | + if raw == "no": | |
| 32 | + return [2] | |
| 33 | + return [10 + (ord(ch) % 17) for ch in raw] | |
| 34 | + | |
| 35 | + def reset(self): | |
| 36 | + self._tokens = [] | |
| 37 | + return None | |
| 38 | + | |
| 39 | + def eval(self, prompt_tokens): | |
| 40 | + self.eval_call_count += 1 | |
| 41 | + self._tokens.extend(prompt_tokens) | |
| 42 | + pos = float(sum(self._tokens) % 11) + 3.0 | |
| 43 | + neg = 1.0 | |
| 44 | + logits = [0.0] * 64 | |
| 45 | + logits[1] = pos | |
| 46 | + logits[2] = neg | |
| 47 | + self.eval_logits = [logits] | |
| 48 | + | |
| 49 | + def save_state(self): | |
| 50 | + return list(self._tokens) | |
| 51 | + | |
| 52 | + def load_state(self, state): | |
| 53 | + self._tokens = list(state) | |
| 54 | + | |
| 55 | + | |
| 56 | +def _install_fake_llama_cpp(monkeypatch): | |
| 57 | + fake_module = types.SimpleNamespace(Llama=_FakeLlama) | |
| 58 | + monkeypatch.setitem(sys.modules, "llama_cpp", fake_module) | |
| 59 | + | |
| 60 | + | |
| 61 | +def test_qwen3_gguf_backend_factory_loads(monkeypatch): | |
| 62 | + _install_fake_llama_cpp(monkeypatch) | |
| 63 | + backend = get_rerank_backend( | |
| 64 | + "qwen3_gguf", | |
| 65 | + { | |
| 66 | + "repo_id": "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF", | |
| 67 | + "filename": "*Q8_0.gguf", | |
| 68 | + "enable_warmup": False, | |
| 69 | + }, | |
| 70 | + ) | |
| 71 | + assert isinstance(backend, Qwen3GGUFRerankerBackend) | |
| 72 | + assert backend._backend_name == "qwen3_gguf" | |
| 73 | + | |
| 74 | + | |
| 75 | +def test_qwen3_gguf_06b_backend_factory_loads(monkeypatch): | |
| 76 | + _install_fake_llama_cpp(monkeypatch) | |
| 77 | + backend = get_rerank_backend( | |
| 78 | + "qwen3_gguf_06b", | |
| 79 | + { | |
| 80 | + "enable_warmup": False, | |
| 81 | + }, | |
| 82 | + ) | |
| 83 | + assert isinstance(backend, Qwen3GGUFRerankerBackend) | |
| 84 | + assert backend._backend_name == "qwen3_gguf_06b" | |
| 85 | + assert backend._repo_id == "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF" | |
| 86 | + assert backend._filename == "qwen3-reranker-0.6b-q8_0.gguf" | |
| 87 | + | |
| 88 | + | |
| 89 | +def test_qwen3_gguf_backend_score_with_meta_dedup_and_restore(monkeypatch): | |
| 90 | + _install_fake_llama_cpp(monkeypatch) | |
| 91 | + backend = Qwen3GGUFRerankerBackend( | |
| 92 | + { | |
| 93 | + "repo_id": "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF", | |
| 94 | + "filename": "*Q8_0.gguf", | |
| 95 | + "enable_warmup": False, | |
| 96 | + "infer_batch_size": 2, | |
| 97 | + "sort_by_doc_length": True, | |
| 98 | + "reuse_query_state": True, | |
| 99 | + } | |
| 100 | + ) | |
| 101 | + | |
| 102 | + scores, meta = backend.score_with_meta( | |
| 103 | + query="wireless mouse", | |
| 104 | + docs=["doc-a", "doc-b", "doc-a", "", " ", None], | |
| 105 | + normalize=True, | |
| 106 | + ) | |
| 107 | + | |
| 108 | + assert len(scores) == 6 | |
| 109 | + assert scores[0] == scores[2] | |
| 110 | + assert scores[0] > 0.5 | |
| 111 | + assert scores[1] > 0.5 | |
| 112 | + assert scores[3:] == [0.0, 0.0, 0.0] | |
| 113 | + assert meta["input_docs"] == 6 | |
| 114 | + assert meta["usable_docs"] == 3 | |
| 115 | + assert meta["unique_docs"] == 2 | |
| 116 | + assert meta["backend"] == "qwen3_gguf" | |
| 117 | + assert meta["inference_batches"] == 1 | |
| 118 | + assert meta["reuse_query_state"] is True | |
| 119 | + assert backend._llm.eval_call_count == 3 | ... | ... |
tests/test_search_rerank_window.py
| ... | ... | @@ -63,6 +63,7 @@ def _build_style_intent_profile(intent_type: str, canonical_value: str, *dimensi |
| 63 | 63 | canonical_value=canonical_value, |
| 64 | 64 | matched_term=canonical_value, |
| 65 | 65 | matched_query_text=canonical_value, |
| 66 | + attribute_terms=(canonical_value,), | |
| 66 | 67 | dimension_aliases=tuple(aliases), |
| 67 | 68 | ), |
| 68 | 69 | ) | ... | ... |
| ... | ... | @@ -0,0 +1,197 @@ |
| 1 | +from types import SimpleNamespace | |
| 2 | + | |
| 3 | +from config import QueryConfig | |
| 4 | +from query.style_intent import DetectedStyleIntent, StyleIntentProfile, StyleIntentRegistry | |
| 5 | +from search.sku_intent_selector import StyleSkuSelector | |
| 6 | + | |
| 7 | + | |
| 8 | +def test_style_sku_selector_matches_first_sku_by_attribute_terms(): | |
| 9 | + registry = StyleIntentRegistry.from_query_config( | |
| 10 | + QueryConfig( | |
| 11 | + style_intent_terms={ | |
| 12 | + "color": [{"en_terms": ["navy"], "zh_terms": ["藏青"], "attribute_terms": ["navy"]}], | |
| 13 | + "size": [{"en_terms": ["xl"], "zh_terms": ["加大码"], "attribute_terms": ["x-large"]}], | |
| 14 | + }, | |
| 15 | + style_intent_dimension_aliases={ | |
| 16 | + "color": ["color", "颜色"], | |
| 17 | + "size": ["size", "尺码"], | |
| 18 | + }, | |
| 19 | + ) | |
| 20 | + ) | |
| 21 | + selector = StyleSkuSelector(registry) | |
| 22 | + parsed_query = SimpleNamespace( | |
| 23 | + style_intent_profile=StyleIntentProfile( | |
| 24 | + intents=( | |
| 25 | + DetectedStyleIntent( | |
| 26 | + intent_type="color", | |
| 27 | + canonical_value="navy", | |
| 28 | + matched_term="藏青", | |
| 29 | + matched_query_text="藏青", | |
| 30 | + attribute_terms=("navy",), | |
| 31 | + dimension_aliases=("color", "颜色"), | |
| 32 | + ), | |
| 33 | + DetectedStyleIntent( | |
| 34 | + intent_type="size", | |
| 35 | + canonical_value="x-large", | |
| 36 | + matched_term="xl", | |
| 37 | + matched_query_text="xl", | |
| 38 | + attribute_terms=("x-large",), | |
| 39 | + dimension_aliases=("size", "尺码"), | |
| 40 | + ), | |
| 41 | + ), | |
| 42 | + ) | |
| 43 | + ) | |
| 44 | + source = { | |
| 45 | + "option1_name": "Color", | |
| 46 | + "option2_name": "Size", | |
| 47 | + "skus": [ | |
| 48 | + {"sku_id": "1", "option1_value": "Black", "option2_value": "M"}, | |
| 49 | + {"sku_id": "2", "option1_value": "Navy Blue", "option2_value": "X-Large", "image_src": "matched.jpg"}, | |
| 50 | + {"sku_id": "3", "option1_value": "Navy", "option2_value": "XL"}, | |
| 51 | + ], | |
| 52 | + } | |
| 53 | + hits = [{"_id": "spu-1", "_source": source}] | |
| 54 | + | |
| 55 | + decisions = selector.prepare_hits(hits, parsed_query) | |
| 56 | + decision = decisions["spu-1"] | |
| 57 | + | |
| 58 | + assert decision.selected_sku_id == "2" | |
| 59 | + assert decision.selected_text == "Navy Blue X-Large" | |
| 60 | + assert decision.matched_stage == "text" | |
| 61 | + | |
| 62 | + selector.apply_precomputed_decisions(hits, decisions) | |
| 63 | + | |
| 64 | + assert source["skus"][0]["sku_id"] == "2" | |
| 65 | + assert source["image_url"] == "matched.jpg" | |
| 66 | + | |
| 67 | + | |
| 68 | +def test_style_sku_selector_returns_no_match_without_attribute_contains(): | |
| 69 | + registry = StyleIntentRegistry.from_query_config( | |
| 70 | + QueryConfig( | |
| 71 | + style_intent_terms={ | |
| 72 | + "color": [{"en_terms": ["beige"], "zh_terms": ["米色"], "attribute_terms": ["beige"]}], | |
| 73 | + }, | |
| 74 | + style_intent_dimension_aliases={"color": ["color", "颜色"]}, | |
| 75 | + ) | |
| 76 | + ) | |
| 77 | + selector = StyleSkuSelector(registry) | |
| 78 | + parsed_query = SimpleNamespace( | |
| 79 | + style_intent_profile=StyleIntentProfile( | |
| 80 | + intents=( | |
| 81 | + DetectedStyleIntent( | |
| 82 | + intent_type="color", | |
| 83 | + canonical_value="beige", | |
| 84 | + matched_term="米色", | |
| 85 | + matched_query_text="米色", | |
| 86 | + attribute_terms=("beige",), | |
| 87 | + dimension_aliases=("color", "颜色"), | |
| 88 | + ), | |
| 89 | + ), | |
| 90 | + ) | |
| 91 | + ) | |
| 92 | + hits = [{ | |
| 93 | + "_id": "spu-1", | |
| 94 | + "_source": { | |
| 95 | + "option1_name": "Color", | |
| 96 | + "skus": [ | |
| 97 | + {"sku_id": "1", "option1_value": "Khaki"}, | |
| 98 | + {"sku_id": "2", "option1_value": "Light Brown"}, | |
| 99 | + ], | |
| 100 | + }, | |
| 101 | + }] | |
| 102 | + | |
| 103 | + decisions = selector.prepare_hits(hits, parsed_query) | |
| 104 | + | |
| 105 | + assert decisions["spu-1"].selected_sku_id is None | |
| 106 | + assert decisions["spu-1"].matched_stage == "no_match" | |
| 107 | + | |
| 108 | + | |
| 109 | +def test_is_text_match_uses_token_boundaries_for_sizes(): | |
| 110 | + registry = StyleIntentRegistry.from_query_config( | |
| 111 | + QueryConfig( | |
| 112 | + style_intent_terms={ | |
| 113 | + "size": [{"en_terms": ["l"], "zh_terms": ["大码"], "attribute_terms": ["l"]}], | |
| 114 | + }, | |
| 115 | + style_intent_dimension_aliases={"size": ["size", "尺码"]}, | |
| 116 | + ) | |
| 117 | + ) | |
| 118 | + selector = StyleSkuSelector(registry) | |
| 119 | + style_profile = StyleIntentProfile( | |
| 120 | + intents=( | |
| 121 | + DetectedStyleIntent( | |
| 122 | + intent_type="size", | |
| 123 | + canonical_value="l", | |
| 124 | + matched_term="l", | |
| 125 | + matched_query_text="l", | |
| 126 | + attribute_terms=("l",), | |
| 127 | + dimension_aliases=("size", "尺码"), | |
| 128 | + ), | |
| 129 | + ), | |
| 130 | + ) | |
| 131 | + selection_context = selector._build_selection_context(style_profile) | |
| 132 | + | |
| 133 | + assert selector._is_text_match("size", selection_context, normalized_value="l") | |
| 134 | + assert not selector._is_text_match("size", selection_context, normalized_value="xl") | |
| 135 | + assert not selector._is_text_match("size", selection_context, normalized_value="xxl") | |
| 136 | + | |
| 137 | + | |
| 138 | +def test_is_text_match_handles_punctuation_and_descriptive_attribute_values(): | |
| 139 | + registry = StyleIntentRegistry.from_query_config( | |
| 140 | + QueryConfig( | |
| 141 | + style_intent_terms={ | |
| 142 | + "color": [{"en_terms": ["blue"], "zh_terms": ["蓝色"], "attribute_terms": ["blue"]}], | |
| 143 | + "style": [{"en_terms": ["off-white"], "zh_terms": ["米白"], "attribute_terms": ["off-white"]}], | |
| 144 | + "accessory": [{"en_terms": ["headscarf"], "zh_terms": ["头巾"], "attribute_terms": ["headscarf"]}], | |
| 145 | + "size": [{"en_terms": ["2xl"], "zh_terms": ["2xl"], "attribute_terms": ["2xl"]}], | |
| 146 | + }, | |
| 147 | + style_intent_dimension_aliases={ | |
| 148 | + "color": ["color", "颜色"], | |
| 149 | + "style": ["style", "风格"], | |
| 150 | + "accessory": ["accessory", "配饰"], | |
| 151 | + "size": ["size", "尺码"], | |
| 152 | + }, | |
| 153 | + ) | |
| 154 | + ) | |
| 155 | + selector = StyleSkuSelector(registry) | |
| 156 | + style_profile = StyleIntentProfile( | |
| 157 | + intents=( | |
| 158 | + DetectedStyleIntent( | |
| 159 | + intent_type="color", | |
| 160 | + canonical_value="blue", | |
| 161 | + matched_term="blue", | |
| 162 | + matched_query_text="blue", | |
| 163 | + attribute_terms=("blue",), | |
| 164 | + dimension_aliases=("color", "颜色"), | |
| 165 | + ), | |
| 166 | + DetectedStyleIntent( | |
| 167 | + intent_type="style", | |
| 168 | + canonical_value="off-white", | |
| 169 | + matched_term="off-white", | |
| 170 | + matched_query_text="off-white", | |
| 171 | + attribute_terms=("off-white",), | |
| 172 | + dimension_aliases=("style", "风格"), | |
| 173 | + ), | |
| 174 | + DetectedStyleIntent( | |
| 175 | + intent_type="accessory", | |
| 176 | + canonical_value="headscarf", | |
| 177 | + matched_term="headscarf", | |
| 178 | + matched_query_text="headscarf", | |
| 179 | + attribute_terms=("headscarf",), | |
| 180 | + dimension_aliases=("accessory", "配饰"), | |
| 181 | + ), | |
| 182 | + DetectedStyleIntent( | |
| 183 | + intent_type="size", | |
| 184 | + canonical_value="2xl", | |
| 185 | + matched_term="2xl", | |
| 186 | + matched_query_text="2xl", | |
| 187 | + attribute_terms=("2xl",), | |
| 188 | + dimension_aliases=("size", "尺码"), | |
| 189 | + ), | |
| 190 | + ), | |
| 191 | + ) | |
| 192 | + selection_context = selector._build_selection_context(style_profile) | |
| 193 | + | |
| 194 | + assert selector._is_text_match("color", selection_context, normalized_value="gray blue") | |
| 195 | + assert selector._is_text_match("style", selection_context, normalized_value="off-white/lined") | |
| 196 | + assert selector._is_text_match("accessory", selection_context, normalized_value="army green + headscarf") | |
| 197 | + assert selector._is_text_match("size", selection_context, normalized_value="2xl recommended 65-70kg") | ... | ... |
tests/test_style_intent.py
| ... | ... | @@ -7,8 +7,8 @@ from query.style_intent import StyleIntentDetector, StyleIntentRegistry |
| 7 | 7 | def test_style_intent_detector_matches_original_and_translated_queries(): |
| 8 | 8 | query_config = QueryConfig( |
| 9 | 9 | style_intent_terms={ |
| 10 | - "color": [["black", "黑色", "black"]], | |
| 11 | - "size": [["xl", "x-large", "加大码"]], | |
| 10 | + "color": [{"en_terms": ["black"], "zh_terms": ["黑色"], "attribute_terms": ["black"]}], | |
| 11 | + "size": [{"en_terms": ["xl", "x-large"], "zh_terms": ["加大码"], "attribute_terms": ["x-large"]}], | |
| 12 | 12 | }, |
| 13 | 13 | style_intent_dimension_aliases={ |
| 14 | 14 | "color": ["color", "颜色"], |
| ... | ... | @@ -31,5 +31,30 @@ def test_style_intent_detector_matches_original_and_translated_queries(): |
| 31 | 31 | |
| 32 | 32 | assert profile.is_active is True |
| 33 | 33 | assert profile.get_canonical_values("color") == {"black"} |
| 34 | - assert profile.get_canonical_values("size") == {"xl"} | |
| 34 | + assert profile.get_canonical_values("size") == {"x-large"} | |
| 35 | 35 | assert len(profile.query_variants) == 2 |
| 36 | + | |
| 37 | + | |
| 38 | +def test_style_intent_detector_uses_original_query_when_language_translation_missing(): | |
| 39 | + query_config = QueryConfig( | |
| 40 | + style_intent_terms={ | |
| 41 | + "color": [{"en_terms": ["black"], "zh_terms": ["黑色"], "attribute_terms": ["black"]}], | |
| 42 | + }, | |
| 43 | + style_intent_dimension_aliases={"color": ["color", "颜色"]}, | |
| 44 | + ) | |
| 45 | + detector = StyleIntentDetector( | |
| 46 | + StyleIntentRegistry.from_query_config(query_config), | |
| 47 | + tokenizer=lambda text: text.split(), | |
| 48 | + ) | |
| 49 | + | |
| 50 | + parsed_query = SimpleNamespace( | |
| 51 | + original_query="black dress", | |
| 52 | + query_normalized="black dress", | |
| 53 | + rewritten_query="black dress", | |
| 54 | + translations={"zh": "连衣裙"}, | |
| 55 | + ) | |
| 56 | + | |
| 57 | + profile = detector.detect(parsed_query) | |
| 58 | + | |
| 59 | + assert profile.get_canonical_values("color") == {"black"} | |
| 60 | + assert profile.intents[0].attribute_terms == ("black",) | ... | ... |