Compare View

switch
from
...
to
 
Commits (9)
Showing 45 changed files   Show diff stats
config/config.yaml
... ... @@ -114,10 +114,11 @@ query_config:
114 114 # 查询解析阶段:翻译与 query 向量并发执行,共用同一等待预算(毫秒)。
115 115 # 检测语言已在租户 index_languages 内:较短;不在索引语言内:较长(翻译对召回更关键)。
116 116 translation_embedding_wait_budget_ms_source_in_index: 500 # 80
117   - translation_embedding_wait_budget_ms_source_not_in_index: 500 #200
  117 + translation_embedding_wait_budget_ms_source_not_in_index: 700 #200
118 118  
119 119 style_intent:
120 120 enabled: true
  121 + selected_sku_boost: 1.2
121 122 color_dictionary_path: "config/dictionaries/style_intent_color.csv"
122 123 size_dictionary_path: "config/dictionaries/style_intent_size.csv"
123 124 dimension_aliases:
... ... @@ -230,7 +231,7 @@ rerank:
230 231 text_bias: 0.1
231 232 text_exponent: 0.35
232 233 knn_bias: 0.6
233   - knn_exponent: 0.2
  234 + knn_exponent: 0.0
234 235  
235 236 # 可扩展服务/provider 注册表(单一配置源)
236 237 services:
... ... @@ -380,7 +381,7 @@ services:
380 381 max_docs: 1000
381 382 normalize: true
382 383 # 服务内后端(reranker 进程启动时读取)
383   - backend: "bge" # bge | qwen3_vllm | qwen3_transformers | dashscope_rerank
  384 + backend: "qwen3_vllm" # bge | qwen3_vllm | qwen3_vllm_score | qwen3_transformers | qwen3_transformers_packed | qwen3_gguf | qwen3_gguf_06b | dashscope_rerank
384 385 backends:
385 386 bge:
386 387 model_name: "BAAI/bge-reranker-v2-m3"
... ... @@ -401,6 +402,9 @@ services:
401 402 enforce_eager: false
402 403 infer_batch_size: 100
403 404 sort_by_doc_length: true
  405 + # 与 reranker/backends/qwen3_vllm.py 一致:standard=_format_instruction__standard(固定 yes/no system);compact=_format_instruction(instruction 作 system 且 user 内重复 Instruct)
  406 + # instruction_format: compact
  407 + instruction_format: compact
404 408 # instruction: "Given a query, score the product for relevance"
405 409 # "rank products by given query" 比 “Given a query, score the product for relevance” 更好点
406 410 # instruction: "rank products by given query, category match first"
... ... @@ -410,6 +414,32 @@ services:
410 414 # instruction: "Relevance ranking: category & style match first"
411 415 # instruction: "Score product relevance by query with category & style match prioritized"
412 416 instruction: "Rank products by query with category & style match prioritized"
  417 + # vLLM LLM.score()(跨编码打分)。独立高性能环境 .venv-reranker-score(vllm 0.18 固定版):./scripts/setup_reranker_venv.sh qwen3_vllm_score
  418 + # 与 qwen3_vllm 可共用同一 model_name / HF 缓存;venv 分离以便升级 vLLM 而不影响 generate 后端。
  419 + qwen3_vllm_score:
  420 + model_name: "Qwen/Qwen3-Reranker-0.6B"
  421 + # 官方 Hub 原版需 true;若改用已转换的 seq-cls 权重(如 tomaarsen/...-seq-cls)则设为 false
  422 + use_original_qwen3_hf_overrides: true
  423 + # vLLM 0.18:算力 < 8(如 T4)默认自动用 TRITON_ATTN;Ampere+ 可省略或设 auto。也可设环境变量 RERANK_VLLM_ATTENTION_BACKEND
  424 + # vllm_attention_backend: "auto"
  425 + # 可选:与 vLLM 对齐;一般保持 auto
  426 + # vllm_runner: "auto"
  427 + # vllm_convert: "auto"
  428 + # 可选:在 use_original_qwen3_hf_overrides 为 true 时与内置 overrides 合并
  429 + # hf_overrides: {}
  430 + engine: "vllm"
  431 + max_model_len: 160
  432 + tensor_parallel_size: 1
  433 + gpu_memory_utilization: 0.20
  434 + dtype: "float16"
  435 + enable_prefix_caching: true
  436 + enforce_eager: false
  437 + infer_batch_size: 100
  438 + sort_by_doc_length: true
  439 + # 与 qwen3_vllm 同名项语义一致;默认 standard 与 vLLM 官方 Qwen3 reranker 前缀一致
  440 + # instruction_format: compact
  441 + instruction_format: standard
  442 + instruction: "Rank products by query with category & style match prioritized"
413 443 qwen3_transformers:
414 444 model_name: "Qwen/Qwen3-Reranker-0.6B"
415 445 instruction: "rank products by given query"
... ... @@ -419,6 +449,68 @@ services:
419 449 use_fp16: true
420 450 # sdpa:默认无需 flash-attn;若已安装 flash_attn 可改为 flash_attention_2
421 451 attn_implementation: "sdpa"
  452 + # Packed Transformers backend: shared query prefix + custom position_ids/attention_mask.
  453 + # For 1 query + many short docs (for example 400 product titles), this usually reduces
  454 + # repeated prefix work and padding waste compared with pairwise batching.
  455 + qwen3_transformers_packed:
  456 + model_name: "Qwen/Qwen3-Reranker-0.6B"
  457 + instruction: "Rank products by query with category & style match prioritized"
  458 + max_model_len: 4096
  459 + max_doc_len: 160
  460 + max_docs_per_pack: 0
  461 + use_fp16: true
  462 + sort_by_doc_length: true
  463 + # Packed mode relies on a custom 4D attention mask. "eager" is the safest default.
  464 + # If your torch/transformers stack validates it, you can benchmark "sdpa".
  465 + attn_implementation: "eager"
  466 + qwen3_gguf:
  467 + repo_id: "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF"
  468 + filename: "*Q8_0.gguf"
  469 + cache_dir: "./model_cache"
  470 + local_dir: "./models/reranker/qwen3-reranker-4b-gguf"
  471 + instruction: "Rank products by query with category & style match prioritized"
  472 + # T4 16GB / 性能优先配置:全量层 offload,实测比保守配置明显更快
  473 + n_ctx: 512
  474 + n_batch: 512
  475 + n_ubatch: 512
  476 + n_gpu_layers: 999
  477 + main_gpu: 0
  478 + n_threads: 2
  479 + n_threads_batch: 4
  480 + flash_attn: true
  481 + offload_kqv: true
  482 + use_mmap: true
  483 + use_mlock: false
  484 + infer_batch_size: 8
  485 + sort_by_doc_length: true
  486 + length_sort_mode: "char"
  487 + enable_warmup: true
  488 + verbose: false
  489 + qwen3_gguf_06b:
  490 + repo_id: "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF"
  491 + filename: "qwen3-reranker-0.6b-q8_0.gguf"
  492 + cache_dir: "./model_cache"
  493 + local_dir: "./models/reranker/qwen3-reranker-0.6b-q8_0-gguf"
  494 + instruction: "Rank products by query with category & style match prioritized"
  495 + # 0.6B GGUF / online rerank baseline:
  496 + # 实测 400 titles 单请求约 265s,因此它更适合作为低显存功能后备,不适合在线低延迟主路由。
  497 + n_ctx: 256
  498 + n_batch: 256
  499 + n_ubatch: 256
  500 + n_gpu_layers: 999
  501 + main_gpu: 0
  502 + n_threads: 2
  503 + n_threads_batch: 4
  504 + flash_attn: true
  505 + offload_kqv: true
  506 + use_mmap: true
  507 + use_mlock: false
  508 + infer_batch_size: 32
  509 + sort_by_doc_length: true
  510 + length_sort_mode: "char"
  511 + reuse_query_state: false
  512 + enable_warmup: true
  513 + verbose: false
422 514 dashscope_rerank:
423 515 model_name: "qwen3-rerank"
424 516 # 按地域选择 endpoint:
... ...
config/dictionaries/product_title_exclusion.tsv
1 1 # zh triggers en triggers zh title exclusions en title exclusions
2   -修身 fitted 宽松 loose,relaxed,oversized,baggy,slouchy
  2 +修身,紧身 fitted,tight 宽松 loose,relaxed,oversized,baggy,slouchy
  3 +宽松 loose,relaxed,oversized,baggy,slouchy 修身,紧身 fitted,tight
... ...
config/dictionaries/style_intent_color.csv
1   -black,black,blk,黑,黑色
2   -white,white,wht,白,白色
3   -red,red,reddish,红,红色
4   -blue,blue,blu,蓝,蓝色
5   -green,green,grn,绿,绿色
6   -yellow,yellow,ylw,黄,黄色
7   -pink,pink,粉,粉色
8   -purple,purple,violet,紫,紫色
9   -gray,gray,grey,灰,灰色
10   -brown,brown,棕,棕色,咖啡色
11   -beige,beige,khaki,米色,卡其色
12   -navy,navy,navy blue,藏青,藏蓝,深蓝
13   -silver,silver,银,银色
14   -gold,gold,金,金色
15   -orange,orange,橙,橙色
  1 +"black,blk","黑,黑色","black"
  2 +"white,wht","白,白色","white"
  3 +"red,reddish","红,红色","red"
  4 +"blue,blu","蓝,蓝色","blue"
  5 +"green,grn","绿,绿色","green"
  6 +"yellow,ylw","黄,黄色","yellow"
  7 +"pink","粉,粉色","pink"
  8 +"purple,violet","紫,紫色","purple"
  9 +"gray,grey","灰,灰色","gray,grey"
  10 +"brown","棕,棕色,咖啡色","brown"
  11 +"beige,khaki","米色,卡其色","beige,khaki"
  12 +"navy,navy blue","藏青,藏蓝,深蓝","navy"
  13 +"silver","银,银色","silver"
  14 +"gold","金,金色","gold"
  15 +"orange","橙,橙色","orange"
... ...
config/dictionaries/style_intent_size.csv
1   -xs,xs,extra small,x-small,加小码
2   -s,s,small,小码,小号
3   -m,m,medium,中码,中号
4   -l,l,large,大码,大号
5   -xl,xl,x-large,extra large,加大码
6   -xxl,xxl,2xl,xx-large,双加大码
7   -xxxl,xxxl,3xl,xxx-large,三加大码
8   -one size,one size,onesize,free size,均码
  1 +"xs,extra small,x-small","加小码","xs,extra small,x-small"
  2 +"s,small","小码,小号","s,small"
  3 +"m,medium","中码,中号","m,medium"
  4 +"l,large","大码,大号","l,large"
  5 +"xl,x-large,extra large","加大码","xl,x-large,extra large"
  6 +"xxl,2xl,xx-large","双加大码","xxl,2xl,xx-large"
  7 +"xxxl,3xl,xxx-large","三加大码","xxxl,3xl,xxx-large"
... ...
config/loader.py
... ... @@ -10,6 +10,7 @@ from __future__ import annotations
10 10 import hashlib
11 11 import json
12 12 import os
  13 +import csv
13 14 from copy import deepcopy
14 15 from dataclasses import asdict
15 16 from functools import lru_cache
... ... @@ -96,20 +97,33 @@ def _read_rewrite_dictionary(path: Path) -&gt; Dict[str, str]:
96 97 return rewrite_dict
97 98  
98 99  
99   -def _read_synonym_csv_dictionary(path: Path) -> List[List[str]]:
100   - rows: List[List[str]] = []
  100 +def _read_synonym_csv_dictionary(path: Path) -> List[Dict[str, List[str]]]:
  101 + rows: List[Dict[str, List[str]]] = []
101 102 if not path.exists():
102 103 return rows
103 104  
  105 + def _split_terms(cell: str) -> List[str]:
  106 + return [item.strip() for item in str(cell or "").split(",") if item.strip()]
  107 +
104 108 with open(path, "r", encoding="utf-8") as handle:
105   - for raw_line in handle:
106   - line = raw_line.strip()
107   - if not line or line.startswith("#"):
  109 + reader = csv.reader(handle)
  110 + for parts in reader:
  111 + if not parts:
  112 + continue
  113 + if parts[0].strip().startswith("#"):
108 114 continue
109   - parts = [segment.strip() for segment in line.split(",")]
110   - normalized = [segment for segment in parts if segment]
111   - if normalized:
112   - rows.append(normalized)
  115 +
  116 + normalized = [segment.strip() for segment in parts]
  117 + if len(normalized) < 3:
  118 + continue
  119 +
  120 + row = {
  121 + "en_terms": _split_terms(normalized[0]),
  122 + "zh_terms": _split_terms(normalized[1]),
  123 + "attribute_terms": _split_terms(normalized[2]),
  124 + }
  125 + if any(row.values()):
  126 + rows.append(row)
113 127 return rows
114 128  
115 129  
... ... @@ -425,6 +439,9 @@ class AppConfigLoader:
425 439 query_cfg.get("translation_embedding_wait_budget_ms_source_not_in_index", 200)
426 440 ),
427 441 style_intent_enabled=bool(style_intent_cfg.get("enabled", True)),
  442 + style_intent_selected_sku_boost=float(
  443 + style_intent_cfg.get("selected_sku_boost", 1.2)
  444 + ),
428 445 style_intent_terms=style_intent_terms,
429 446 style_intent_dimension_aliases=style_dimension_aliases,
430 447 product_title_exclusion_enabled=bool(product_title_exclusion_cfg.get("enabled", True)),
... ...
config/schema.py
... ... @@ -65,7 +65,8 @@ class QueryConfig:
65 65 translation_embedding_wait_budget_ms_source_in_index: int = 80
66 66 translation_embedding_wait_budget_ms_source_not_in_index: int = 200
67 67 style_intent_enabled: bool = True
68   - style_intent_terms: Dict[str, List[List[str]]] = field(default_factory=dict)
  68 + style_intent_selected_sku_boost: float = 1.2
  69 + style_intent_terms: Dict[str, List[Dict[str, List[str]]]] = field(default_factory=dict)
69 70 style_intent_dimension_aliases: Dict[str, List[str]] = field(default_factory=dict)
70 71 product_title_exclusion_enabled: bool = True
71 72 product_title_exclusion_rules: List[Dict[str, List[str]]] = field(default_factory=list)
... ...
config/services_config.py
... ... @@ -7,6 +7,7 @@ contains no independent parsing or precedence logic.
7 7  
8 8 from __future__ import annotations
9 9  
  10 +import os
10 11 from typing import Any, Dict, Tuple
11 12  
12 13 from config.loader import get_app_config
... ... @@ -61,6 +62,12 @@ def get_embedding_image_backend_config() -&gt; Tuple[str, Dict[str, Any]]:
61 62  
62 63 def get_rerank_backend_config() -> Tuple[str, Dict[str, Any]]:
63 64 cfg = get_app_config().services.rerank
  65 + backend = str(os.getenv("RERANK_BACKEND") or cfg.backend).strip()
  66 + if backend != cfg.backend:
  67 + backend_cfg = cfg.backends.get(backend)
  68 + if backend_cfg is None:
  69 + raise ValueError(f"Unknown rerank backend override from RERANK_BACKEND: {backend!r}")
  70 + return backend, dict(backend_cfg)
64 71 return cfg.backend, cfg.get_backend_config()
65 72  
66 73  
... ...
perf_reports/reranker_vllm_instruction/2026-03-25/RESULTS.md 0 → 100644
... ... @@ -0,0 +1,61 @@
  1 +# Reranker benchmark: `qwen3_vllm` vs `qwen3_vllm_score` × `instruction_format`
  2 +
  3 +**Date:** 2026-03-25
  4 +**Host:** single GPU (Tesla T4, ~16 GiB), CUDA 12.8 (see `nvidia-smi` during run).
  5 +
  6 +## Configuration (from `config/config.yaml`)
  7 +
  8 +Shared across both backends for this run:
  9 +
  10 +| Key | Value |
  11 +|-----|-------|
  12 +| `model_name` | `Qwen/Qwen3-Reranker-0.6B` |
  13 +| `max_model_len` | 160 |
  14 +| `infer_batch_size` | 100 |
  15 +| `sort_by_doc_length` | true |
  16 +| `enable_prefix_caching` | true |
  17 +| `enforce_eager` | false |
  18 +| `dtype` | float16 |
  19 +| `tensor_parallel_size` | 1 |
  20 +| `gpu_memory_utilization` | 0.20 |
  21 +| `instruction` | `Rank products by query with category & style match prioritized` |
  22 +
  23 +`qwen3_vllm` uses vLLM **generate + logprobs** (`.venv-reranker`).
  24 +`qwen3_vllm_score` uses vLLM **`LLM.score()`** (`.venv-reranker-score`, pinned vLLM stack per `reranker/README.md`).
  25 +
  26 +## Methodology
  27 +
  28 +- Script: `python scripts/benchmark_reranker_random_titles.py 100,200,400,600,800,1000 --repeat 5` with **`--seed 99`** (see note below), **`--quiet-runs`**, **`--timeout 360`**.
  29 +- Titles: default file `/home/ubuntu/rerank_test/titles.1.8w` (one title per line).
  30 +- Query: default `健身女生T恤短袖`.
  31 +- Each scenario: **3 warm-up** requests at `n=400` (not timed), then **5 timed** runs per `n`.
  32 +- Metric: **client wall time** for `POST /rerank` (localhost), milliseconds.
  33 +- After each `services.rerank.backend` / `instruction_format` change: `./restart.sh reranker`, then **`GET /health`** until `backend` and `instruction_format` matched the intended scenario (extended `reranker/server.py` to expose `instruction_format` when the backend defines `_instruction_format`).
  34 +
  35 +**Note on RNG seed:** With `--seed 42`, some runs occasionally lost one sample at `n=600` (non-200 or transport error). All figures below use **`--seed 99`** so every cell has **5/5** successful runs and comparable sampled titles.
  36 +
  37 +## Raw artifacts
  38 +
  39 +JSON aggregates (means, stdev, raw `values_ms`): same directory, `qwen3_vllm_{compact,standard}.json`, `qwen3_vllm_score_{compact,standard}.json`.
  40 +
  41 +## Results — mean latency (ms)
  42 +
  43 +| backend | instruction_format | n=100 | n=200 | n=400 | n=600 | n=800 | n=1000 |
  44 +|---------|-------------------|------:|------:|------:|------:|------:|-------:|
  45 +| `qwen3_vllm` | `compact` | 213.5 | 418.0 | 861.4 | 1263.4 | 1744.3 | 2162.2 |
  46 +| `qwen3_vllm` | `standard` | 254.9 | 475.4 | 909.7 | 1353.2 | 1912.5 | 2406.7 |
  47 +| `qwen3_vllm_score` | `compact` | 239.2 | 480.2 | 966.2 | 1433.5 | 1937.2 | 2428.4 |
  48 +| `qwen3_vllm_score` | `standard` | 299.6 | 591.8 | 1178.9 | 1773.7 | 2341.6 | 2931.7 |
  49 +
  50 +## Short interpretation
  51 +
  52 +1. **`compact` vs `standard`:** For both backends, **`compact` is faster** on this setup (shorter / different chat template vs fixed yes/no system prompt + user block — see `reranker/backends/qwen3_vllm.py` / `qwen3_vllm_score.py`).
  53 +2. **`qwen3_vllm` vs `qwen3_vllm_score`:** At **`n=1000`**, **`qwen3_vllm` + `compact`** is the fastest row (~2162 ms mean); **`qwen3_vllm_score` + `standard`** is the slowest (~2932 ms). Ordering can change on other GPUs / vLLM versions / batching.
  54 +3. **Repo default** after tests: `services.rerank.backend: qwen3_vllm_score`, `instruction_format: compact` on **both** `qwen3_vllm` and `qwen3_vllm_score` blocks (patch script keeps them aligned).
  55 +
  56 +## Tooling added / changed
  57 +
  58 +- `reranker/server.py`: `/health` includes `instruction_format` when the active backend sets `_instruction_format`.
  59 +- `scripts/benchmark_reranker_random_titles.py`: `--tag`, `--json-summary-out`, `--quiet-runs`.
  60 +- `scripts/patch_rerank_vllm_benchmark_config.py`: surgical YAML patch (preserves newlines).
  61 +- `scripts/run_reranker_vllm_instruction_benchmark.sh`: full matrix driver (continues if a benchmark exits non-zero; uses `--timeout 360`).
... ...
query/style_intent.py
... ... @@ -11,38 +11,79 @@ from .tokenization import TokenizedText, normalize_query_text, tokenize_text
11 11  
12 12  
13 13 @dataclass(frozen=True)
  14 +class StyleIntentTermDefinition:
  15 + canonical_value: str
  16 + en_terms: Tuple[str, ...]
  17 + zh_terms: Tuple[str, ...]
  18 + attribute_terms: Tuple[str, ...]
  19 +
  20 +
  21 +@dataclass(frozen=True)
14 22 class StyleIntentDefinition:
15 23 intent_type: str
16   - term_groups: Tuple[Tuple[str, ...], ...]
  24 + terms: Tuple[StyleIntentTermDefinition, ...]
17 25 dimension_aliases: Tuple[str, ...]
18   - synonym_to_canonical: Dict[str, str]
  26 + en_synonym_to_term: Dict[str, StyleIntentTermDefinition]
  27 + zh_synonym_to_term: Dict[str, StyleIntentTermDefinition]
19 28 max_term_ngram: int = 3
20 29  
21 30 @classmethod
22 31 def from_rows(
23 32 cls,
24 33 intent_type: str,
25   - rows: Sequence[Sequence[str]],
  34 + rows: Sequence[Dict[str, List[str]]],
26 35 dimension_aliases: Sequence[str],
27 36 ) -> "StyleIntentDefinition":
28   - term_groups: List[Tuple[str, ...]] = []
29   - synonym_to_canonical: Dict[str, str] = {}
  37 + terms: List[StyleIntentTermDefinition] = []
  38 + en_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {}
  39 + zh_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {}
30 40 max_ngram = 1
31 41  
32 42 for row in rows:
33   - normalized_terms: List[str] = []
34   - for raw_term in row:
35   - term = normalize_query_text(raw_term)
36   - if not term or term in normalized_terms:
37   - continue
38   - normalized_terms.append(term)
39   - if not normalized_terms:
  43 + normalized_en = tuple(
  44 + dict.fromkeys(
  45 + term
  46 + for term in (normalize_query_text(raw) for raw in row.get("en_terms", []))
  47 + if term
  48 + )
  49 + )
  50 + normalized_zh = tuple(
  51 + dict.fromkeys(
  52 + term
  53 + for term in (normalize_query_text(raw) for raw in row.get("zh_terms", []))
  54 + if term
  55 + )
  56 + )
  57 + normalized_attribute = tuple(
  58 + dict.fromkeys(
  59 + term
  60 + for term in (normalize_query_text(raw) for raw in row.get("attribute_terms", []))
  61 + if term
  62 + )
  63 + )
  64 + if not normalized_en and not normalized_zh and not normalized_attribute:
40 65 continue
41 66  
42   - canonical = normalized_terms[0]
43   - term_groups.append(tuple(normalized_terms))
44   - for term in normalized_terms:
45   - synonym_to_canonical[term] = canonical
  67 + canonical = (
  68 + normalized_attribute[0]
  69 + if normalized_attribute
  70 + else normalized_en[0]
  71 + if normalized_en
  72 + else normalized_zh[0]
  73 + )
  74 + term_definition = StyleIntentTermDefinition(
  75 + canonical_value=canonical,
  76 + en_terms=normalized_en,
  77 + zh_terms=normalized_zh,
  78 + attribute_terms=normalized_attribute,
  79 + )
  80 + terms.append(term_definition)
  81 +
  82 + for term in normalized_en:
  83 + en_synonym_to_term[term] = term_definition
  84 + max_ngram = max(max_ngram, len(term.split()))
  85 + for term in normalized_zh:
  86 + zh_synonym_to_term[term] = term_definition
46 87 max_ngram = max(max_ngram, len(term.split()))
47 88  
48 89 aliases = tuple(
... ... @@ -58,28 +99,31 @@ class StyleIntentDefinition:
58 99  
59 100 return cls(
60 101 intent_type=intent_type,
61   - term_groups=tuple(term_groups),
  102 + terms=tuple(terms),
62 103 dimension_aliases=aliases,
63   - synonym_to_canonical=synonym_to_canonical,
  104 + en_synonym_to_term=en_synonym_to_term,
  105 + zh_synonym_to_term=zh_synonym_to_term,
64 106 max_term_ngram=max_ngram,
65 107 )
66 108  
67   - def match_candidates(self, candidates: Iterable[str]) -> Set[str]:
68   - matched: Set[str] = set()
  109 + def match_candidates(self, candidates: Iterable[str], *, language: str) -> Set[StyleIntentTermDefinition]:
  110 + mapping = self.zh_synonym_to_term if language == "zh" else self.en_synonym_to_term
  111 + matched: Set[StyleIntentTermDefinition] = set()
69 112 for candidate in candidates:
70   - canonical = self.synonym_to_canonical.get(normalize_query_text(candidate))
71   - if canonical:
72   - matched.add(canonical)
  113 + term_definition = mapping.get(normalize_query_text(candidate))
  114 + if term_definition:
  115 + matched.add(term_definition)
73 116 return matched
74 117  
75 118 def match_text(
76 119 self,
77 120 text: str,
78 121 *,
  122 + language: str,
79 123 tokenizer: Optional[Callable[[str], Any]] = None,
80   - ) -> Set[str]:
  124 + ) -> Set[StyleIntentTermDefinition]:
81 125 bundle = tokenize_text(text, tokenizer=tokenizer, max_ngram=self.max_term_ngram)
82   - return self.match_candidates(bundle.candidates)
  126 + return self.match_candidates(bundle.candidates, language=language)
83 127  
84 128  
85 129 @dataclass(frozen=True)
... ... @@ -88,6 +132,7 @@ class DetectedStyleIntent:
88 132 canonical_value: str
89 133 matched_term: str
90 134 matched_query_text: str
  135 + attribute_terms: Tuple[str, ...]
91 136 dimension_aliases: Tuple[str, ...]
92 137  
93 138 def to_dict(self) -> Dict[str, Any]:
... ... @@ -96,6 +141,7 @@ class DetectedStyleIntent:
96 141 "canonical_value": self.canonical_value,
97 142 "matched_term": self.matched_term,
98 143 "matched_query_text": self.matched_query_text,
  144 + "attribute_terms": list(self.attribute_terms),
99 145 "dimension_aliases": list(self.dimension_aliases),
100 146 }
101 147  
... ... @@ -159,7 +205,7 @@ class StyleIntentRegistry:
159 205 rows=rows or [],
160 206 dimension_aliases=dimension_aliases.get(intent_type, []),
161 207 )
162   - if definition.synonym_to_canonical:
  208 + if definition.terms:
163 209 definitions[definition.intent_type] = definition
164 210  
165 211 return cls(
... ... @@ -191,15 +237,10 @@ class StyleIntentDetector:
191 237 seen = set()
192 238 variants: List[TokenizedText] = []
193 239 texts = [
194   - getattr(parsed_query, "original_query", None),
195   - getattr(parsed_query, "query_normalized", None),
196   - getattr(parsed_query, "rewritten_query", None),
  240 + self._get_language_query_text(parsed_query, "zh"),
  241 + self._get_language_query_text(parsed_query, "en"),
197 242 ]
198 243  
199   - translations = getattr(parsed_query, "translations", {}) or {}
200   - if isinstance(translations, dict):
201   - texts.extend(translations.values())
202   -
203 244 for raw_text in texts:
204 245 text = str(raw_text or "").strip()
205 246 if not text:
... ... @@ -221,35 +262,66 @@ class StyleIntentDetector:
221 262  
222 263 return tuple(variants)
223 264  
  265 + @staticmethod
  266 + def _get_language_query_text(parsed_query: Any, language: str) -> str:
  267 + translations = getattr(parsed_query, "translations", {}) or {}
  268 + if isinstance(translations, dict):
  269 + translated = translations.get(language)
  270 + if translated:
  271 + return str(translated)
  272 + return str(getattr(parsed_query, "original_query", "") or "")
  273 +
  274 + def _tokenize_language_query(self, parsed_query: Any, language: str) -> Optional[TokenizedText]:
  275 + text = self._get_language_query_text(parsed_query, language).strip()
  276 + if not text:
  277 + return None
  278 + return tokenize_text(
  279 + text,
  280 + tokenizer=self.tokenizer,
  281 + max_ngram=max(
  282 + (definition.max_term_ngram for definition in self.registry.definitions.values()),
  283 + default=3,
  284 + ),
  285 + )
  286 +
224 287 def detect(self, parsed_query: Any) -> StyleIntentProfile:
225 288 if not self.registry.enabled or not self.registry.definitions:
226 289 return StyleIntentProfile()
227 290  
228 291 query_variants = self._build_query_variants(parsed_query)
  292 + zh_variant = self._tokenize_language_query(parsed_query, "zh")
  293 + en_variant = self._tokenize_language_query(parsed_query, "en")
229 294 detected: List[DetectedStyleIntent] = []
230 295 seen_pairs = set()
231 296  
232   - for variant in query_variants:
233   - for intent_type, definition in self.registry.definitions.items():
234   - matched_canonicals = definition.match_candidates(variant.candidates)
235   - if not matched_canonicals:
  297 + for intent_type, definition in self.registry.definitions.items():
  298 + for language, variant, mapping in (
  299 + ("zh", zh_variant, definition.zh_synonym_to_term),
  300 + ("en", en_variant, definition.en_synonym_to_term),
  301 + ):
  302 + if variant is None or not mapping:
  303 + continue
  304 +
  305 + matched_terms = definition.match_candidates(variant.candidates, language=language)
  306 + if not matched_terms:
236 307 continue
237 308  
238 309 for candidate in variant.candidates:
239 310 normalized_candidate = normalize_query_text(candidate)
240   - canonical = definition.synonym_to_canonical.get(normalized_candidate)
241   - if not canonical or canonical not in matched_canonicals:
  311 + term_definition = mapping.get(normalized_candidate)
  312 + if term_definition is None or term_definition not in matched_terms:
242 313 continue
243   - pair = (intent_type, canonical)
  314 + pair = (intent_type, term_definition.canonical_value)
244 315 if pair in seen_pairs:
245 316 continue
246 317 seen_pairs.add(pair)
247 318 detected.append(
248 319 DetectedStyleIntent(
249 320 intent_type=intent_type,
250   - canonical_value=canonical,
  321 + canonical_value=term_definition.canonical_value,
251 322 matched_term=normalized_candidate,
252 323 matched_query_text=variant.text,
  324 + attribute_terms=term_definition.attribute_terms,
253 325 dimension_aliases=definition.dimension_aliases,
254 326 )
255 327 )
... ...
requirements_reranker_base.txt 0 → 100644
... ... @@ -0,0 +1,7 @@
  1 +# Shared base dependencies for reranker service venvs.
  2 +
  3 +fastapi>=0.100.0
  4 +uvicorn[standard]>=0.23.0
  5 +pydantic>=2.0.0
  6 +numpy>=1.24.0
  7 +pyyaml>=6.0
... ...
requirements_reranker_bge.txt 0 → 100644
... ... @@ -0,0 +1,7 @@
  1 +# Isolated dependencies for bge reranker backend.
  2 +
  3 +-r requirements_reranker_base.txt
  4 +torch>=2.0.0
  5 +transformers>=4.30.0
  6 +sentence-transformers>=2.2.0
  7 +modelscope>=1.9.0
... ...
requirements_reranker_dashscope.txt 0 → 100644
... ... @@ -0,0 +1,3 @@
  1 +# Isolated dependencies for dashscope_rerank backend.
  2 +
  3 +-r requirements_reranker_base.txt
... ...
requirements_reranker_qwen3_gguf.txt 0 → 100644
... ... @@ -0,0 +1,5 @@
  1 +# Isolated dependencies for qwen3_gguf reranker backend (.venv-reranker-gguf).
  2 +
  3 +-r requirements_reranker_base.txt
  4 +huggingface-hub>=0.32.0
  5 +llama-cpp-python>=0.3.16
... ...
requirements_reranker_qwen3_gguf_06b.txt 0 → 100644
... ... @@ -0,0 +1,3 @@
  1 +# Isolated dependencies for qwen3_gguf_06b reranker backend (.venv-reranker-gguf-06b).
  2 +
  3 +-r requirements_reranker_qwen3_gguf.txt
... ...
requirements_reranker_qwen3_transformers.txt 0 → 100644
... ... @@ -0,0 +1,5 @@
  1 +# Isolated dependencies for qwen3_transformers reranker backend.
  2 +
  3 +-r requirements_reranker_base.txt
  4 +torch>=2.0.0
  5 +transformers>=4.51.0
... ...
requirements_reranker_qwen3_transformers_packed.txt 0 → 100644
... ... @@ -0,0 +1,9 @@
  1 +# Isolated dependencies for qwen3_transformers_packed reranker backend.
  2 +#
  3 +# Keep this stack aligned with the validated CUDA runtime on our hosts.
  4 +# On this machine, torch 2.11.0 + cu130 fails CUDA init, while torch 2.10.0 + cu128 works.
  5 +# We also cap transformers <5 to stay on the same family as the working vLLM score env.
  6 +
  7 +-r requirements_reranker_qwen3_transformers.txt
  8 +torch==2.10.0
  9 +transformers>=4.51.0,<5
... ...
requirements_reranker_qwen3_vllm.txt 0 → 100644
... ... @@ -0,0 +1,5 @@
  1 +# Isolated dependencies for qwen3_vllm reranker backend (.venv-reranker).
  2 +
  3 +-r requirements_reranker_base.txt
  4 +transformers>=4.30.0
  5 +vllm>=0.8.5
... ...
requirements_reranker_qwen3_vllm_score.txt 0 → 100644
... ... @@ -0,0 +1,14 @@
  1 +# Dedicated high-performance venv for qwen3_vllm_score: .venv-reranker-score
  2 +#
  3 +# Create / refresh:
  4 +# ./scripts/setup_reranker_venv.sh qwen3_vllm_score
  5 +#
  6 +# vLLM 0.17+ replaces LLM(task="score") with runner/convert auto + LLM.score().
  7 +# Pin vLLM for reproducible perf baselines; bump after validating CUDA/driver on your hosts.
  8 +# If pip cannot find a wheel for your CUDA version, edit the vllm line or install from:
  9 +# https://docs.vllm.ai/en/latest/getting_started/installation.html
  10 +
  11 +-r requirements_reranker_base.txt
  12 +vllm==0.18.0
  13 +# Match vLLM 0.18 stack; cap <5 to avoid pip prefetching incompatible transformers 5.x.
  14 +transformers>=4.51.0,<5
... ...
requirements_reranker_service.txt
1   -# Isolated dependencies for reranker service (.venv-reranker)
  1 +# Legacy alias: qwen3_vllm reranker service env (.venv-reranker).
2 2 #
3   -# Default backend is qwen3_vllm (Qwen3-Reranker-0.6B).
  3 +# Prefer backend-specific requirements files:
  4 +# - requirements_reranker_qwen3_vllm.txt
  5 +# - requirements_reranker_qwen3_vllm_score.txt
  6 +# - requirements_reranker_qwen3_gguf.txt
  7 +# - requirements_reranker_qwen3_transformers.txt
  8 +# - requirements_reranker_bge.txt
  9 +# - requirements_reranker_dashscope.txt
4 10  
5   -fastapi>=0.100.0
6   -uvicorn[standard]>=0.23.0
7   -pydantic>=2.0.0
8   -numpy>=1.24.0
9   -pyyaml>=6.0
10   -transformers>=4.30.0
11   -vllm>=0.8.5
  11 +-r requirements_reranker_qwen3_vllm.txt
... ...
reranker/DEPLOYMENT_AND_TUNING.md
1   -# Reranker 部署与性能调优手册(Qwen3-vLLM
  1 +# Reranker 部署与性能调优手册(Qwen3-vLLM / Qwen3-GGUF
2 2  
3 3 本文档沉淀当前项目在电商搜索重排场景下的可复用实践,覆盖:
4 4  
5 5 - 环境准备与安装部署
6   -- `qwen3_vllm` 配置项与优化思路
  6 +- `qwen3_vllm` / `qwen3_gguf` / `qwen3_gguf_06b` 配置项与优化思路
7 7 - 1000-doc 场景压测流程
8 8 - 关键结论与推荐默认参数
9 9 - 常见故障排查
10 10  
11 11 适用范围:
12 12  
13   -- 重排后端:`services.rerank.backend: qwen3_vllm`
14   -- 模型:`Qwen/Qwen3-Reranker-0.6B`
  13 +- 重排后端:`services.rerank.backend: qwen3_vllm` / `qwen3_gguf` / `qwen3_gguf_06b`
  14 +- 模型:`Qwen/Qwen3-Reranker-0.6B` / `DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF` / `ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF`
15 15 - 场景:query 较短(通常 < 100 tokens),doc 为商品标题或标题+简短描述,单请求 docs 约 1000 条
16 16  
17 17 ## 1. 环境基线
18 18  
19   -当前验证环境(2026-03-11):
  19 +当前验证环境(2026-03-25):
20 20  
21 21 - GPU:`Tesla T4 16GB`
22 22 - Driver / CUDA:`570.158.01 / 12.8`
23 23 - Python:`3.12.3`
24   -- 关键依赖:`vllm==0.17.0`、`torch==2.10.0+cu128`、`transformers==4.57.6`、`fastapi==0.135.1`、`uvicorn==0.41.0`
  24 +- 关键依赖:`vllm==0.17.0`、`torch==2.10.0+cu128`、`transformers==4.57.6`、`llama-cpp-python>=0.3.16`、`fastapi==0.135.1`、`uvicorn==0.41.0`
25 25  
26 26 ## 2. 环境准备与安装
27 27  
28 28 ### 2.1 准备 reranker 独立虚拟环境
29 29  
30 30 ```bash
31   -./scripts/setup_reranker_venv.sh
  31 +./scripts/setup_reranker_venv.sh qwen3_vllm
  32 +```
  33 +
  34 +若使用 GGUF 并需要 CUDA:
  35 +
  36 +```bash
  37 +./scripts/setup_reranker_venv.sh qwen3_gguf
  38 +PATH=/usr/local/cuda/bin:$PATH \
  39 +CUDACXX=/usr/local/cuda/bin/nvcc \
  40 +CMAKE_ARGS="-DGGML_CUDA=on" \
  41 +FORCE_CMAKE=1 \
  42 +./.venv-reranker-gguf/bin/pip install --no-cache-dir --force-reinstall --no-build-isolation llama-cpp-python==0.3.18
32 43 ```
33 44  
34 45 ### 2.2 基础检查
... ... @@ -37,6 +48,7 @@
37 48 nvidia-smi
38 49 ./.venv-reranker/bin/python -c "import torch; print(torch.cuda.is_available())"
39 50 ./.venv-reranker/bin/python -c "import vllm, transformers; print(vllm.__version__, transformers.__version__)"
  51 +./.venv-reranker-gguf/bin/python -c "import llama_cpp; print(llama_cpp.__version__)"
40 52 ```
41 53  
42 54 ## 3. 部署与运行
... ... @@ -64,6 +76,29 @@ services:
64 76 length_sort_mode: "char" # char | token
65 77 ```
66 78  
  79 +GGUF / T4 剩余显存约 `4.8~6GB` 时,推荐基线:
  80 +
  81 +```yaml
  82 +services:
  83 + rerank:
  84 + backend: "qwen3_gguf"
  85 + backends:
  86 + qwen3_gguf:
  87 + repo_id: "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF"
  88 + filename: "*Q8_0.gguf"
  89 + local_dir: "./models/reranker/qwen3-reranker-4b-gguf"
  90 + cache_dir: "./model_cache"
  91 + n_ctx: 384
  92 + n_batch: 384
  93 + n_ubatch: 128
  94 + n_gpu_layers: 24
  95 + flash_attn: true
  96 + offload_kqv: true
  97 + infer_batch_size: 8
  98 + sort_by_doc_length: true
  99 + length_sort_mode: "char"
  100 +```
  101 +
67 102 ### 3.2 启停命令
68 103  
69 104 推荐统一使用:
... ... @@ -105,6 +140,13 @@ curl -sS http://127.0.0.1:6007/health
105 140 - `service_ctl.sh` 对 reranker 使用独立启动路径
106 141 - 增加“稳定健康检查”(连续健康探测)避免“刚 healthy 即退出”的假阳性
107 142  
  143 +### 4.4 GGUF / T4 小显存优化原则
  144 +
  145 +- `Q8_0` 权重约 `4.28GB`,但还要给 KV cache、CUDA 工作区和运行时碎片预留空间,不能按“模型大小 < 剩余显存”直接判断可行。
  146 +- 当前业务是短 query + 商品标题,优先压缩 `n_ctx`;`384` 通常比默认长上下文更划算。
  147 +- T4 小显存下先扫 `n_gpu_layers`,再尝试提高 `n_ctx`;`infer_batch_size` 在当前 GGUF 接入里主要是服务侧 work chunk,不是 llama.cpp 的真实算子 batch。
  148 +- `flash_attn: true`、`offload_kqv: true` 默认保持开启;若 OOM,优先降低 `n_gpu_layers`。
  149 +
108 150 ## 5. 性能调优流程(标准流程)
109 151  
110 152 ### 5.1 使用一键压测脚本
... ... @@ -125,6 +167,13 @@ curl -sS http://127.0.0.1:6007/health
125 167 - `infer_batch_size`: `24 32 48 64`
126 168 - 并发组:`c=1`(看单请求延迟)、`c=4`(看并发吞吐与尾延迟)
127 169  
  170 +GGUF 建议扫描:
  171 +
  172 +- `n_gpu_layers`: `20 24 28`
  173 +- `n_ctx`: `320 384 448`
  174 +- `infer_batch_size`: `4 8 12`(次要,仅影响服务侧 work chunk)
  175 +- 扫描顺序:先固定 `n_ctx=384`,找能稳定启动的最大 `n_gpu_layers`;再在显存允许时尝试 `n_ctx=448`;最后才微调 `infer_batch_size`
  176 +
128 177 可通过环境变量覆盖:
129 178  
130 179 - `BATCH_SIZES`
... ... @@ -140,23 +189,28 @@ curl -sS http://127.0.0.1:6007/health
140 189 - `RERANK_VLLM_INFER_BATCH_SIZE`
141 190 - `RERANK_VLLM_SORT_BY_DOC_LENGTH`
142 191  
143   -## 6. 本轮关键结论(2026-03-11)
144   -
145   -基于报告:
146   -
147   -- `perf_reports/20260311/reranker_1000docs/report.md`
  192 +## 6. 本轮关键结论
148 193  
149   -结论
  194 +vLLM(2026-03-11,见 `perf_reports/20260311/reranker_1000docs/report.md`)
150 195  
151 196 - 对在线重排更重要的单请求延迟(`c=1`)指标,`infer_batch_size=64` 最优
152 197 - `infer_batch_size=96` 在更高并发下吞吐略高,但会牺牲单请求延迟稳定性
153 198 - 当前默认选择 `infer_batch_size=64` 作为平衡点
154 199  
  200 +GGUF(2026-03-25,本次接入):
  201 +
  202 +- `DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF` 的 `Q8_0` 体积约 `4.28GB`,结合当前机器实测剩余显存约 `4823 MiB`,默认不采用激进的全量 GPU offload。
  203 +- 当前推荐默认值:`n_ctx=384`、`n_batch=384`、`n_ubatch=128`、`n_gpu_layers=24`、`infer_batch_size=8`。
  204 +- 若现场剩余显存更接近 `6GB` 且碎片较少,可优先尝试 `n_gpu_layers=28`;若启动失败,回退到 `24` 或 `20`。
  205 +- 由于当前工作区尚未缓存该 GGUF 权重,本次尚未完成真实吞吐压测;上线前需在部署机复跑一轮参数扫描并归档报告。
  206 +
155 207 ## 7. 生产建议
156 208  
157 209 - 默认保持:`infer_batch_size: 64`、`sort_by_doc_length: true`
158 210 - 满足以下条件时可考虑提高到 `96`:业务以吞吐优先、可接受更高单请求延迟、已通过同机同数据压测验证收益
159 211 - 每次改动后都必须复跑 `benchmark_reranker_1000docs.sh` 并归档结果
  212 +- GGUF 默认保持:`n_ctx: 384`、`n_gpu_layers: 24`、`infer_batch_size: 8`、`flash_attn: true`、`offload_kqv: true`
  213 +- GGUF 若 OOM:先降 `n_gpu_layers`,再降 `n_ctx`,最后再降 `infer_batch_size`
160 214  
161 215 ## 8. 故障排查
162 216  
... ... @@ -194,6 +248,13 @@ lsof -i :6007 -P -n
194 248 - 降低 `infer_batch_size`
195 249 - 检查是否有其他进程占用同卡
196 250  
  251 +GGUF 优先调整:
  252 +
  253 +- 降低 `n_gpu_layers`
  254 +- 降低 `n_ctx`
  255 +- 降低 `infer_batch_size`
  256 +- 检查是否有其他进程占用同卡
  257 +
197 258 ## 9. 变更与验证清单
198 259  
199 260 每次 reranker 调优改动后,至少完成:
... ...
reranker/GGUF_0_6B_INSTALL_AND_TUNING.md 0 → 100644
... ... @@ -0,0 +1,154 @@
  1 +# Qwen3-Reranker-0.6B GGUF 安装与调优
  2 +
  3 +本文档覆盖 `qwen3_gguf_06b` 后端,对应模型:
  4 +
  5 +- Hugging Face: `ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF`
  6 +- 文件: `qwen3-reranker-0.6b-q8_0.gguf`
  7 +- 本地目录: `./models/reranker/qwen3-reranker-0.6b-q8_0-gguf`
  8 +
  9 +## 结论先看
  10 +
  11 +这个后端已经接入完成,也能正常使用 GPU offload,但不适合当前项目的在线主链路场景。
  12 +
  13 +目标场景是:
  14 +
  15 +- 1 个 query
  16 +- 400 个商品标题
  17 +- 追求最短响应时间
  18 +
  19 +实测最优配置下:
  20 +
  21 +- GPU 显存占用约 `894 MiB`
  22 +- 400 titles 单请求延迟约 `265318 ms`
  23 +
  24 +因此它更适合作为:
  25 +
  26 +- 低显存 fallback
  27 +- 功能验证
  28 +- 本地离线实验
  29 +
  30 +不建议作为在线低延迟 reranker 主 backend。
  31 +
  32 +## 独立环境
  33 +
  34 +`qwen3_gguf_06b` 使用独立 venv:
  35 +
  36 +- backend: `qwen3_gguf_06b`
  37 +- venv: `.venv-reranker-gguf-06b`
  38 +- requirements: `requirements_reranker_qwen3_gguf_06b.txt`
  39 +
  40 +安装:
  41 +
  42 +```bash
  43 +./scripts/setup_reranker_venv.sh qwen3_gguf_06b
  44 +```
  45 +
  46 +如果需要确认是 CUDA 版 `llama-cpp-python`:
  47 +
  48 +```bash
  49 +./.venv-reranker-gguf-06b/bin/python - <<'PY'
  50 +import llama_cpp
  51 +print(llama_cpp.llama_supports_gpu_offload())
  52 +PY
  53 +```
  54 +
  55 +预期输出:
  56 +
  57 +```python
  58 +True
  59 +```
  60 +
  61 +## 模型下载
  62 +
  63 +推荐预先下载到本地,避免首次服务启动时在线拉取:
  64 +
  65 +```bash
  66 +mkdir -p models/reranker/qwen3-reranker-0.6b-q8_0-gguf
  67 +curl -L --fail -C - \
  68 + -o models/reranker/qwen3-reranker-0.6b-q8_0-gguf/qwen3-reranker-0.6b-q8_0.gguf \
  69 + 'https://huggingface.co/ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/resolve/main/qwen3-reranker-0.6b-q8_0.gguf?download=true'
  70 +```
  71 +
  72 +当前实测文件大小:
  73 +
  74 +- `639153184` bytes
  75 +
  76 +## 推荐配置
  77 +
  78 +`config/config.yaml` 中建议保留:
  79 +
  80 +```yaml
  81 +qwen3_gguf_06b:
  82 + repo_id: "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF"
  83 + filename: "qwen3-reranker-0.6b-q8_0.gguf"
  84 + local_dir: "./models/reranker/qwen3-reranker-0.6b-q8_0-gguf"
  85 + cache_dir: "./model_cache"
  86 + instruction: "Rank products by query with category & style match prioritized"
  87 + n_ctx: 256
  88 + n_batch: 256
  89 + n_ubatch: 256
  90 + n_gpu_layers: 999
  91 + main_gpu: 0
  92 + n_threads: 2
  93 + n_threads_batch: 4
  94 + flash_attn: true
  95 + offload_kqv: true
  96 + use_mmap: true
  97 + use_mlock: false
  98 + infer_batch_size: 32
  99 + sort_by_doc_length: true
  100 + length_sort_mode: "char"
  101 + reuse_query_state: false
  102 + enable_warmup: true
  103 + verbose: false
  104 +```
  105 +
  106 +## 调优结果
  107 +
  108 +在当前机器上做了同机实测。标题文件来自 `/home/ubuntu/rerank_test/titles.1.8w`,查询为 `白色oversized T-shirt`。
  109 +
  110 +80 titles:
  111 +
  112 +- `n_ctx=256, reuse_query_state=true` -> `60108 ms`
  113 +- `n_ctx=256, reuse_query_state=false` -> `53383~56893 ms`
  114 +- `n_ctx=320, reuse_query_state=true` -> `60961 ms`
  115 +- `n_ctx=384, reuse_query_state=true` -> `56578 ms`
  116 +- `n_ctx=384, reuse_query_state=false` -> `57272 ms`
  117 +- `n_ctx=512, reuse_query_state=false` -> `60542 ms`
  118 +- `n_ctx=256, reuse_query_state=false, n_threads=4, n_threads_batch=8` -> `61228 ms`
  119 +
  120 +400 titles:
  121 +
  122 +- `n_ctx=256, n_batch=256, n_ubatch=256, n_gpu_layers=999, reuse_query_state=false`
  123 + -> `265318 ms`
  124 +
  125 +## 经验沉淀
  126 +
  127 +这次接入最重要的结论不是“哪个小参数更快”,而是:
  128 +
  129 +1. 这个 0.6B GGUF 权重虽然小,但当前后端实现仍是逐 doc 顺序打分。
  130 +2. 对在线 400-title 请求来说,串行打分本身就是主瓶颈。
  131 +3. `reuse_query_state` 在这个模型上没有带来收益,反而更慢。
  132 +4. `n_ctx` 拉大到 `384/512` 也没有带来实质收益,反而更慢或持平。
  133 +5. 这个 backend 的优势是低显存,不是低延迟。
  134 +
  135 +如果目标是在线最短响应时间,优先级建议是:
  136 +
  137 +1. `qwen3_vllm`
  138 +2. 其他真正支持高吞吐批处理的后端
  139 +3. `qwen3_gguf_06b` 仅作为低显存 fallback
  140 +
  141 +## 验证命令
  142 +
  143 +本地直连 backend 调优:
  144 +
  145 +```bash
  146 +PYTHONPATH=/data/saas-search ./.venv-reranker-gguf/bin/python \
  147 + scripts/benchmark_reranker_gguf_local.py --backend-name qwen3_gguf_06b --docs 400
  148 +```
  149 +
  150 +按服务方式启动:
  151 +
  152 +```bash
  153 +RERANK_BACKEND=qwen3_gguf_06b ./scripts/start_reranker.sh
  154 +```
... ...
reranker/GGUF_INSTALL_AND_TUNING.md 0 → 100644
... ... @@ -0,0 +1,280 @@
  1 +# Qwen3 GGUF 安装与调优手册
  2 +
  3 +本文档只覆盖 `qwen3_gguf` 后端,目标机器为当前项目实测环境:
  4 +
  5 +- GPU: `Tesla T4 16GB`
  6 +- CUDA: `12.8`
  7 +- 模型: `DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF`
  8 +- 量化: `Q8_0`
  9 +
  10 +---
  11 +
  12 +## 1. 结论先看
  13 +
  14 +当前这套代码里,GGUF 后端的主要瓶颈不是“显存没吃满”,而是 **llama.cpp 按 doc 顺序逐条打分**。因此最有效的优化策略是:
  15 +
  16 +- 让模型层尽可能全部 offload 到 GPU
  17 +- 打开 `flash_attn` / `offload_kqv`
  18 +- 把 `n_ctx / n_batch / n_ubatch` 调到一个对短标题重排更合适的高效点
  19 +
  20 +本轮在当前机器上的推荐配置是:
  21 +
  22 +```yaml
  23 +qwen3_gguf:
  24 + n_ctx: 512
  25 + n_batch: 512
  26 + n_ubatch: 512
  27 + n_gpu_layers: 999
  28 + n_threads: 2
  29 + n_threads_batch: 4
  30 + flash_attn: true
  31 + offload_kqv: true
  32 + infer_batch_size: 8
  33 + sort_by_doc_length: true
  34 + length_sort_mode: "char"
  35 +```
  36 +
  37 +说明:
  38 +
  39 +- `n_gpu_layers: 999` 在 llama.cpp 中等价于“尽可能全部层都 offload”
  40 +- 这台 T4 上,**即使全量 offload,当前模型也只占到约 `4.5 GiB` GPU 显存**
  41 +- 所以“允许 8G 显存”并不会自动带来更高速度;这个模型/后端在当前工作负载下已经接近“该用到的权重都上 GPU 了”
  42 +
  43 +---
  44 +
  45 +## 2. 独立环境
  46 +
  47 +`qwen3_gguf` 必须使用自己的独立 venv:
  48 +
  49 +- `qwen3_vllm` -> `.venv-reranker`
  50 +- `qwen3_gguf` -> `.venv-reranker-gguf`
  51 +
  52 +安装命令:
  53 +
  54 +```bash
  55 +./scripts/setup_reranker_venv.sh qwen3_gguf
  56 +```
  57 +
  58 +脚本现在会自动做两件事:
  59 +
  60 +1. 安装 GGUF 后端所需 Python 依赖
  61 +2. 在检测到 `/usr/local/cuda/bin/nvcc` 时,把 `llama-cpp-python` **重编译成 CUDA 版**
  62 +
  63 +---
  64 +
  65 +## 3. GPU 版验证
  66 +
  67 +必须验证不是 CPU-only 版:
  68 +
  69 +```bash
  70 +./.venv-reranker-gguf/bin/python - <<'PY'
  71 +import llama_cpp
  72 +print("supports_gpu_offload =", llama_cpp.llama_supports_gpu_offload())
  73 +PY
  74 +```
  75 +
  76 +正确结果应为:
  77 +
  78 +```text
  79 +supports_gpu_offload = True
  80 +```
  81 +
  82 +还可以看动态库:
  83 +
  84 +```bash
  85 +ldd .venv-reranker-gguf/lib/python3.12/site-packages/llama_cpp/lib/libllama.so | rg 'cuda|cublas|ggml-cuda'
  86 +```
  87 +
  88 +应能看到:
  89 +
  90 +- `libggml-cuda.so`
  91 +- `libcudart.so`
  92 +- `libcublas.so`
  93 +
  94 +---
  95 +
  96 +## 4. 模型下载
  97 +
  98 +当前使用本地文件优先策略,模型放在:
  99 +
  100 +```text
  101 +models/reranker/qwen3-reranker-4b-gguf/Qwen.Qwen3-Reranker-4B.Q8_0.gguf
  102 +```
  103 +
  104 +若本地文件存在,后端会直接加载本地 GGUF,不再依赖启动时在线下载。
  105 +
  106 +为了避免当前机器上 Hugging Face Xet 下载的 `416 Range Not Satisfiable` 问题,`start_reranker.sh` 已对 `qwen3_gguf` 默认设置:
  107 +
  108 +```bash
  109 +HF_HUB_DISABLE_XET=1
  110 +```
  111 +
  112 +---
  113 +
  114 +## 5. 本地调优脚本
  115 +
  116 +新增本地基准脚本:
  117 +
  118 +```bash
  119 +PYTHONPATH=/data/saas-search ./.venv-reranker-gguf/bin/python \
  120 + scripts/benchmark_reranker_gguf_local.py --docs 64 --repeat 1
  121 +```
  122 +
  123 +它会直接实例化 GGUF backend,输出:
  124 +
  125 +- 模型加载耗时
  126 +- 当前进程 GPU 显存占用
  127 +- 单次 rerank 延迟
  128 +
  129 +---
  130 +
  131 +## 6. 本轮实测结果
  132 +
  133 +测试条件:
  134 +
  135 +- Query: `白色oversized T-shirt`
  136 +- Docs: `64` 条商品标题
  137 +- 本地脚本:`scripts/benchmark_reranker_gguf_local.py`
  138 +- 每组 1 次,重点比较相对趋势
  139 +
  140 +结果:
  141 +
  142 +### 6.1 保守配置
  143 +
  144 +```text
  145 +n_ctx=384
  146 +n_batch=384
  147 +n_ubatch=128
  148 +n_gpu_layers=24
  149 +```
  150 +
  151 +- GPU 显存:`2984 MiB`
  152 +- 64 docs 延迟:`74347.91 ms`
  153 +
  154 +### 6.2 全量 offload
  155 +
  156 +```text
  157 +n_ctx=384
  158 +n_batch=384
  159 +n_ubatch=128
  160 +n_gpu_layers=999
  161 +```
  162 +
  163 +- GPU 显存:`4338 MiB`
  164 +- 64 docs 延迟:`51401.77 ms`
  165 +
  166 +### 6.3 最优配置
  167 +
  168 +```text
  169 +n_ctx=512
  170 +n_batch=512
  171 +n_ubatch=512
  172 +n_gpu_layers=999
  173 +```
  174 +
  175 +- GPU 显存:`4564 MiB`
  176 +- 64 docs 延迟:`49116.10 ms`
  177 +
  178 +### 6.4 其它尝试
  179 +
  180 +`n_threads=4 / n_threads_batch=8`:
  181 +
  182 +- GPU 显存:`4564 MiB`
  183 +- 64 docs 延迟:`49895.88 ms`
  184 +- 比推荐值略慢
  185 +
  186 +`infer_batch_size=64`:
  187 +
  188 +- GPU 显存:`4564 MiB`
  189 +- 64 docs 延迟:`50723.36 ms`
  190 +- 也略慢
  191 +
  192 +### 6.5 API 级验证
  193 +
  194 +在把推荐配置写入 `config/config.yaml` 并重启服务后,使用:
  195 +
  196 +```bash
  197 +RERANK_BASE=http://127.0.0.1:6007 \
  198 + ./.venv/bin/python scripts/benchmark_reranker_random_titles.py 64 --repeat 1 --query '白色oversized T-shirt'
  199 +```
  200 +
  201 +得到:
  202 +
  203 +- `64 docs`:`50177.22 ms`
  204 +
  205 +再用:
  206 +
  207 +```bash
  208 +RERANK_BASE=http://127.0.0.1:6007 \
  209 + ./.venv/bin/python scripts/benchmark_reranker_random_titles.py 153 --repeat 1 --query '白色oversized T-shirt'
  210 +```
  211 +
  212 +得到:
  213 +
  214 +- `153 docs`:`115328.60 ms`
  215 +
  216 +对比旧日志中的保守配置:
  217 +
  218 +- 旧配置 `153 docs`:`153435.37 ms`
  219 +- 新配置 `153 docs`:`115328.60 ms`
  220 +
  221 +改善幅度约:
  222 +
  223 +- `24.8%`
  224 +
  225 +---
  226 +
  227 +## 7. 为什么没有吃到 8G
  228 +
  229 +结论很重要:
  230 +
  231 +- 当前最优配置已经是“尽可能全量层 offload”
  232 +- 该 `Q8_0` 模型在这套 llama.cpp / T4 / 短文本重排场景下,**实测只需要约 `4.5 GiB` GPU 显存**
  233 +- 继续为了“吃满 8G”去增大 `n_ctx`,不会明显提升吞吐,反而可能带来额外开销
  234 +
  235 +所以本轮不是“显存太保守”,而是:
  236 +
  237 +- 可 offload 的权重已经基本 offload 完了
  238 +- 真正拖慢响应的是 **逐 doc 顺序推理** 这一后端实现路径
  239 +
  240 +---
  241 +
  242 +## 8. 生产建议
  243 +
  244 +### 8.1 当前建议
  245 +
  246 +保留以下参数:
  247 +
  248 +```yaml
  249 +n_ctx: 512
  250 +n_batch: 512
  251 +n_ubatch: 512
  252 +n_gpu_layers: 999
  253 +n_threads: 2
  254 +n_threads_batch: 4
  255 +flash_attn: true
  256 +offload_kqv: true
  257 +```
  258 +
  259 +### 8.2 如果还嫌慢
  260 +
  261 +优先级建议:
  262 +
  263 +1. 缩小 `rerank_window`
  264 +2. 减少传入 doc 数
  265 +3. 若业务允许,切换到更适合高吞吐的后端
  266 +
  267 +原因:
  268 +
  269 +- 当前 GGUF 后端是本地单进程、逐 doc 打分
  270 +- 对长列表重排,它天然不如 vLLM / 云端 rerank API 擅长吞吐
  271 +
  272 +---
  273 +
  274 +## 9. 本轮落地文件
  275 +
  276 +- `config/config.yaml`
  277 +- `scripts/setup_reranker_venv.sh`
  278 +- `scripts/start_reranker.sh`
  279 +- `scripts/benchmark_reranker_gguf_local.py`
  280 +- `reranker/GGUF_INSTALL_AND_TUNING.md`
... ...
reranker/README.md
1 1 # Reranker 模块
2 2  
3   -**请求示例**见 `docs/QUICKSTART.md` §3.5。扩展规范见 `docs/DEVELOPER_GUIDE.md` §7。部署与调优实战见 `reranker/DEPLOYMENT_AND_TUNING.md`。
  3 +**请求示例**见 `docs/QUICKSTART.md` §3.5。扩展规范见 `docs/DEVELOPER_GUIDE.md` §7。部署与调优实战见 `reranker/DEPLOYMENT_AND_TUNING.md`。`ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF` 的专项接入与调优结论见 `reranker/GGUF_0_6B_INSTALL_AND_TUNING.md`。
4 4  
5 5 ---
6 6  
7   -Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM、Qwen3-Transformers、DashScope 云重排)。调用方通过 HTTP 访问,不关心具体后端。
  7 +Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM、Qwen3-Transformers、Qwen3-GGUF、DashScope 云重排)。调用方通过 HTTP 访问,不关心具体后端。
8 8  
9 9 **特性**
10   -- 多后端:`qwen3_vllm`(默认,Qwen3-Reranker-0.6B + vLLM)、`qwen3_transformers`(纯 Transformers,无需 vLLM)、`bge`(兼容保留)
  10 +- 多后端:`qwen3_vllm`、`qwen3_vllm_score`(同模型,vLLM ``LLM.score()`` + 独立 `.venv-reranker-score`)、`qwen3_transformers`、`qwen3_transformers_packed`(共享前缀 + packed attention mask)、`qwen3_gguf`(Qwen3-Reranker-4B GGUF + llama.cpp)、`qwen3_gguf_06b`(Qwen3-Reranker-0.6B Q8_0 GGUF + llama.cpp)、`bge`(兼容保留)
11 11 - 云后端:`dashscope_rerank`(调用 DashScope `/compatible-api/v1/reranks`,支持按地域切换 endpoint)
12 12 - 统一配置:`config/config.yaml` → `services.rerank.backend` / `services.rerank.backends.<name>`
13 13 - 文档去重、分数与输入顺序一致、FP16/GPU 支持(视后端)
... ... @@ -17,28 +17,51 @@ Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwe
17 17 - `reranker/backends/`:后端实现与工厂
18 18 - `backends/__init__.py`:`get_rerank_backend(name, config)`
19 19 - `backends/bge.py`:BGE 后端
20   - - `backends/qwen3_vllm.py`:Qwen3-Reranker-0.6B + vLLM 后端
  20 + - `backends/qwen3_vllm.py`:Qwen3-Reranker-0.6B + vLLM(generate + logprobs)
  21 + - `backends/qwen3_vllm_score.py`:同上模型 + vLLM ``LLM.score()``(`requirements_reranker_qwen3_vllm_score.txt` / `.venv-reranker-score`)
21 22 - `backends/qwen3_transformers.py`:Qwen3-Reranker-0.6B 纯 Transformers 后端(官方 Usage 方式)
  23 + - `backends/qwen3_transformers_packed.py`:Qwen3-Reranker-0.6B + Transformers packed 推理(共享 query prefix,适合 `1 query + 400 docs`)
  24 + - `backends/qwen3_gguf.py`:Qwen3-Reranker GGUF + llama.cpp 后端(支持 `qwen3_gguf` / `qwen3_gguf_06b`)
22 25 - `backends/dashscope_rerank.py`:DashScope 云重排后端(HTTP 调用)
23 26 - `reranker/bge_reranker.py`:BGE 核心推理(被 bge 后端封装)
24 27 - `reranker/config.py`:服务端口、MAX_DOCS、NORMALIZE 等(后端参数在 config.yaml)
25 28  
26 29 ## 依赖
27 30 - 通用:`torch`、`transformers`、`fastapi`、`uvicorn`(隔离环境见 `requirements_reranker_service.txt`;全量 ML 环境另见 `requirements_ml.txt`)
28   -- **Qwen3-vLLM 后端**:`vllm>=0.8.5`、`transformers>=4.51.0`(仅当使用 `backend: qwen3_vllm` 时需 vLLM)
  31 +- **Qwen3-vLLM 后端**:`vllm>=0.8.5`、`transformers>=4.51.0`(`qwen3_vllm` → `.venv-reranker`)
  32 +- **Qwen3-vLLM-score 后端**:固定 `vllm==0.18.0`(`qwen3_vllm_score` → `.venv-reranker-score`,见 `requirements_reranker_qwen3_vllm_score.txt`)
29 33 - **Qwen3-Transformers 后端**:`transformers>=4.51.0`、`torch`(无需 vLLM,适合 CPU 或小显存)
  34 +- **Qwen3-Transformers-Packed 后端**:复用 Transformers 依赖(`qwen3_transformers_packed` → `.venv-reranker-transformers-packed`)
  35 +- **Qwen3-GGUF 后端**:`llama-cpp-python>=0.3.16`
  36 +- 现在按 backend 使用独立 venv:
  37 + - `qwen3_vllm` -> `.venv-reranker`
  38 + - `qwen3_vllm_score` -> `.venv-reranker-score`
  39 + - `qwen3_gguf` -> `.venv-reranker-gguf`
  40 + - `qwen3_gguf_06b` -> `.venv-reranker-gguf-06b`
  41 + - `qwen3_transformers` -> `.venv-reranker-transformers`
  42 + - `qwen3_transformers_packed` -> `.venv-reranker-transformers-packed`
  43 + - `bge` -> `.venv-reranker-bge`
  44 + - `dashscope_rerank` -> `.venv-reranker-dashscope`
30 45 ```bash
31   - ./scripts/setup_reranker_venv.sh
  46 + ./scripts/setup_reranker_venv.sh qwen3_gguf_06b
  47 + ```
  48 + CUDA 构建建议:
  49 + ```bash
  50 + PATH=/usr/local/cuda/bin:$PATH \
  51 + CUDACXX=/usr/local/cuda/bin/nvcc \
  52 + CMAKE_ARGS="-DGGML_CUDA=on" \
  53 + FORCE_CMAKE=1 \
  54 + ./.venv-reranker-gguf/bin/pip install --no-cache-dir --force-reinstall --no-build-isolation llama-cpp-python==0.3.18
32 55 ```
33 56  
34 57 ## 配置
35   -- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `qwen3_transformers` | `bge` | `dashscope_rerank`),或环境变量 `RERANK_BACKEND`。
  58 +- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `qwen3_vllm_score` | `qwen3_transformers` | `qwen3_transformers_packed` | `qwen3_gguf` | `qwen3_gguf_06b` | `bge` | `dashscope_rerank`),或环境变量 `RERANK_BACKEND`。
36 59 - **后端参数**:`services.rerank.backends.bge` / `services.rerank.backends.qwen3_vllm`,例如:
37 60  
38 61 ```yaml
39 62 services:
40 63 rerank:
41   - backend: "qwen3_vllm" # 或 bge
  64 + backend: "qwen3_gguf" # 或 qwen3_vllm / bge
42 65 backends:
43 66 bge:
44 67 model_name: "BAAI/bge-reranker-v2-m3"
... ... @@ -65,6 +88,44 @@ services:
65 88 tensor_parallel_size: 1
66 89 gpu_memory_utilization: 0.8
67 90 instruction: "Given a shopping query, rank product titles by relevance"
  91 + qwen3_transformers_packed:
  92 + model_name: "Qwen/Qwen3-Reranker-0.6B"
  93 + instruction: "Rank products by query with category & style match prioritized"
  94 + max_model_len: 4096
  95 + max_doc_len: 160
  96 + max_docs_per_pack: 0
  97 + use_fp16: true
  98 + sort_by_doc_length: true
  99 + attn_implementation: "eager"
  100 + qwen3_gguf:
  101 + repo_id: "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF"
  102 + filename: "*Q8_0.gguf"
  103 + local_dir: "./models/reranker/qwen3-reranker-4b-gguf"
  104 + cache_dir: "./model_cache"
  105 + instruction: "Rank products by query with category & style match prioritized"
  106 + n_ctx: 384
  107 + n_batch: 384
  108 + n_ubatch: 128
  109 + n_gpu_layers: 24
  110 + flash_attn: true
  111 + offload_kqv: true
  112 + infer_batch_size: 8
  113 + sort_by_doc_length: true
  114 + length_sort_mode: "char"
  115 + qwen3_gguf_06b:
  116 + repo_id: "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF"
  117 + filename: "qwen3-reranker-0.6b-q8_0.gguf"
  118 + local_dir: "./models/reranker/qwen3-reranker-0.6b-q8_0-gguf"
  119 + cache_dir: "./model_cache"
  120 + instruction: "Rank products by query with category & style match prioritized"
  121 + n_ctx: 256
  122 + n_batch: 256
  123 + n_ubatch: 256
  124 + n_gpu_layers: 999
  125 + infer_batch_size: 32
  126 + sort_by_doc_length: true
  127 + length_sort_mode: "char"
  128 + reuse_query_state: false
68 129 dashscope_rerank:
69 130 model_name: "qwen3-rerank"
70 131 endpoint: "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
... ... @@ -94,7 +155,7 @@ DashScope 认证:
94 155 ```bash
95 156 ./scripts/start_reranker.sh
96 157 ```
97   -该脚本会使用隔离环境 `.venv-reranker`;首次请先执行 `./scripts/setup_reranker_venv.sh`。
  158 +该脚本会按当前 `services.rerank.backend` 自动选择对应的独立 venv;首次请先执行 `./scripts/setup_reranker_venv.sh <backend>`。
98 159  
99 160 ## 性能压测(1000 docs)
100 161 ```bash
... ... @@ -122,7 +183,7 @@ Content-Type: application/json
122 183 ```
123 184  
124 185 `top_n` 为可选字段:
125   -- 对本地后端(`qwen3_vllm` / `qwen3_transformers` / `bge`)通常会忽略,仍返回全量分数。
  186 +- 对本地后端(`qwen3_vllm` / `qwen3_transformers` / `qwen3_transformers_packed` / `qwen3_gguf` / `qwen3_gguf_06b` / `bge`)通常会忽略,仍返回全量分数。
126 187 - 对 `dashscope_rerank` 可用于控制云端返回的候选量,建议设置为 `page+size`(例如分页 `from=20,size=10` 时传 `30`)。
127 188  
128 189 Response:
... ... @@ -160,3 +221,6 @@ uvicorn reranker.server:app --host 0.0.0.0 --port 6007 --log-level info
160 221 - 运行时可用环境变量临时覆盖批量参数:`RERANK_VLLM_INFER_BATCH_SIZE`、`RERANK_VLLM_SORT_BY_DOC_LENGTH`。
161 222 - **Qwen3-vLLM**:参考 [Qwen3-Reranker-0.6B](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B),需 GPU 与较多显存;与 BGE 相比适合长文本、高吞吐场景(vLLM 前缀缓存)。
162 223 - **Qwen3-Transformers**:官方 Transformers Usage 方式,无需 vLLM;适合 CPU 或小显存。默认 `attn_implementation: "sdpa"`;若已安装 `flash_attn` 可设 `flash_attention_2`(未安装时服务会自动回退到 sdpa)。
  224 +- **Qwen3-Transformers-Packed**:仍使用 Hugging Face Transformers 与 PyTorch CUDA 内核,只定制 packed 输入、`position_ids` 和 4D `attention_mask`。它更适合在线检索里的“一个 query 对几百个短 doc”场景;默认 `attn_implementation: "eager"` 以保证自定义 mask 兼容性,若你的 `torch/transformers` 版本已验证支持,可再压测 `"sdpa"`。
  225 +- **Qwen3-GGUF**:参考 [DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF](https://huggingface.co/DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF)。单卡 T4 且仅剩约 `4.8~6GB` 显存时,推荐 `Q8_0 + n_ctx=384 + n_gpu_layers=24 + flash_attn=true + offload_kqv=true` 起步;若启动 OOM,优先把 `n_gpu_layers` 下调到 `20`,再把 `n_ctx` 下调到 `320`。`infer_batch_size` 在 GGUF 后端是服务侧 work chunk,大多不如 `n_gpu_layers` / `n_ctx` 关键。
  226 +- **Qwen3-GGUF-0.6B**:参考 [ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF](https://huggingface.co/ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF)。它的优点是权重小、显存占用低,单进程实测约 `0.9~1.1 GiB`;但在当前 llama.cpp 串行打分接法下,`1 query + 400 titles` 的实测延迟仍约 `265s`。因此它更适合低显存功能后备,不适合作为在线低延迟主 reranker。
... ...
reranker/backends/__init__.py
... ... @@ -43,14 +43,32 @@ def get_rerank_backend(name: str, config: Dict[str, Any]) -&gt; RerankBackendProtoc
43 43 if name == "qwen3_vllm":
44 44 from reranker.backends.qwen3_vllm import Qwen3VLLMRerankerBackend
45 45 return Qwen3VLLMRerankerBackend(config)
  46 + if name == "qwen3_vllm_score":
  47 + from reranker.backends.qwen3_vllm_score import Qwen3VLLMScoreRerankerBackend
  48 + return Qwen3VLLMScoreRerankerBackend(config)
46 49 if name == "qwen3_transformers":
47 50 from reranker.backends.qwen3_transformers import Qwen3TransformersRerankerBackend
48 51 return Qwen3TransformersRerankerBackend(config)
  52 + if name == "qwen3_transformers_packed":
  53 + from reranker.backends.qwen3_transformers_packed import (
  54 + Qwen3TransformersPackedRerankerBackend,
  55 + )
  56 + return Qwen3TransformersPackedRerankerBackend(config)
  57 + if name == "qwen3_gguf":
  58 + from reranker.backends.qwen3_gguf import Qwen3GGUFRerankerBackend
  59 + gguf_config = dict(config or {})
  60 + gguf_config.setdefault("_backend_name", "qwen3_gguf")
  61 + return Qwen3GGUFRerankerBackend(gguf_config)
  62 + if name == "qwen3_gguf_06b":
  63 + from reranker.backends.qwen3_gguf import Qwen3GGUFRerankerBackend
  64 + gguf_config = dict(config or {})
  65 + gguf_config.setdefault("_backend_name", "qwen3_gguf_06b")
  66 + return Qwen3GGUFRerankerBackend(gguf_config)
49 67 if name == "dashscope_rerank":
50 68 from reranker.backends.dashscope_rerank import DashScopeRerankBackend
51 69 return DashScopeRerankBackend(config)
52 70 raise ValueError(
53   - f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_transformers, dashscope_rerank"
  71 + f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_vllm_score, qwen3_transformers, qwen3_transformers_packed, qwen3_gguf, qwen3_gguf_06b, dashscope_rerank"
54 72 )
55 73  
56 74  
... ...
reranker/backends/qwen3_gguf.py 0 → 100644
... ... @@ -0,0 +1,408 @@
  1 +"""
  2 +Qwen3-Reranker GGUF backend using llama-cpp-python.
  3 +
  4 +Reference:
  5 +- https://huggingface.co/DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF
  6 +- https://huggingface.co/Qwen/Qwen3-Reranker-4B
  7 +- https://huggingface.co/ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF
  8 +- https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
  9 +"""
  10 +
  11 +from __future__ import annotations
  12 +
  13 +import logging
  14 +import math
  15 +import os
  16 +import threading
  17 +import time
  18 +from pathlib import Path
  19 +from typing import Any, Dict, List, Tuple
  20 +
  21 +
  22 +logger = logging.getLogger("reranker.backends.qwen3_gguf")
  23 +
  24 +
  25 +_BACKEND_DEFAULTS: Dict[str, Dict[str, str]] = {
  26 + "qwen3_gguf": {
  27 + "repo_id": "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF",
  28 + "filename": "*Q8_0.gguf",
  29 + "local_dir": "./models/reranker/qwen3-reranker-4b-gguf",
  30 + },
  31 + "qwen3_gguf_06b": {
  32 + "repo_id": "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF",
  33 + "filename": "qwen3-reranker-0.6b-q8_0.gguf",
  34 + "local_dir": "./models/reranker/qwen3-reranker-0.6b-q8_0-gguf",
  35 + },
  36 +}
  37 +
  38 +
  39 +def deduplicate_with_positions(texts: List[str]) -> Tuple[List[str], List[int]]:
  40 + """Deduplicate texts globally while preserving first-seen order."""
  41 + unique_texts: List[str] = []
  42 + position_to_unique: List[int] = []
  43 + seen: Dict[str, int] = {}
  44 +
  45 + for text in texts:
  46 + idx = seen.get(text)
  47 + if idx is None:
  48 + idx = len(unique_texts)
  49 + seen[text] = idx
  50 + unique_texts.append(text)
  51 + position_to_unique.append(idx)
  52 +
  53 + return unique_texts, position_to_unique
  54 +
  55 +
  56 +def _format_instruction(instruction: str, query: str, doc: str) -> str:
  57 + return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
  58 + instruction=instruction,
  59 + query=query,
  60 + doc=doc,
  61 + )
  62 +
  63 +
  64 +class Qwen3GGUFRerankerBackend:
  65 + """
  66 + Qwen3-Reranker GGUF backend using llama.cpp through llama-cpp-python.
  67 +
  68 + Tuned for short-query / short-doc reranking on a single GPU.
  69 + Config from services.rerank.backends.<backend_name>.
  70 + """
  71 +
  72 + def __init__(self, config: Dict[str, Any]) -> None:
  73 + self._config = config or {}
  74 + self._backend_name = str(self._config.get("_backend_name") or "qwen3_gguf").strip()
  75 + defaults = _BACKEND_DEFAULTS.get(self._backend_name, _BACKEND_DEFAULTS["qwen3_gguf"])
  76 + self._repo_id = str(self._config.get("repo_id") or defaults["repo_id"]).strip()
  77 + self._filename = str(self._config.get("filename") or defaults["filename"]).strip()
  78 + self._model_path = str(self._config.get("model_path") or "").strip()
  79 + self._cache_dir = str(self._config.get("cache_dir") or "").strip() or None
  80 + self._local_dir = str(self._config.get("local_dir") or defaults["local_dir"]).strip() or None
  81 + self._instruction = str(
  82 + self._config.get("instruction")
  83 + or "Rank products by query with category & style match prioritized"
  84 + )
  85 + self._infer_batch_size = int(
  86 + os.getenv("RERANK_GGUF_INFER_BATCH_SIZE") or self._config.get("infer_batch_size", 8)
  87 + )
  88 + sort_by_doc_length = os.getenv("RERANK_GGUF_SORT_BY_DOC_LENGTH")
  89 + if sort_by_doc_length is None:
  90 + sort_by_doc_length = self._config.get("sort_by_doc_length", True)
  91 + self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {
  92 + "1",
  93 + "true",
  94 + "yes",
  95 + "y",
  96 + "on",
  97 + }
  98 + self._length_sort_mode = str(self._config.get("length_sort_mode") or "char").strip().lower()
  99 + self._reuse_query_state = bool(self._config.get("reuse_query_state", False))
  100 +
  101 + n_ctx = int(self._config.get("n_ctx", self._config.get("max_model_len", 384)))
  102 + n_batch = int(self._config.get("n_batch", min(n_ctx, 384)))
  103 + n_ubatch = int(self._config.get("n_ubatch", min(n_batch, 128)))
  104 + n_gpu_layers = int(self._config.get("n_gpu_layers", 24))
  105 + main_gpu = int(self._config.get("main_gpu", 0))
  106 + n_threads = int(self._config.get("n_threads", 2))
  107 + n_threads_batch = int(self._config.get("n_threads_batch", 4))
  108 + flash_attn = bool(self._config.get("flash_attn", True))
  109 + offload_kqv = bool(self._config.get("offload_kqv", True))
  110 + use_mmap = bool(self._config.get("use_mmap", True))
  111 + use_mlock = bool(self._config.get("use_mlock", False))
  112 + verbose = bool(self._config.get("verbose", False))
  113 + enable_warmup = bool(self._config.get("enable_warmup", True))
  114 +
  115 + if self._infer_batch_size <= 0:
  116 + raise ValueError(f"infer_batch_size must be > 0, got {self._infer_batch_size}")
  117 + if n_ctx <= 0:
  118 + raise ValueError(f"n_ctx must be > 0, got {n_ctx}")
  119 + if n_batch <= 0 or n_ubatch <= 0:
  120 + raise ValueError(f"n_batch/n_ubatch must be > 0, got {n_batch}/{n_ubatch}")
  121 +
  122 + try:
  123 + from llama_cpp import Llama
  124 + except Exception as exc: # pragma: no cover - depends on optional dependency
  125 + raise RuntimeError(
  126 + f"{self._backend_name} backend requires llama-cpp-python. "
  127 + f"Install the {self._backend_name} backend venv first via "
  128 + f"scripts/setup_reranker_venv.sh {self._backend_name}."
  129 + ) from exc
  130 +
  131 + self._llama_class = Llama
  132 + self._n_ctx = n_ctx
  133 + self._n_batch = n_batch
  134 + self._n_ubatch = n_ubatch
  135 + self._n_gpu_layers = n_gpu_layers
  136 + self._enable_warmup = enable_warmup
  137 + self._infer_lock = threading.Lock()
  138 +
  139 + logger.info(
  140 + "[Qwen3_GGUF] Loading backend=%s repo=%s filename=%s model_path=%s n_ctx=%s n_batch=%s n_ubatch=%s n_gpu_layers=%s flash_attn=%s offload_kqv=%s reuse_query_state=%s",
  141 + self._backend_name,
  142 + self._repo_id,
  143 + self._filename,
  144 + self._model_path or None,
  145 + n_ctx,
  146 + n_batch,
  147 + n_ubatch,
  148 + n_gpu_layers,
  149 + flash_attn,
  150 + offload_kqv,
  151 + self._reuse_query_state,
  152 + )
  153 +
  154 + llm_kwargs = {
  155 + "n_ctx": n_ctx,
  156 + "n_batch": n_batch,
  157 + "n_ubatch": n_ubatch,
  158 + "n_gpu_layers": n_gpu_layers,
  159 + "main_gpu": main_gpu,
  160 + "n_threads": n_threads,
  161 + "n_threads_batch": n_threads_batch,
  162 + "logits_all": True,
  163 + "offload_kqv": offload_kqv,
  164 + "flash_attn": flash_attn,
  165 + "use_mmap": use_mmap,
  166 + "use_mlock": use_mlock,
  167 + "verbose": verbose,
  168 + }
  169 + llm_kwargs = {key: value for key, value in llm_kwargs.items() if value is not None}
  170 + self._llm = self._load_model(llm_kwargs)
  171 + self._model_name = self._model_path or f"{self._repo_id}:{self._filename}"
  172 +
  173 + self._prefix = (
  174 + "<|im_start|>system\n"
  175 + "Judge whether the Document meets the requirements based on the Query and the Instruct provided. "
  176 + 'Note that the answer can only be "yes" or "no".'
  177 + "<|im_end|>\n<|im_start|>user\n"
  178 + )
  179 + self._suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
  180 + self._prefix_tokens = self._tokenize(self._prefix, special=True)
  181 + self._suffix_tokens = self._tokenize(self._suffix, special=True)
  182 + self._request_prefix_template = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: "
  183 + self._effective_max_len = self._n_ctx - len(self._prefix_tokens) - len(self._suffix_tokens)
  184 + if self._effective_max_len <= 16:
  185 + raise RuntimeError(
  186 + f"n_ctx={self._n_ctx} is too small after prompt overhead; effective={self._effective_max_len}"
  187 + )
  188 +
  189 + self._true_token = self._single_token_id("yes")
  190 + self._false_token = self._single_token_id("no")
  191 +
  192 + if self._enable_warmup:
  193 + self._warmup()
  194 +
  195 + logger.info(
  196 + "[Qwen3_GGUF] Model ready | backend=%s model=%s effective_max_len=%s infer_batch_size=%s sort_by_doc_length=%s",
  197 + self._backend_name,
  198 + self._model_name,
  199 + self._effective_max_len,
  200 + self._infer_batch_size,
  201 + self._sort_by_doc_length,
  202 + )
  203 +
  204 + def _load_model(self, llm_kwargs: Dict[str, Any]):
  205 + if self._model_path:
  206 + return self._llama_class(model_path=self._model_path, **llm_kwargs)
  207 + if self._local_dir:
  208 + matches = sorted(
  209 + path for path in Path(self._local_dir).glob(self._filename) if path.is_file()
  210 + )
  211 + if matches:
  212 + local_model_path = str(matches[0].resolve())
  213 + logger.info("[Qwen3_GGUF] Using local GGUF file: %s", local_model_path)
  214 + return self._llama_class(model_path=local_model_path, **llm_kwargs)
  215 + return self._llama_class.from_pretrained(
  216 + repo_id=self._repo_id,
  217 + filename=self._filename,
  218 + local_dir=self._local_dir,
  219 + cache_dir=self._cache_dir,
  220 + **llm_kwargs,
  221 + )
  222 +
  223 + def _tokenize(self, text: str, *, special: bool) -> List[int]:
  224 + return list(
  225 + self._llm.tokenize(
  226 + text.encode("utf-8"),
  227 + add_bos=False,
  228 + special=special,
  229 + )
  230 + )
  231 +
  232 + def _single_token_id(self, text: str) -> int:
  233 + token_ids = self._tokenize(text, special=False)
  234 + if len(token_ids) != 1:
  235 + raise RuntimeError(f"Expected {text!r} to be one token, got {token_ids}")
  236 + return int(token_ids[0])
  237 +
  238 + def _warmup(self) -> None:
  239 + try:
  240 + prompt = self._build_prompt_tokens("warmup query", "warmup document")
  241 + with self._infer_lock:
  242 + self._eval_logits(prompt)
  243 + except Exception as exc: # pragma: no cover - defensive
  244 + logger.warning("[Qwen3_GGUF] Warmup failed: %s", exc)
  245 +
  246 + def _build_request_prefix_tokens(self, query: str) -> List[int]:
  247 + request_prefix = self._request_prefix_template.format(
  248 + instruction=self._instruction,
  249 + query=query,
  250 + )
  251 + return self._tokenize(request_prefix, special=False)
  252 +
  253 + def _build_prompt_tokens(self, query: str, doc: str) -> List[int]:
  254 + pair = _format_instruction(self._instruction, query, doc)
  255 + pair_tokens = self._tokenize(pair, special=False)
  256 + pair_tokens = pair_tokens[: self._effective_max_len]
  257 + return self._prefix_tokens + pair_tokens + self._suffix_tokens
  258 +
  259 + def _eval_logits(self, prompt_tokens: List[int]) -> List[float]:
  260 + self._llm.reset()
  261 + self._llm.eval(prompt_tokens)
  262 + logits = self._llm.eval_logits
  263 + if not logits:
  264 + raise RuntimeError("llama.cpp returned empty logits")
  265 + return list(logits[-1])
  266 +
  267 + def _score_prompt(self, prompt_tokens: List[int]) -> float:
  268 + logits = self._eval_logits(prompt_tokens)
  269 + true_logit = float(logits[self._true_token])
  270 + false_logit = float(logits[self._false_token])
  271 + max_logit = max(true_logit, false_logit)
  272 + true_exp = math.exp(true_logit - max_logit)
  273 + false_exp = math.exp(false_logit - max_logit)
  274 + return float(true_exp / (true_exp + false_exp))
  275 +
  276 + def _supports_query_state_reuse(self) -> bool:
  277 + return (
  278 + self._reuse_query_state
  279 + and hasattr(self._llm, "save_state")
  280 + and hasattr(self._llm, "load_state")
  281 + )
  282 +
  283 + def _build_query_state_locked(self, query: str):
  284 + request_prefix_tokens = self._build_request_prefix_tokens(query)
  285 + max_doc_tokens = self._effective_max_len - len(request_prefix_tokens)
  286 + if max_doc_tokens <= 0:
  287 + return None, 0
  288 + self._llm.reset()
  289 + self._llm.eval(self._prefix_tokens + request_prefix_tokens)
  290 + return self._llm.save_state(), max_doc_tokens
  291 +
  292 + def _score_doc_with_state_locked(self, state, doc_tokens: List[int], max_doc_tokens: int) -> float:
  293 + self._llm.load_state(state)
  294 + self._llm.eval(doc_tokens[:max_doc_tokens] + self._suffix_tokens)
  295 + logits = self._llm.eval_logits
  296 + if not logits:
  297 + raise RuntimeError("llama.cpp returned empty logits")
  298 + final_logits = list(logits[-1])
  299 + true_logit = float(final_logits[self._true_token])
  300 + false_logit = float(final_logits[self._false_token])
  301 + max_logit = max(true_logit, false_logit)
  302 + true_exp = math.exp(true_logit - max_logit)
  303 + false_exp = math.exp(false_logit - max_logit)
  304 + return float(true_exp / (true_exp + false_exp))
  305 +
  306 + def _estimate_doc_lengths(self, docs: List[str]) -> List[int]:
  307 + if self._length_sort_mode == "token":
  308 + return [len(self._tokenize(text, special=False)) for text in docs]
  309 + return [len(text) for text in docs]
  310 +
  311 + def score_with_meta(
  312 + self,
  313 + query: str,
  314 + docs: List[str],
  315 + normalize: bool = True,
  316 + ) -> Tuple[List[float], Dict[str, Any]]:
  317 + start_ts = time.time()
  318 + total_docs = len(docs) if docs else 0
  319 + output_scores: List[float] = [0.0] * total_docs
  320 +
  321 + query = "" if query is None else str(query).strip()
  322 + indexed: List[Tuple[int, str]] = []
  323 + for i, doc in enumerate(docs or []):
  324 + if doc is None:
  325 + continue
  326 + text = str(doc).strip()
  327 + if not text:
  328 + continue
  329 + indexed.append((i, text))
  330 +
  331 + if not query or not indexed:
  332 + elapsed_ms = (time.time() - start_ts) * 1000.0
  333 + return output_scores, {
  334 + "input_docs": total_docs,
  335 + "usable_docs": len(indexed),
  336 + "unique_docs": 0,
  337 + "dedup_ratio": 0.0,
  338 + "elapsed_ms": round(elapsed_ms, 3),
  339 + "model": self._model_name,
  340 + "backend": self._backend_name,
  341 + "normalize": normalize,
  342 + "infer_batch_size": self._infer_batch_size,
  343 + "inference_batches": 0,
  344 + "sort_by_doc_length": self._sort_by_doc_length,
  345 + "n_ctx": self._n_ctx,
  346 + "n_batch": self._n_batch,
  347 + "n_ubatch": self._n_ubatch,
  348 + "n_gpu_layers": self._n_gpu_layers,
  349 + }
  350 +
  351 + indexed_texts = [text for _, text in indexed]
  352 + unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts)
  353 +
  354 + lengths = self._estimate_doc_lengths(unique_texts)
  355 + order = list(range(len(unique_texts)))
  356 + if self._sort_by_doc_length and len(unique_texts) > 1:
  357 + order = sorted(order, key=lambda i: lengths[i])
  358 +
  359 + unique_scores: List[float] = [0.0] * len(unique_texts)
  360 + unique_doc_tokens = [self._tokenize(text, special=False) for text in unique_texts]
  361 + inference_batches = 0
  362 + with self._infer_lock:
  363 + query_state = None
  364 + max_doc_tokens = self._effective_max_len
  365 + if self._supports_query_state_reuse():
  366 + query_state, max_doc_tokens = self._build_query_state_locked(query)
  367 + for start in range(0, len(order), self._infer_batch_size):
  368 + batch_indices = order[start : start + self._infer_batch_size]
  369 + inference_batches += 1
  370 + for idx in batch_indices:
  371 + if query_state is not None:
  372 + unique_scores[idx] = self._score_doc_with_state_locked(
  373 + query_state,
  374 + unique_doc_tokens[idx],
  375 + max_doc_tokens,
  376 + )
  377 + else:
  378 + prompt = self._build_prompt_tokens(query, unique_texts[idx])
  379 + unique_scores[idx] = self._score_prompt(prompt)
  380 +
  381 + for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
  382 + output_scores[orig_idx] = float(unique_scores[unique_idx])
  383 +
  384 + elapsed_ms = (time.time() - start_ts) * 1000.0
  385 + dedup_ratio = 0.0
  386 + if indexed:
  387 + dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))
  388 +
  389 + meta = {
  390 + "input_docs": total_docs,
  391 + "usable_docs": len(indexed),
  392 + "unique_docs": len(unique_texts),
  393 + "dedup_ratio": round(dedup_ratio, 4),
  394 + "elapsed_ms": round(elapsed_ms, 3),
  395 + "model": self._model_name,
  396 + "backend": self._backend_name,
  397 + "normalize": normalize,
  398 + "infer_batch_size": self._infer_batch_size,
  399 + "inference_batches": inference_batches,
  400 + "sort_by_doc_length": self._sort_by_doc_length,
  401 + "length_sort_mode": self._length_sort_mode,
  402 + "n_ctx": self._n_ctx,
  403 + "n_batch": self._n_batch,
  404 + "n_ubatch": self._n_ubatch,
  405 + "n_gpu_layers": self._n_gpu_layers,
  406 + "reuse_query_state": query_state is not None,
  407 + }
  408 + return output_scores, meta
... ...
reranker/backends/qwen3_transformers_packed.py 0 → 100644
... ... @@ -0,0 +1,398 @@
  1 +"""
  2 +Qwen3-Reranker backend using packed inference with Transformers.
  3 +
  4 +This backend implements the sequence stitching optimization described in
  5 +Qwen3-Reranker packed inference examples:
  6 +1. Share the query/instruction prefix across many documents.
  7 +2. Reset document ``position_ids`` relative to the shared prefix.
  8 +3. Use a custom causal attention mask so each document can attend to the
  9 + prefix and itself, but never to other documents.
  10 +
  11 +Compared with the standard per-pair batching path, this reduces repeated
  12 +prefix computation and removes inter-sample padding waste. For online search
  13 +requests like ``1 query + 400 docs``, the backend further packs documents into
  14 +multiple chunks under a configurable total token budget.
  15 +"""
  16 +
  17 +from __future__ import annotations
  18 +
  19 +import logging
  20 +import threading
  21 +import time
  22 +from typing import Any, Dict, List, Sequence, Tuple
  23 +
  24 +import torch
  25 +from transformers import AutoModelForCausalLM, AutoTokenizer
  26 +
  27 +logger = logging.getLogger("reranker.backends.qwen3_transformers_packed")
  28 +
  29 +_DEFAULT_PREFIX = (
  30 + "<|im_start|>system\n"
  31 + "Judge whether the Document meets the requirements based on the Query and the Instruct "
  32 + 'provided. Note that the answer can only be "yes" or "no".'
  33 + "<|im_end|>\n<|im_start|>user\n"
  34 +)
  35 +_DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
  36 +_DEFAULT_PAIR_PREFIX_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n<Document>: "
  37 +
  38 +
  39 +def _deduplicate_with_positions(texts: Sequence[str]) -> Tuple[List[str], List[int]]:
  40 + unique_texts: List[str] = []
  41 + position_to_unique: List[int] = []
  42 + seen: Dict[str, int] = {}
  43 +
  44 + for text in texts:
  45 + idx = seen.get(text)
  46 + if idx is None:
  47 + idx = len(unique_texts)
  48 + seen[text] = idx
  49 + unique_texts.append(text)
  50 + position_to_unique.append(idx)
  51 +
  52 + return unique_texts, position_to_unique
  53 +
  54 +
  55 +class Qwen3TransformersPackedRerankerBackend:
  56 + """
  57 + Qwen3-Reranker packed inference backend using Transformers.
  58 +
  59 + Config from ``services.rerank.backends.qwen3_transformers_packed``.
  60 + """
  61 +
  62 + def __init__(self, config: Dict[str, Any]) -> None:
  63 + self._config = config or {}
  64 + model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
  65 + self._instruction = str(
  66 + self._config.get("instruction")
  67 + or "Rank products by query with category & style match prioritized"
  68 + )
  69 + self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX)
  70 + self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX)
  71 + self._pair_prefix_template = str(
  72 + self._config.get("pair_prefix_template") or _DEFAULT_PAIR_PREFIX_TEMPLATE
  73 + )
  74 +
  75 + max_model_len = int(self._config.get("max_model_len", 4096))
  76 + max_doc_len = int(self._config.get("max_doc_len", 160))
  77 + max_docs_per_pack = int(self._config.get("max_docs_per_pack", 0))
  78 + use_fp16 = bool(self._config.get("use_fp16", True))
  79 + device = self._config.get("device")
  80 + attn_impl = str(self._config.get("attn_implementation") or "eager").strip()
  81 + sort_by_doc_length = self._config.get("sort_by_doc_length", True)
  82 +
  83 + self._model_name = model_name
  84 + self._max_model_len = max_model_len
  85 + self._max_doc_len = max_doc_len
  86 + self._max_docs_per_pack = max_docs_per_pack
  87 + self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {
  88 + "1",
  89 + "true",
  90 + "yes",
  91 + "y",
  92 + "on",
  93 + }
  94 + self._attn_impl = attn_impl
  95 +
  96 + logger.info(
  97 + "[Qwen3_Transformers_Packed] Loading model %s (max_model_len=%s, max_doc_len=%s, "
  98 + "max_docs_per_pack=%s, fp16=%s, attn_impl=%s)",
  99 + model_name,
  100 + max_model_len,
  101 + max_doc_len,
  102 + max_docs_per_pack,
  103 + use_fp16,
  104 + attn_impl,
  105 + )
  106 +
  107 + self._tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
  108 + self._tokenizer.pad_token = self._tokenizer.eos_token
  109 +
  110 + self._prefix_tokens = self._tokenizer.encode(self._prefix, add_special_tokens=False)
  111 + self._suffix_tokens = self._tokenizer.encode(self._suffix, add_special_tokens=False)
  112 + self._suffix_len = len(self._suffix_tokens)
  113 +
  114 + if not torch.cuda.is_available():
  115 + raise RuntimeError(
  116 + "qwen3_transformers_packed backend requires CUDA GPU, "
  117 + "but torch.cuda.is_available() is False"
  118 + )
  119 +
  120 + kwargs: Dict[str, Any] = {}
  121 + if use_fp16:
  122 + kwargs["torch_dtype"] = torch.float16
  123 + if attn_impl:
  124 + kwargs["attn_implementation"] = attn_impl
  125 +
  126 + self._model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).eval()
  127 + target_device = str(device).strip() if device is not None else "cuda"
  128 + if not target_device.startswith("cuda"):
  129 + raise ValueError(
  130 + "qwen3_transformers_packed backend is GPU-only. "
  131 + f"Unsupported device setting: {target_device!r}"
  132 + )
  133 + self._model = self._model.to(target_device)
  134 + self._device = next(self._model.parameters()).device
  135 + if self._device.type != "cuda":
  136 + raise RuntimeError(
  137 + "qwen3_transformers_packed backend failed to place model on CUDA. "
  138 + f"Current device: {self._device}"
  139 + )
  140 +
  141 + self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes")
  142 + self._token_false_id = self._tokenizer.convert_tokens_to_ids("no")
  143 + if self._token_true_id is None or self._token_false_id is None:
  144 + raise RuntimeError("Failed to resolve Qwen3 reranker classifier token ids for yes/no")
  145 +
  146 + prefix_budget = len(self._prefix_tokens) + self._suffix_len + 1
  147 + if self._max_model_len <= prefix_budget:
  148 + raise ValueError(
  149 + "max_model_len is too small for packed reranking. "
  150 + f"Need > {prefix_budget}, got {self._max_model_len}."
  151 + )
  152 + if self._max_doc_len <= 0:
  153 + raise ValueError(f"max_doc_len must be > 0, got {self._max_doc_len}")
  154 + if self._max_docs_per_pack < 0:
  155 + raise ValueError(
  156 + f"max_docs_per_pack must be >= 0, got {self._max_docs_per_pack}"
  157 + )
  158 +
  159 + self._infer_lock = threading.Lock()
  160 +
  161 + logger.info(
  162 + "[Qwen3_Transformers_Packed] Model ready | model=%s device=%s",
  163 + model_name,
  164 + self._device,
  165 + )
  166 +
  167 + def _build_pair_prefix_tokens(self, query: str) -> List[int]:
  168 + pair_prefix = self._pair_prefix_template.format(
  169 + prefix=self._prefix,
  170 + instruction=self._instruction,
  171 + query=query,
  172 + )
  173 + return self._tokenizer.encode(pair_prefix, add_special_tokens=False)
  174 +
  175 + def _tokenize_documents(self, docs: Sequence[str], query_prefix_len: int) -> List[List[int]]:
  176 + max_doc_tokens = min(
  177 + self._max_doc_len,
  178 + max(1, self._max_model_len - query_prefix_len - self._suffix_len),
  179 + )
  180 + tokenized = self._tokenizer(
  181 + list(docs),
  182 + padding=False,
  183 + truncation=True,
  184 + max_length=max_doc_tokens,
  185 + add_special_tokens=False,
  186 + return_attention_mask=False,
  187 + )
  188 + return [list(ids) for ids in tokenized["input_ids"]]
  189 +
  190 + def _build_pack_plan(
  191 + self,
  192 + query_prefix_len: int,
  193 + doc_tokens: Sequence[Sequence[int]],
  194 + ) -> List[List[int]]:
  195 + order = list(range(len(doc_tokens)))
  196 + if self._sort_by_doc_length and len(order) > 1:
  197 + order.sort(key=lambda idx: len(doc_tokens[idx]))
  198 +
  199 + packs: List[List[int]] = []
  200 + current_pack: List[int] = []
  201 + current_len = query_prefix_len
  202 + for idx in order:
  203 + packed_doc_len = len(doc_tokens[idx]) + self._suffix_len
  204 + if packed_doc_len <= 0:
  205 + continue
  206 +
  207 + over_docs_cap = self._max_docs_per_pack > 0 and len(current_pack) >= self._max_docs_per_pack
  208 + over_token_cap = current_pack and (current_len + packed_doc_len > self._max_model_len)
  209 + if over_docs_cap or over_token_cap:
  210 + packs.append(current_pack)
  211 + current_pack = []
  212 + current_len = query_prefix_len
  213 +
  214 + if query_prefix_len + packed_doc_len > self._max_model_len:
  215 + raise ValueError(
  216 + "Packed doc still exceeds max_model_len after truncation. "
  217 + f"query_prefix_len={query_prefix_len}, doc_len={packed_doc_len}, "
  218 + f"max_model_len={self._max_model_len}"
  219 + )
  220 +
  221 + current_pack.append(idx)
  222 + current_len += packed_doc_len
  223 +
  224 + if current_pack:
  225 + packs.append(current_pack)
  226 + return packs
  227 +
  228 + def _build_pack_inputs(
  229 + self,
  230 + query_prefix_tokens: Sequence[int],
  231 + doc_tokens: Sequence[Sequence[int]],
  232 + doc_indices: Sequence[int],
  233 + ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
  234 + prefix_len = len(query_prefix_tokens)
  235 + input_ids_list = list(query_prefix_tokens)
  236 + position_ids_list = list(range(prefix_len))
  237 + spans: List[Tuple[int, int]] = []
  238 + current_len = prefix_len
  239 +
  240 + for idx in doc_indices:
  241 + doc_with_suffix = list(doc_tokens[idx]) + self._suffix_tokens
  242 + start = current_len
  243 + end = start + len(doc_with_suffix)
  244 + spans.append((start, end))
  245 + input_ids_list.extend(doc_with_suffix)
  246 + position_ids_list.extend(range(prefix_len, prefix_len + len(doc_with_suffix)))
  247 + current_len = end
  248 +
  249 + total_len = len(input_ids_list)
  250 + device = self._device
  251 + neg_inf = torch.finfo(torch.float32).min
  252 +
  253 + allowed = torch.zeros((total_len, total_len), dtype=torch.bool, device=device)
  254 + prefix_causal = torch.tril(
  255 + torch.ones((prefix_len, prefix_len), dtype=torch.bool, device=device)
  256 + )
  257 + allowed[:prefix_len, :prefix_len] = prefix_causal
  258 + for start, end in spans:
  259 + allowed[start:end, :prefix_len] = True
  260 + doc_len = end - start
  261 + allowed[start:end, start:end] = torch.tril(
  262 + torch.ones((doc_len, doc_len), dtype=torch.bool, device=device)
  263 + )
  264 +
  265 + attention_mask = torch.full(
  266 + (total_len, total_len),
  267 + neg_inf,
  268 + dtype=torch.float32,
  269 + device=device,
  270 + )
  271 + attention_mask.masked_fill_(allowed, 0.0)
  272 +
  273 + inputs = {
  274 + "input_ids": torch.tensor([input_ids_list], dtype=torch.long, device=device),
  275 + "position_ids": torch.tensor([position_ids_list], dtype=torch.long, device=device),
  276 + "attention_mask": attention_mask.view(1, 1, total_len, total_len),
  277 + }
  278 + logits_ids = torch.tensor(
  279 + [end - 1 for _, end in spans],
  280 + dtype=torch.long,
  281 + device=device,
  282 + )
  283 + return inputs, logits_ids
  284 +
  285 + @torch.no_grad()
  286 + def _score_pack(
  287 + self,
  288 + query_prefix_tokens: Sequence[int],
  289 + doc_tokens: Sequence[Sequence[int]],
  290 + doc_indices: Sequence[int],
  291 + ) -> Tuple[List[float], int]:
  292 + inputs, logits_ids = self._build_pack_inputs(
  293 + query_prefix_tokens=query_prefix_tokens,
  294 + doc_tokens=doc_tokens,
  295 + doc_indices=doc_indices,
  296 + )
  297 + outputs = self._model(**inputs)
  298 + scores = outputs.logits[0, logits_ids, :]
  299 + true_vector = scores[:, self._token_true_id]
  300 + false_vector = scores[:, self._token_false_id]
  301 + pair_scores = torch.stack([false_vector, true_vector], dim=1)
  302 + pair_scores = torch.nn.functional.log_softmax(pair_scores, dim=1)
  303 + return pair_scores[:, 1].exp().tolist(), int(inputs["input_ids"].shape[1])
  304 +
  305 + def score_with_meta(
  306 + self,
  307 + query: str,
  308 + docs: List[str],
  309 + normalize: bool = True,
  310 + ) -> Tuple[List[float], Dict[str, Any]]:
  311 + start_ts = time.time()
  312 + total_docs = len(docs) if docs else 0
  313 + output_scores: List[float] = [0.0] * total_docs
  314 +
  315 + query = "" if query is None else str(query).strip()
  316 + indexed: List[Tuple[int, str]] = []
  317 + for i, doc in enumerate(docs or []):
  318 + if doc is None:
  319 + continue
  320 + text = str(doc).strip()
  321 + if not text:
  322 + continue
  323 + indexed.append((i, text))
  324 +
  325 + if not query or not indexed:
  326 + elapsed_ms = (time.time() - start_ts) * 1000.0
  327 + return output_scores, {
  328 + "input_docs": total_docs,
  329 + "usable_docs": len(indexed),
  330 + "unique_docs": 0,
  331 + "dedup_ratio": 0.0,
  332 + "elapsed_ms": round(elapsed_ms, 3),
  333 + "model": self._model_name,
  334 + "backend": "qwen3_transformers_packed",
  335 + "normalize": normalize,
  336 + "packed_batches": 0,
  337 + "max_model_len": self._max_model_len,
  338 + "max_doc_len": self._max_doc_len,
  339 + "sort_by_doc_length": self._sort_by_doc_length,
  340 + }
  341 +
  342 + indexed_texts = [text for _, text in indexed]
  343 + unique_texts, position_to_unique = _deduplicate_with_positions(indexed_texts)
  344 +
  345 + query_prefix_tokens = self._build_pair_prefix_tokens(query)
  346 + doc_tokens = self._tokenize_documents(unique_texts, query_prefix_len=len(query_prefix_tokens))
  347 + pack_plan = self._build_pack_plan(
  348 + query_prefix_len=len(query_prefix_tokens),
  349 + doc_tokens=doc_tokens,
  350 + )
  351 +
  352 + unique_scores: List[float] = [0.0] * len(unique_texts)
  353 + pack_lengths: List[int] = []
  354 + with self._infer_lock:
  355 + for pack_doc_indices in pack_plan:
  356 + batch_scores, pack_seq_len = self._score_pack(
  357 + query_prefix_tokens=query_prefix_tokens,
  358 + doc_tokens=doc_tokens,
  359 + doc_indices=pack_doc_indices,
  360 + )
  361 + if len(batch_scores) != len(pack_doc_indices):
  362 + raise RuntimeError(
  363 + "Packed reranker score size mismatch: "
  364 + f"expected {len(pack_doc_indices)}, got {len(batch_scores)}"
  365 + )
  366 + for idx, score in zip(pack_doc_indices, batch_scores):
  367 + unique_scores[idx] = float(score)
  368 + pack_lengths.append(pack_seq_len)
  369 +
  370 + for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
  371 + output_scores[orig_idx] = float(unique_scores[unique_idx])
  372 +
  373 + elapsed_ms = (time.time() - start_ts) * 1000.0
  374 + dedup_ratio = 0.0
  375 + if indexed:
  376 + dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))
  377 +
  378 + meta = {
  379 + "input_docs": total_docs,
  380 + "usable_docs": len(indexed),
  381 + "unique_docs": len(unique_texts),
  382 + "dedup_ratio": round(dedup_ratio, 4),
  383 + "elapsed_ms": round(elapsed_ms, 3),
  384 + "model": self._model_name,
  385 + "backend": "qwen3_transformers_packed",
  386 + "normalize": normalize,
  387 + "packed_batches": len(pack_plan),
  388 + "packed_max_seq_len": max(pack_lengths) if pack_lengths else 0,
  389 + "packed_avg_seq_len": round(sum(pack_lengths) / len(pack_lengths), 3)
  390 + if pack_lengths
  391 + else 0.0,
  392 + "max_model_len": self._max_model_len,
  393 + "max_doc_len": self._max_doc_len,
  394 + "max_docs_per_pack": self._max_docs_per_pack,
  395 + "sort_by_doc_length": self._sort_by_doc_length,
  396 + "attn_implementation": self._attn_impl,
  397 + }
  398 + return output_scores, meta
... ...
reranker/backends/qwen3_vllm.py
... ... @@ -45,7 +45,7 @@ def deduplicate_with_positions(texts: List[str]) -&gt; Tuple[List[str], List[int]]:
45 45 return unique_texts, position_to_unique
46 46  
47 47  
48   -def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]:
  48 +def _format_instruction__standard(instruction: str, query: str, doc: str) -> List[Dict[str, str]]:
49 49 """Build chat messages for one (query, doc) pair."""
50 50 return [
51 51 {
... ... @@ -58,6 +58,18 @@ def _format_instruction(instruction: str, query: str, doc: str) -&gt; List[Dict[str
58 58 },
59 59 ]
60 60  
  61 +def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]:
  62 + """Build chat messages for one (query, doc) pair."""
  63 + return [
  64 + {
  65 + "role": "system",
  66 + "content": instruction,
  67 + },
  68 + {
  69 + "role": "user",
  70 + "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}",
  71 + },
  72 + ]
61 73  
62 74 class Qwen3VLLMRerankerBackend:
63 75 """
... ... @@ -78,6 +90,17 @@ class Qwen3VLLMRerankerBackend:
78 90 self._config.get("instruction")
79 91 or "Given a query, score the product for relevance"
80 92 )
  93 + _fmt = str(self._config.get("instruction_format") or "compact").strip().lower()
  94 + if _fmt not in {"standard", "compact"}:
  95 + raise ValueError(
  96 + f"instruction_format must be 'standard' or 'compact', got {_fmt!r}"
  97 + )
  98 + self._instruction_format = _fmt
  99 + self._format_messages = (
  100 + _format_instruction__standard
  101 + if self._instruction_format == "standard"
  102 + else _format_instruction
  103 + )
81 104 infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get("infer_batch_size", 64)
82 105 sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH")
83 106 if sort_by_doc_length is None:
... ... @@ -95,13 +118,15 @@ class Qwen3VLLMRerankerBackend:
95 118 )
96 119  
97 120 logger.info(
98   - "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)",
  121 + "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s, "
  122 + "instruction_format=%s)",
99 123 model_name,
100 124 max_model_len,
101 125 tensor_parallel_size,
102 126 gpu_memory_utilization,
103 127 dtype,
104 128 enable_prefix_caching,
  129 + self._instruction_format,
105 130 )
106 131  
107 132 self._llm = LLM(
... ... @@ -145,7 +170,7 @@ class Qwen3VLLMRerankerBackend:
145 170 ) -> List[TokensPrompt]:
146 171 """Build tokenized prompts for vLLM from (query, doc) pairs. Batch apply_chat_template."""
147 172 messages_batch = [
148   - _format_instruction(self._instruction, q, d) for q, d in pairs
  173 + self._format_messages(self._instruction, q, d) for q, d in pairs
149 174 ]
150 175 tokenized = self._tokenizer.apply_chat_template(
151 176 messages_batch,
... ... @@ -242,6 +267,7 @@ class Qwen3VLLMRerankerBackend:
242 267 "infer_batch_size": self._infer_batch_size,
243 268 "inference_batches": 0,
244 269 "sort_by_doc_length": self._sort_by_doc_length,
  270 + "instruction_format": self._instruction_format,
245 271 }
246 272  
247 273 # Deduplicate globally by text, keep mapping to original indices.
... ... @@ -289,6 +315,7 @@ class Qwen3VLLMRerankerBackend:
289 315 "normalize": normalize,
290 316 "infer_batch_size": self._infer_batch_size,
291 317 "inference_batches": inference_batches,
292   - "sort_by_doc_length": self._sort_by_doc_length
  318 + "sort_by_doc_length": self._sort_by_doc_length,
  319 + "instruction_format": self._instruction_format,
293 320 }
294 321 return output_scores, meta
... ...
reranker/backends/qwen3_vllm_score.py 0 → 100644
... ... @@ -0,0 +1,323 @@
  1 +"""
  2 +Qwen3-Reranker via vLLM ``LLM.score()`` (pooling / cross-encoder score API).
  3 +
  4 +Matches vLLM ``examples/offline_inference/qwen3_reranker.py``: paired
  5 +``llm.score(query_texts, doc_texts)`` with the recommended prefix/suffix templates.
  6 +Requires vLLM >= 0.17 (uses ``runner``/``convert`` auto, not legacy ``task="score"``).
  7 +
  8 +Dedicated venv: ``.venv-reranker-score`` + ``requirements_reranker_qwen3_vllm_score.txt``
  9 +(see ``./scripts/setup_reranker_venv.sh qwen3_vllm_score``). Default ``model_name`` can match
  10 +``qwen3_vllm``; only the Python env differs for pinned high-performance vLLM.
  11 +
  12 +Reference: https://docs.vllm.ai/ — Qwen3 reranker example
  13 +"""
  14 +
  15 +from __future__ import annotations
  16 +
  17 +import logging
  18 +import os
  19 +import threading
  20 +import time
  21 +from typing import Any, Dict, List, Tuple
  22 +
  23 +logger = logging.getLogger("reranker.backends.qwen3_vllm_score")
  24 +
  25 +import torch
  26 +from vllm import LLM
  27 +
  28 +from reranker.backends.qwen3_vllm import deduplicate_with_positions
  29 +
  30 +# Official vLLM Qwen3 reranker prompt layout (im_start blocks + assistant suffix).
  31 +_DEFAULT_PREFIX = (
  32 + "<|im_start|>system\n"
  33 + "Judge whether the Document meets the requirements based on the Query and the Instruct "
  34 + 'provided. Note that the answer can only be "yes" or "no".'
  35 + "<|im_end|>\n<|im_start|>user\n"
  36 +)
  37 +_DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
  38 +_DEFAULT_QUERY_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
  39 +_DEFAULT_DOCUMENT_TEMPLATE = "<Document>: {doc}{suffix}"
  40 +# compact:与 qwen3_vllm._format_instruction 一致(instruction 作 system,user 内重复 Instruct)
  41 +_IM_USER_START = "<|im_end|>\n<|im_start|>user\n"
  42 +
  43 +
  44 +def _resolve_vllm_attention_config(config: Dict[str, Any]) -> Dict[str, Any] | None:
  45 + """
  46 + vLLM 0.18 defaults to Flash-Attention paths that require compute capability >= 8 (Ampere+).
  47 + Turing / Volta (e.g. T4 sm_75) must use a non-FA backend such as TRITON_ATTN.
  48 + """
  49 + env = (os.getenv("RERANK_VLLM_ATTENTION_BACKEND") or "").strip()
  50 + raw = config.get("vllm_attention_backend")
  51 + if env:
  52 + choice = env
  53 + elif raw is not None and str(raw).strip() and str(raw).strip().lower() != "auto":
  54 + choice = str(raw).strip()
  55 + else:
  56 + choice = ""
  57 + if choice:
  58 + backend = choice.strip().upper()
  59 + if backend == "AUTO":
  60 + choice = ""
  61 + else:
  62 + logger.info("[Qwen3_VLLM_SCORE] attention_config.backend=%s (from config/env)", backend)
  63 + return {"backend": backend}
  64 +
  65 + major, minor = torch.cuda.get_device_capability()
  66 + if major < 8:
  67 + logger.info(
  68 + "[Qwen3_VLLM_SCORE] GPU compute capability %d.%d < 8.0; using attention backend "
  69 + "TRITON_ATTN (Flash-Attention 2 requires sm >= 80). "
  70 + "Override with services.rerank.backends.qwen3_vllm_score.vllm_attention_backend "
  71 + "or RERANK_VLLM_ATTENTION_BACKEND.",
  72 + major,
  73 + minor,
  74 + )
  75 + return {"backend": "TRITON_ATTN"}
  76 + return None
  77 +
  78 +
  79 +class Qwen3VLLMScoreRerankerBackend:
  80 + """
  81 + Qwen3 reranker using vLLM ``LLM.score()`` (pooling runner) for cross-encoder scores.
  82 +
  83 + Config from ``services.rerank.backends.qwen3_vllm_score``.
  84 + """
  85 +
  86 + def __init__(self, config: Dict[str, Any]) -> None:
  87 + self._config = config or {}
  88 + model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
  89 + max_model_len = int(self._config.get("max_model_len", 2048))
  90 + tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1))
  91 + gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.4))
  92 + enable_prefix_caching = bool(self._config.get("enable_prefix_caching", False))
  93 + enforce_eager = bool(self._config.get("enforce_eager", True))
  94 + dtype = str(self._config.get("dtype", "float16")).strip().lower()
  95 + use_hf_overrides = self._config.get("use_original_qwen3_hf_overrides")
  96 + if use_hf_overrides is None:
  97 + use_hf_overrides = True
  98 + use_hf_overrides = bool(use_hf_overrides)
  99 +
  100 + self._instruction = str(
  101 + self._config.get("instruction")
  102 + or "Given a query, score the product for relevance"
  103 + )
  104 + _fmt = str(self._config.get("instruction_format") or "standard").strip().lower()
  105 + if _fmt not in {"standard", "compact"}:
  106 + raise ValueError(
  107 + f"instruction_format must be 'standard' or 'compact', got {_fmt!r}"
  108 + )
  109 + self._instruction_format = _fmt
  110 + self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX)
  111 + self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX)
  112 + self._query_template = str(self._config.get("query_template") or _DEFAULT_QUERY_TEMPLATE)
  113 + self._document_template = str(
  114 + self._config.get("document_template") or _DEFAULT_DOCUMENT_TEMPLATE
  115 + )
  116 +
  117 + infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get(
  118 + "infer_batch_size", 64
  119 + )
  120 + sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH")
  121 + if sort_by_doc_length is None:
  122 + sort_by_doc_length = self._config.get("sort_by_doc_length", True)
  123 +
  124 + self._infer_batch_size = int(infer_batch_size)
  125 + self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {
  126 + "1",
  127 + "true",
  128 + "yes",
  129 + "y",
  130 + "on",
  131 + }
  132 +
  133 + if not torch.cuda.is_available():
  134 + raise RuntimeError(
  135 + "qwen3_vllm_score backend requires CUDA GPU, but torch.cuda.is_available() is False"
  136 + )
  137 + if dtype not in {"float16", "half", "auto"}:
  138 + raise ValueError(
  139 + f"Unsupported dtype for qwen3_vllm_score: {dtype!r}. Use float16/half/auto."
  140 + )
  141 + if self._infer_batch_size <= 0:
  142 + raise ValueError(f"infer_batch_size must be > 0, got {self._infer_batch_size}")
  143 +
  144 + runner = str(self._config.get("vllm_runner") or "auto").strip().lower()
  145 + convert = str(self._config.get("vllm_convert") or "auto").strip().lower()
  146 + if runner not in {"auto", "generate", "pooling", "draft"}:
  147 + raise ValueError(f"Invalid vllm_runner: {runner!r}")
  148 + if convert not in {"auto", "none", "embed", "classify"}:
  149 + raise ValueError(f"Invalid vllm_convert: {convert!r}")
  150 +
  151 + logger.info(
  152 + "[Qwen3_VLLM_SCORE] Loading model %s (LLM.score API, runner=%s, convert=%s, "
  153 + "hf_overrides=%s, max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s, "
  154 + "instruction_format=%s)",
  155 + model_name,
  156 + runner,
  157 + convert,
  158 + use_hf_overrides,
  159 + max_model_len,
  160 + tensor_parallel_size,
  161 + gpu_memory_utilization,
  162 + dtype,
  163 + enable_prefix_caching,
  164 + self._instruction_format,
  165 + )
  166 +
  167 + # vLLM 0.17+ uses runner/convert instead of LLM(..., task="score"). With the official
  168 + # Qwen3 reranker hf_overrides, architecture becomes *ForSequenceClassification -> pooling+classify.
  169 + llm_kwargs: Dict[str, Any] = {
  170 + "model": model_name,
  171 + "runner": runner,
  172 + "convert": convert,
  173 + "tensor_parallel_size": tensor_parallel_size,
  174 + "max_model_len": max_model_len,
  175 + "gpu_memory_utilization": gpu_memory_utilization,
  176 + "enable_prefix_caching": enable_prefix_caching,
  177 + "enforce_eager": enforce_eager,
  178 + "dtype": dtype,
  179 + }
  180 + hf_overrides: Dict[str, Any] = dict(self._config.get("hf_overrides") or {})
  181 + if use_hf_overrides:
  182 + hf_overrides = {
  183 + **hf_overrides,
  184 + "architectures": ["Qwen3ForSequenceClassification"],
  185 + "classifier_from_token": ["no", "yes"],
  186 + "is_original_qwen3_reranker": True,
  187 + }
  188 + if hf_overrides:
  189 + llm_kwargs["hf_overrides"] = hf_overrides
  190 +
  191 + attn_cfg = _resolve_vllm_attention_config(self._config)
  192 + if attn_cfg is not None:
  193 + llm_kwargs["attention_config"] = attn_cfg
  194 +
  195 + self._llm = LLM(**llm_kwargs)
  196 + # vLLM score path: single-process safety (mirrors generate backend until verified).
  197 + self._infer_lock = threading.Lock()
  198 +
  199 + self._model_name = model_name
  200 + logger.info("[Qwen3_VLLM_SCORE] Model ready | model=%s", model_name)
  201 +
  202 + def _format_pair(self, query: str, doc: str) -> Tuple[str, str]:
  203 + if self._instruction_format == "compact":
  204 + # Align with reranker.backends.qwen3_vllm._format_instruction query/doc split for LLM.score().
  205 + compact_prefix = f"<|im_start|>system\n{self._instruction}{_IM_USER_START}"
  206 + q_text = (
  207 + f"{compact_prefix}<Instruct>: {self._instruction}\n\n<Query>: {query}\n"
  208 + )
  209 + d_text = f"\n<Document>: {doc}{self._suffix}"
  210 + return q_text, d_text
  211 + q_text = self._query_template.format(
  212 + prefix=self._prefix,
  213 + instruction=self._instruction,
  214 + query=query,
  215 + )
  216 + d_text = self._document_template.format(doc=doc, suffix=self._suffix)
  217 + return q_text, d_text
  218 +
  219 + def _score_batch(self, pairs: List[Tuple[str, str]]) -> List[float]:
  220 + if not pairs:
  221 + return []
  222 + queries: List[str] = []
  223 + documents: List[str] = []
  224 + for q, d in pairs:
  225 + qt, dt = self._format_pair(q, d)
  226 + queries.append(qt)
  227 + documents.append(dt)
  228 + with self._infer_lock:
  229 + outputs = self._llm.score(queries, documents, use_tqdm=False)
  230 + scores: List[float] = []
  231 + for out in outputs:
  232 + so = out.outputs
  233 + scores.append(float(so.score))
  234 + return scores
  235 +
  236 + @staticmethod
  237 + def _estimate_doc_lengths(docs: List[str]) -> List[int]:
  238 + if not docs:
  239 + return []
  240 + return [len(text) for text in docs]
  241 +
  242 + def score_with_meta(
  243 + self,
  244 + query: str,
  245 + docs: List[str],
  246 + normalize: bool = True,
  247 + ) -> Tuple[List[float], Dict[str, Any]]:
  248 + start_ts = time.time()
  249 + total_docs = len(docs) if docs else 0
  250 + output_scores: List[float] = [0.0] * total_docs
  251 +
  252 + query = "" if query is None else str(query).strip()
  253 + indexed: List[Tuple[int, str]] = []
  254 + for i, doc in enumerate(docs or []):
  255 + if doc is None:
  256 + continue
  257 + text = str(doc).strip()
  258 + if not text:
  259 + continue
  260 + indexed.append((i, text))
  261 +
  262 + if not query or not indexed:
  263 + elapsed_ms = (time.time() - start_ts) * 1000.0
  264 + return output_scores, {
  265 + "input_docs": total_docs,
  266 + "usable_docs": len(indexed),
  267 + "unique_docs": 0,
  268 + "dedup_ratio": 0.0,
  269 + "elapsed_ms": round(elapsed_ms, 3),
  270 + "model": self._model_name,
  271 + "backend": "qwen3_vllm_score",
  272 + "normalize": normalize,
  273 + "infer_batch_size": self._infer_batch_size,
  274 + "inference_batches": 0,
  275 + "sort_by_doc_length": self._sort_by_doc_length,
  276 + "instruction_format": self._instruction_format,
  277 + }
  278 +
  279 + indexed_texts = [text for _, text in indexed]
  280 + unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts)
  281 +
  282 + lengths = self._estimate_doc_lengths(unique_texts)
  283 + order = list(range(len(unique_texts)))
  284 + if self._sort_by_doc_length and len(unique_texts) > 1:
  285 + order = sorted(order, key=lambda i: lengths[i])
  286 +
  287 + unique_scores: List[float] = [0.0] * len(unique_texts)
  288 + inference_batches = 0
  289 + for start in range(0, len(order), self._infer_batch_size):
  290 + batch_indices = order[start : start + self._infer_batch_size]
  291 + inference_batches += 1
  292 + pairs = [(query, unique_texts[i]) for i in batch_indices]
  293 + batch_scores = self._score_batch(pairs)
  294 + if len(batch_scores) != len(batch_indices):
  295 + raise RuntimeError(
  296 + f"Reranker score size mismatch: expected {len(batch_indices)}, got {len(batch_scores)}"
  297 + )
  298 + for idx, score in zip(batch_indices, batch_scores):
  299 + unique_scores[idx] = float(score)
  300 +
  301 + for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
  302 + output_scores[orig_idx] = float(unique_scores[unique_idx])
  303 +
  304 + elapsed_ms = (time.time() - start_ts) * 1000.0
  305 + dedup_ratio = 0.0
  306 + if indexed:
  307 + dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))
  308 +
  309 + meta = {
  310 + "input_docs": total_docs,
  311 + "usable_docs": len(indexed),
  312 + "unique_docs": len(unique_texts),
  313 + "dedup_ratio": round(dedup_ratio, 4),
  314 + "elapsed_ms": round(elapsed_ms, 3),
  315 + "model": self._model_name,
  316 + "backend": "qwen3_vllm_score",
  317 + "normalize": normalize,
  318 + "infer_batch_size": self._infer_batch_size,
  319 + "inference_batches": inference_batches,
  320 + "sort_by_doc_length": self._sort_by_doc_length,
  321 + "instruction_format": self._instruction_format,
  322 + }
  323 + return output_scores, meta
... ...
reranker/server.py
... ... @@ -7,7 +7,7 @@ Request: { &quot;query&quot;: &quot;...&quot;, &quot;docs&quot;: [&quot;doc1&quot;, &quot;doc2&quot;, ...], &quot;normalize&quot;: optional
7 7 Response: { "scores": [float], "meta": {...} }
8 8  
9 9 Backend selected via config: services.rerank.backend
10   -(bge | qwen3_vllm | qwen3_transformers | dashscope_rerank), env RERANK_BACKEND.
  10 +(bge | qwen3_vllm | qwen3_vllm_score | qwen3_transformers | qwen3_transformers_packed | qwen3_gguf | qwen3_gguf_06b | dashscope_rerank), env RERANK_BACKEND.
11 11 """
12 12  
13 13 import logging
... ... @@ -99,12 +99,17 @@ def health() -&gt; Dict[str, Any]:
99 99 model_info = getattr(_reranker, "_model_name", None) or getattr(
100 100 _reranker, "_config", {}
101 101 ).get("model_name", _backend_name)
102   - return {
  102 + payload: Dict[str, Any] = {
103 103 "status": "ok" if _reranker is not None else "unavailable",
104 104 "model_loaded": _reranker is not None,
105 105 "model": model_info,
106 106 "backend": _backend_name,
107 107 }
  108 + if _reranker is not None:
  109 + _fmt = getattr(_reranker, "_instruction_format", None)
  110 + if _fmt is not None:
  111 + payload["instruction_format"] = _fmt
  112 + return payload
108 113  
109 114  
110 115 @app.post("/rerank", response_model=RerankResponse)
... ...
scripts/benchmark_reranker_gguf_local.py 0 → 100644
... ... @@ -0,0 +1,198 @@
  1 +#!/usr/bin/env python3
  2 +"""
  3 +Local tuning probe for GGUF reranker backends.
  4 +
  5 +Runs the backend directly in a fresh process per config to measure:
  6 +- load time
  7 +- GPU memory used by this process
  8 +- single-request rerank latency
  9 +
  10 +Example:
  11 + ./.venv-reranker-gguf/bin/python scripts/benchmark_reranker_gguf_local.py
  12 + ./.venv-reranker-gguf-06b/bin/python scripts/benchmark_reranker_gguf_local.py --backend-name qwen3_gguf_06b --docs 400
  13 +"""
  14 +
  15 +from __future__ import annotations
  16 +
  17 +import argparse
  18 +import json
  19 +import os
  20 +import random
  21 +import statistics
  22 +import subprocess
  23 +import sys
  24 +import time
  25 +from pathlib import Path
  26 +from typing import Any
  27 +
  28 +
  29 +DEFAULT_TITLES = Path("/home/ubuntu/rerank_test/titles.1.8w")
  30 +
  31 +
  32 +def load_titles(path: Path) -> list[str]:
  33 + items: list[str] = []
  34 + with path.open(encoding="utf-8", errors="replace") as fh:
  35 + for line in fh:
  36 + text = line.strip()
  37 + if text:
  38 + items.append(text)
  39 + return items
  40 +
  41 +
  42 +def gpu_mem_for_pid(pid: int) -> int:
  43 + try:
  44 + out = subprocess.check_output(
  45 + [
  46 + "nvidia-smi",
  47 + "--query-compute-apps=pid,used_gpu_memory",
  48 + "--format=csv,noheader,nounits",
  49 + ],
  50 + text=True,
  51 + )
  52 + except Exception:
  53 + return -1
  54 + for raw in out.splitlines():
  55 + parts = [p.strip() for p in raw.split(",")]
  56 + if len(parts) != 2:
  57 + continue
  58 + try:
  59 + row_pid = int(parts[0])
  60 + row_mem = int(parts[1])
  61 + except ValueError:
  62 + continue
  63 + if row_pid == pid:
  64 + return row_mem
  65 + return -1
  66 +
  67 +
  68 +def main() -> int:
  69 + parser = argparse.ArgumentParser()
  70 + parser.add_argument("--backend-name", type=str, default="qwen3_gguf")
  71 + parser.add_argument("--titles-file", type=Path, default=DEFAULT_TITLES)
  72 + parser.add_argument("--query", type=str, default="白色oversized T-shirt")
  73 + parser.add_argument("--docs", type=int, default=160)
  74 + parser.add_argument("--repeat", type=int, default=1)
  75 + parser.add_argument("--seed", type=int, default=42)
  76 + parser.add_argument(
  77 + "--configs-json",
  78 + type=str,
  79 + default="",
  80 + help="JSON array of config objects; when omitted, uses built-in scan set.",
  81 + )
  82 + args = parser.parse_args()
  83 +
  84 + if not args.titles_file.is_file():
  85 + print(f"missing titles file: {args.titles_file}", file=sys.stderr)
  86 + return 2
  87 +
  88 + titles = load_titles(args.titles_file)
  89 + if len(titles) < args.docs:
  90 + print(f"not enough titles: need {args.docs}, got {len(titles)}", file=sys.stderr)
  91 + return 2
  92 +
  93 + random.seed(args.seed)
  94 + docs = random.sample(titles, args.docs)
  95 +
  96 + if args.configs_json:
  97 + configs = json.loads(args.configs_json)
  98 + elif args.backend_name == "qwen3_gguf_06b":
  99 + configs = [
  100 + {"name": "gguf_06b_full_256", "n_ctx": 256, "n_batch": 256, "n_ubatch": 256, "n_gpu_layers": 999},
  101 + {"name": "gguf_06b_full_320", "n_ctx": 320, "n_batch": 320, "n_ubatch": 320, "n_gpu_layers": 999},
  102 + {"name": "gguf_06b_full_384", "n_ctx": 384, "n_batch": 384, "n_ubatch": 384, "n_gpu_layers": 999},
  103 + {"name": "gguf_06b_full_512", "n_ctx": 512, "n_batch": 512, "n_ubatch": 512, "n_gpu_layers": 999},
  104 + ]
  105 + else:
  106 + configs = [
  107 + {"name": "gguf_t4_24g", "n_ctx": 384, "n_batch": 384, "n_ubatch": 128, "n_gpu_layers": 24},
  108 + {"name": "gguf_t4_40g", "n_ctx": 384, "n_batch": 384, "n_ubatch": 128, "n_gpu_layers": 40},
  109 + {"name": "gguf_t4_full", "n_ctx": 384, "n_batch": 384, "n_ubatch": 128, "n_gpu_layers": 999},
  110 + {"name": "gguf_t4_full_512", "n_ctx": 512, "n_batch": 512, "n_ubatch": 256, "n_gpu_layers": 999},
  111 + {"name": "gguf_t4_full_512_u512", "n_ctx": 512, "n_batch": 512, "n_ubatch": 512, "n_gpu_layers": 999},
  112 + {"name": "gguf_t4_full_768", "n_ctx": 768, "n_batch": 768, "n_ubatch": 256, "n_gpu_layers": 999},
  113 + ]
  114 +
  115 + from reranker.backends.qwen3_gguf import Qwen3GGUFRerankerBackend
  116 +
  117 + default_cfg_by_backend: dict[str, dict[str, Any]] = {
  118 + "qwen3_gguf": {
  119 + "_backend_name": "qwen3_gguf",
  120 + "repo_id": "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF",
  121 + "filename": "*Q8_0.gguf",
  122 + "local_dir": "./models/reranker/qwen3-reranker-4b-gguf",
  123 + "infer_batch_size": 8,
  124 + },
  125 + "qwen3_gguf_06b": {
  126 + "_backend_name": "qwen3_gguf_06b",
  127 + "repo_id": "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF",
  128 + "filename": "qwen3-reranker-0.6b-q8_0.gguf",
  129 + "local_dir": "./models/reranker/qwen3-reranker-0.6b-q8_0-gguf",
  130 + "infer_batch_size": 32,
  131 + },
  132 + }
  133 + if args.backend_name not in default_cfg_by_backend:
  134 + print(f"unsupported backend: {args.backend_name}", file=sys.stderr)
  135 + return 2
  136 +
  137 + base_cfg: dict[str, Any] = {
  138 + **default_cfg_by_backend[args.backend_name],
  139 + "instruction": "Rank products by query with category & style match prioritized",
  140 + "cache_dir": "./model_cache",
  141 + "main_gpu": 0,
  142 + "n_threads": 2,
  143 + "n_threads_batch": 4,
  144 + "flash_attn": True,
  145 + "offload_kqv": True,
  146 + "use_mmap": True,
  147 + "use_mlock": False,
  148 + "sort_by_doc_length": True,
  149 + "length_sort_mode": "char",
  150 + "enable_warmup": True,
  151 + "verbose": False,
  152 + "reuse_query_state": True,
  153 + }
  154 +
  155 + all_results: list[dict[str, Any]] = []
  156 + for cfg in configs:
  157 + merged = dict(base_cfg)
  158 + merged.update(cfg)
  159 + name = str(merged.pop("name"))
  160 +
  161 + t0 = time.perf_counter()
  162 + backend = Qwen3GGUFRerankerBackend(merged)
  163 + load_ms = (time.perf_counter() - t0) * 1000.0
  164 + gpu_mem_mib = gpu_mem_for_pid(os.getpid())
  165 +
  166 + runs: list[float] = []
  167 + last_meta: dict[str, Any] = {}
  168 + for _ in range(args.repeat):
  169 + t1 = time.perf_counter()
  170 + _scores, meta = backend.score_with_meta(args.query, docs, normalize=True)
  171 + runs.append((time.perf_counter() - t1) * 1000.0)
  172 + last_meta = dict(meta)
  173 +
  174 + result = {
  175 + "name": name,
  176 + "config": merged,
  177 + "load_ms": round(load_ms, 2),
  178 + "gpu_mem_mib": gpu_mem_mib,
  179 + "latency_ms_min": round(min(runs), 2),
  180 + "latency_ms_avg": round(statistics.mean(runs), 2),
  181 + "latency_ms_max": round(max(runs), 2),
  182 + "meta": last_meta,
  183 + }
  184 + all_results.append(result)
  185 + print(json.dumps(result, ensure_ascii=False))
  186 + del backend
  187 +
  188 + print("SUMMARY")
  189 + for item in sorted(all_results, key=lambda x: x["latency_ms_avg"]):
  190 + print(
  191 + f'{item["name"]}: avg={item["latency_ms_avg"]}ms '
  192 + f'gpu={item["gpu_mem_mib"]}MiB load={item["load_ms"]}ms'
  193 + )
  194 + return 0
  195 +
  196 +
  197 +if __name__ == "__main__":
  198 + raise SystemExit(main())
... ...
scripts/benchmark_reranker_random_titles.py
... ... @@ -6,6 +6,7 @@ Randomly samples N titles from a text file (one title per line), POSTs to the
6 6 rerank HTTP API, prints wall-clock latency.
7 7  
8 8 Supports multiple N values (comma-separated) and multiple repeats per N.
  9 +Each invocation runs 3 warmup requests with n=400 first; those are not timed for summaries.
9 10  
10 11 Example:
11 12 source activate.sh
... ... @@ -149,6 +150,23 @@ def main() -&gt; int:
149 150 action="store_true",
150 151 help="Print first ~500 chars of response body on success (last run only).",
151 152 )
  153 + parser.add_argument(
  154 + "--tag",
  155 + type=str,
  156 + default=os.environ.get("BENCH_TAG", ""),
  157 + help="Optional label stored in --json-summary-out (default: env BENCH_TAG or empty).",
  158 + )
  159 + parser.add_argument(
  160 + "--json-summary-out",
  161 + type=Path,
  162 + default=None,
  163 + help="Write one JSON object with per-n latencies and aggregates for downstream tables.",
  164 + )
  165 + parser.add_argument(
  166 + "--quiet-runs",
  167 + action="store_true",
  168 + help="Suppress per-run lines; still prints warmup lines and text summaries.",
  169 + )
152 170 args = parser.parse_args()
153 171  
154 172 try:
... ... @@ -167,7 +185,9 @@ def main() -&gt; int:
167 185 return 2
168 186  
169 187 titles = _load_titles(args.titles_file)
170   - max_n = max(doc_counts)
  188 + warmup_n = 400
  189 + warmup_runs = 3
  190 + max_n = max(max(doc_counts), warmup_n)
171 191 if len(titles) < max_n:
172 192 print(
173 193 f"error: file has only {len(titles)} non-empty lines, need at least {max_n}",
... ... @@ -181,6 +201,33 @@ def main() -&gt; int:
181 201 summary: dict[int, List[float]] = {n: [] for n in doc_counts}
182 202  
183 203 with httpx.Client(timeout=args.timeout) as client:
  204 + for w in range(warmup_runs):
  205 + if args.seed is not None:
  206 + random.seed(args.seed + 8_000_000 + w)
  207 + docs_w = random.sample(titles, warmup_n)
  208 + try:
  209 + ok_w, status_w, _elapsed_w, scores_len_w, _text_w = _do_rerank(
  210 + client,
  211 + args.url,
  212 + args.query,
  213 + docs_w,
  214 + top_n=top_n,
  215 + normalize=normalize,
  216 + )
  217 + except httpx.HTTPError as exc:
  218 + print(
  219 + f"warmup n={warmup_n} {w + 1}/{warmup_runs} error: request failed: {exc}",
  220 + file=sys.stderr,
  221 + )
  222 + any_fail = True
  223 + continue
  224 + if not ok_w:
  225 + any_fail = True
  226 + print(
  227 + f"warmup n={warmup_n} {w + 1}/{warmup_runs} status={status_w} "
  228 + f"scores={scores_len_w if scores_len_w is not None else 'n/a'} (not timed)"
  229 + )
  230 +
184 231 for n in doc_counts:
185 232 for run_idx in range(repeat):
186 233 if args.seed is not None:
... ... @@ -208,10 +255,11 @@ def main() -&gt; int:
208 255 else:
209 256 any_fail = True
210 257  
211   - print(
212   - f"n={n} run={run_idx + 1}/{repeat} status={status} "
213   - f"latency_ms={elapsed_ms:.2f} scores={scores_len if scores_len is not None else 'n/a'}"
214   - )
  258 + if not args.quiet_runs:
  259 + print(
  260 + f"n={n} run={run_idx + 1}/{repeat} status={status} "
  261 + f"latency_ms={elapsed_ms:.2f} scores={scores_len if scores_len is not None else 'n/a'}"
  262 + )
215 263 if args.print_body_preview and text and run_idx == repeat - 1 and n == doc_counts[-1]:
216 264 preview = text[:500] + ("…" if len(text) > 500 else "")
217 265 print(preview)
... ... @@ -230,6 +278,33 @@ def main() -&gt; int:
230 278 f"summary n={n} runs={len(lat)} min_ms={lo:.2f} max_ms={hi:.2f} avg_ms={avg:.2f}{extra}"
231 279 )
232 280  
  281 + if args.json_summary_out is not None:
  282 + per_n: dict = {}
  283 + for n in doc_counts:
  284 + lat = summary[n]
  285 + row: dict = {"values_ms": lat, "runs": len(lat)}
  286 + if lat:
  287 + row["mean_ms"] = statistics.mean(lat)
  288 + row["min_ms"] = min(lat)
  289 + row["max_ms"] = max(lat)
  290 + if len(lat) >= 2:
  291 + row["stdev_ms"] = statistics.stdev(lat)
  292 + per_n[str(n)] = row
  293 + out_obj = {
  294 + "tag": args.tag or None,
  295 + "doc_counts": doc_counts,
  296 + "repeat": repeat,
  297 + "url": args.url,
  298 + "per_n": per_n,
  299 + "failed": bool(any_fail),
  300 + }
  301 + args.json_summary_out.parent.mkdir(parents=True, exist_ok=True)
  302 + args.json_summary_out.write_text(
  303 + json.dumps(out_obj, ensure_ascii=False, indent=2) + "\n",
  304 + encoding="utf-8",
  305 + )
  306 + print(f"wrote json summary -> {args.json_summary_out}")
  307 +
233 308 return 1 if any_fail else 0
234 309  
235 310  
... ...
scripts/lib/reranker_backend_env.sh 0 → 100644
... ... @@ -0,0 +1,68 @@
  1 +#!/bin/bash
  2 +#
  3 +# Shared helpers for mapping reranker backends to isolated virtualenvs.
  4 +#
  5 +
  6 +set -euo pipefail
  7 +
  8 +detect_rerank_backend() {
  9 + local project_root="$1"
  10 + local backend="${RERANK_BACKEND:-}"
  11 +
  12 + if [[ -n "${backend}" ]]; then
  13 + printf '%s\n' "${backend}"
  14 + return 0
  15 + fi
  16 +
  17 + backend="$(
  18 + awk '
  19 + /^ rerank:$/ { in_rerank=1; next }
  20 + in_rerank && /^ [^ ]/ { in_rerank=0 }
  21 + in_rerank && /^ backend:/ {
  22 + gsub(/"/, "", $2)
  23 + print $2
  24 + exit
  25 + }
  26 + ' "${project_root}/config/config.yaml"
  27 + )"
  28 +
  29 + if [[ -z "${backend}" ]]; then
  30 + backend="qwen3_vllm"
  31 + fi
  32 +
  33 + printf '%s\n' "${backend}"
  34 +}
  35 +
  36 +reranker_backend_venv_dir() {
  37 + local project_root="$1"
  38 + local backend="$2"
  39 +
  40 + case "${backend}" in
  41 + qwen3_vllm) printf '%s/.venv-reranker\n' "${project_root}" ;;
  42 + qwen3_vllm_score) printf '%s/.venv-reranker-score\n' "${project_root}" ;;
  43 + qwen3_gguf) printf '%s/.venv-reranker-gguf\n' "${project_root}" ;;
  44 + qwen3_gguf_06b) printf '%s/.venv-reranker-gguf-06b\n' "${project_root}" ;;
  45 + qwen3_transformers) printf '%s/.venv-reranker-transformers\n' "${project_root}" ;;
  46 + qwen3_transformers_packed) printf '%s/.venv-reranker-transformers-packed\n' "${project_root}" ;;
  47 + bge) printf '%s/.venv-reranker-bge\n' "${project_root}" ;;
  48 + dashscope_rerank) printf '%s/.venv-reranker-dashscope\n' "${project_root}" ;;
  49 + *) printf '%s/.venv-reranker-%s\n' "${project_root}" "${backend}" ;;
  50 + esac
  51 +}
  52 +
  53 +reranker_backend_requirements_file() {
  54 + local project_root="$1"
  55 + local backend="$2"
  56 +
  57 + case "${backend}" in
  58 + qwen3_vllm) printf '%s/requirements_reranker_qwen3_vllm.txt\n' "${project_root}" ;;
  59 + qwen3_vllm_score) printf '%s/requirements_reranker_qwen3_vllm_score.txt\n' "${project_root}" ;;
  60 + qwen3_gguf) printf '%s/requirements_reranker_qwen3_gguf.txt\n' "${project_root}" ;;
  61 + qwen3_gguf_06b) printf '%s/requirements_reranker_qwen3_gguf_06b.txt\n' "${project_root}" ;;
  62 + qwen3_transformers) printf '%s/requirements_reranker_qwen3_transformers.txt\n' "${project_root}" ;;
  63 + qwen3_transformers_packed) printf '%s/requirements_reranker_qwen3_transformers_packed.txt\n' "${project_root}" ;;
  64 + bge) printf '%s/requirements_reranker_bge.txt\n' "${project_root}" ;;
  65 + dashscope_rerank) printf '%s/requirements_reranker_dashscope.txt\n' "${project_root}" ;;
  66 + *) return 1 ;;
  67 + esac
  68 +}
... ...
scripts/patch_rerank_vllm_benchmark_config.py 0 → 100755
... ... @@ -0,0 +1,100 @@
  1 +#!/usr/bin/env python3
  2 +"""
  3 +Surgically patch config/config.yaml:
  4 + services.rerank.backend
  5 + services.rerank.backends.qwen3_vllm.instruction_format
  6 + services.rerank.backends.qwen3_vllm_score.instruction_format
  7 +
  8 +Preserves comments and unrelated lines. Used for benchmark matrix runs.
  9 +"""
  10 +
  11 +from __future__ import annotations
  12 +
  13 +import argparse
  14 +import re
  15 +import sys
  16 +from pathlib import Path
  17 +
  18 +
  19 +def _with_stripped_body(line: str) -> tuple[str, str]:
  20 + """Return (body without end newline, newline suffix including '' if none)."""
  21 + if line.endswith("\r\n"):
  22 + return line[:-2], "\r\n"
  23 + if line.endswith("\n"):
  24 + return line[:-1], "\n"
  25 + return line, ""
  26 +
  27 +
  28 +def _patch_backend_in_rerank_block(lines: list[str], backend: str) -> None:
  29 + in_rerank = False
  30 + for i, line in enumerate(lines):
  31 + if line.startswith(" rerank:"):
  32 + in_rerank = True
  33 + continue
  34 + if in_rerank:
  35 + if line.startswith(" ") and not line.startswith(" ") and line.strip():
  36 + in_rerank = False
  37 + continue
  38 + body, nl = _with_stripped_body(line)
  39 + m = re.match(r'^(\s*backend:\s*")[^"]+(".*)$', body)
  40 + if m:
  41 + lines[i] = f'{m.group(1)}{backend}{m.group(2)}{nl}'
  42 + return
  43 + raise RuntimeError("services.rerank.backend line not found")
  44 +
  45 +
  46 +def _patch_instruction_format_under_backend(
  47 + lines: list[str], section: str, fmt: str
  48 +) -> None:
  49 + """section is 'qwen3_vllm' or 'qwen3_vllm_score' (first line is ' qwen3_vllm:')."""
  50 + header = f" {section}:"
  51 + start = None
  52 + for i, line in enumerate(lines):
  53 + if line.rstrip() == header:
  54 + start = i
  55 + break
  56 + if start is None:
  57 + raise RuntimeError(f"section {section!r} not found")
  58 +
  59 + for j in range(start + 1, len(lines)):
  60 + line = lines[j]
  61 + body, nl = _with_stripped_body(line)
  62 + if re.match(r"^ [a-zA-Z0-9_]+:\s*$", body):
  63 + break
  64 + m = re.match(r"^(\s*instruction_format:\s*)\S+", body)
  65 + if m:
  66 + lines[j] = f"{m.group(1)}{fmt}{nl}"
  67 + return
  68 + raise RuntimeError(f"instruction_format not found under {section!r}")
  69 +
  70 +
  71 +def main() -> int:
  72 + p = argparse.ArgumentParser()
  73 + p.add_argument(
  74 + "--config",
  75 + type=Path,
  76 + default=Path(__file__).resolve().parent.parent / "config" / "config.yaml",
  77 + )
  78 + p.add_argument("--backend", choices=("qwen3_vllm", "qwen3_vllm_score"), required=True)
  79 + p.add_argument(
  80 + "--instruction-format",
  81 + dest="instruction_format",
  82 + choices=("compact", "standard"),
  83 + required=True,
  84 + )
  85 + args = p.parse_args()
  86 + text = args.config.read_text(encoding="utf-8")
  87 + lines = text.splitlines(keepends=True)
  88 + if not lines:
  89 + print("empty config", file=sys.stderr)
  90 + return 2
  91 + _patch_backend_in_rerank_block(lines, args.backend)
  92 + _patch_instruction_format_under_backend(lines, "qwen3_vllm", args.instruction_format)
  93 + _patch_instruction_format_under_backend(lines, "qwen3_vllm_score", args.instruction_format)
  94 + args.config.write_text("".join(lines), encoding="utf-8")
  95 + print(f"patched {args.config}: backend={args.backend} instruction_format={args.instruction_format} (both vLLM blocks)")
  96 + return 0
  97 +
  98 +
  99 +if __name__ == "__main__":
  100 + raise SystemExit(main())
... ...
scripts/run_reranker_vllm_instruction_benchmark.sh 0 → 100755
... ... @@ -0,0 +1,89 @@
  1 +#!/usr/bin/env bash
  2 +# Patch config, restart reranker, wait for /health, run benchmark_reranker_random_titles.py.
  3 +# Requires: curl, .venv with PyYAML not needed (patch is standalone Python).
  4 +
  5 +set -euo pipefail
  6 +ROOT="$(cd "$(dirname "$0")/.." && pwd)"
  7 +cd "$ROOT"
  8 +
  9 +PYTHON="${ROOT}/.venv/bin/python"
  10 +DAY="$(date +%F)"
  11 +OUT_DIR="${ROOT}/perf_reports/reranker_vllm_instruction/${DAY}"
  12 +mkdir -p "$OUT_DIR"
  13 +
  14 +health_ok() {
  15 + local want_backend="$1"
  16 + local want_fmt="$2"
  17 + local body
  18 + if ! body="$(curl -sS --connect-timeout 2 --max-time 5 "http://127.0.0.1:6007/health" 2>/dev/null)"; then
  19 + return 1
  20 + fi
  21 + echo "$body" | "$PYTHON" -c "
  22 +import json, sys
  23 +want_b, want_f = sys.argv[1], sys.argv[2]
  24 +d = json.load(sys.stdin)
  25 +if d.get('status') != 'ok' or not d.get('model_loaded'):
  26 + sys.exit(1)
  27 +if d.get('backend') != want_b:
  28 + sys.exit(1)
  29 +if d.get('instruction_format') != want_f:
  30 + sys.exit(1)
  31 +sys.exit(0)
  32 +" "$want_backend" "$want_fmt"
  33 +}
  34 +
  35 +wait_health() {
  36 + local want_backend="$1"
  37 + local want_fmt="$2"
  38 + local i
  39 + for i in $(seq 1 180); do
  40 + if health_ok "$want_backend" "$want_fmt"; then
  41 + curl -sS "http://127.0.0.1:6007/health" | "$PYTHON" -m json.tool
  42 + return 0
  43 + fi
  44 + echo "[wait] ${i}/180 backend=${want_backend} instruction_format=${want_fmt} ..."
  45 + sleep 3
  46 + done
  47 + echo "[error] health did not match in time" >&2
  48 + return 1
  49 +}
  50 +
  51 +run_one() {
  52 + local backend="$1"
  53 + local fmt="$2"
  54 + local tag="${backend}|${fmt}"
  55 + local jf="${OUT_DIR}/${backend}_${fmt}.json"
  56 +
  57 + echo "========== ${tag} =========="
  58 + "$PYTHON" "${ROOT}/scripts/patch_rerank_vllm_benchmark_config.py" \
  59 + --backend "$backend" --instruction-format "$fmt"
  60 +
  61 + "${ROOT}/restart.sh" reranker
  62 + wait_health "$backend" "$fmt"
  63 +
  64 + if ! "$PYTHON" "${ROOT}/scripts/benchmark_reranker_random_titles.py" \
  65 + 100,200,400,600,800,1000 \
  66 + --repeat 5 \
  67 + --seed 42 \
  68 + --quiet-runs \
  69 + --timeout 360 \
  70 + --tag "$tag" \
  71 + --json-summary-out "$jf"
  72 + then
  73 + echo "[warn] benchmark exited non-zero for ${tag} (see ${jf} failed flag / partial runs)" >&2
  74 + fi
  75 +
  76 + echo "artifact: $jf"
  77 +}
  78 +
  79 +run_one qwen3_vllm compact
  80 +run_one qwen3_vllm standard
  81 +run_one qwen3_vllm_score compact
  82 +run_one qwen3_vllm_score standard
  83 +
  84 +# Restore repo-default-style rerank settings (score + compact).
  85 +"$PYTHON" "${ROOT}/scripts/patch_rerank_vllm_benchmark_config.py" \
  86 + --backend qwen3_vllm_score --instruction-format compact
  87 +"${ROOT}/restart.sh" reranker
  88 +wait_health qwen3_vllm_score compact
  89 +echo "Restored config: qwen3_vllm_score + compact. Done. Artifacts under ${OUT_DIR}"
... ...
scripts/setup_reranker_venv.sh
1 1 #!/bin/bash
2 2 #
3   -# Create isolated venv for reranker service (.venv-reranker).
  3 +# Create isolated venv for one reranker backend.
4 4 #
5 5 set -euo pipefail
6 6  
7 7 PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
8 8 cd "${PROJECT_ROOT}"
9 9  
10   -VENV_DIR="${PROJECT_ROOT}/.venv-reranker"
11 10 PYTHON_BIN="${PYTHON_BIN:-python3}"
12 11 TMP_DIR="${RERANKER_PIP_TMPDIR:-${PROJECT_ROOT}/.tmp/reranker-pip}"
13 12  
  13 +# shellcheck source=scripts/lib/load_env.sh
  14 +source "${PROJECT_ROOT}/scripts/lib/load_env.sh"
  15 +load_env_file "${PROJECT_ROOT}/.env"
  16 +# shellcheck source=scripts/lib/reranker_backend_env.sh
  17 +source "${PROJECT_ROOT}/scripts/lib/reranker_backend_env.sh"
  18 +
  19 +BACKEND="${1:-$(detect_rerank_backend "${PROJECT_ROOT}")}"
  20 +VENV_DIR="${RERANKER_VENV:-$(reranker_backend_venv_dir "${PROJECT_ROOT}" "${BACKEND}")}"
  21 +REQ_FILE="$(reranker_backend_requirements_file "${PROJECT_ROOT}" "${BACKEND}")"
  22 +
  23 +if [[ ! -f "${REQ_FILE}" ]]; then
  24 + echo "ERROR: requirements file not found for reranker backend ${BACKEND}: ${REQ_FILE}" >&2
  25 + exit 1
  26 +fi
  27 +
14 28 if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then
15 29 echo "ERROR: python not found: ${PYTHON_BIN}" >&2
16 30 exit 1
... ... @@ -34,9 +48,35 @@ PIP_ARGS=(--no-cache-dir)
34 48  
35 49 echo "Using TMPDIR=${TMPDIR}"
36 50 "${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" --upgrade pip wheel
37   -"${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" -r requirements_reranker_service.txt
  51 +"${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" -r "${REQ_FILE}"
  52 +
  53 +if [[ "${BACKEND}" == qwen3_gguf* ]]; then
  54 + if [[ -x "/usr/local/cuda/bin/nvcc" ]]; then
  55 + "${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" \
  56 + cmake \
  57 + ninja \
  58 + scikit-build-core \
  59 + flit_core \
  60 + setuptools-scm
  61 + echo "Rebuilding llama-cpp-python with CUDA support for ${BACKEND}"
  62 + PATH="/usr/local/cuda/bin:/usr/bin:/bin" \
  63 + CC="/usr/bin/x86_64-linux-gnu-gcc" \
  64 + CXX="/usr/bin/x86_64-linux-gnu-g++" \
  65 + CUDACXX="/usr/local/cuda/bin/nvcc" \
  66 + CMAKE_ARGS="-DGGML_CUDA=on" \
  67 + FORCE_CMAKE=1 \
  68 + "${VENV_DIR}/bin/python" -m pip install "${PIP_ARGS[@]}" \
  69 + --force-reinstall \
  70 + --no-build-isolation \
  71 + "llama-cpp-python==0.3.18"
  72 + else
  73 + echo "WARNING: /usr/local/cuda/bin/nvcc not found; ${BACKEND} will be installed without CUDA support." >&2
  74 + fi
  75 +fi
38 76  
39 77 echo
40 78 echo "Done."
  79 +echo "Backend: ${BACKEND}"
41 80 echo "Reranker venv: ${VENV_DIR}"
  81 +echo "Requirements: ${REQ_FILE}"
42 82 echo "Start service: ./scripts/start_reranker.sh"
... ...
scripts/start_reranker.sh
1 1 #!/bin/bash
2 2 #
3   -# Start reranker service from isolated venv (.venv-reranker).
  3 +# Start reranker service from its backend-specific isolated venv.
4 4 #
5 5 set -euo pipefail
6 6  
7 7 PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
8 8 cd "${PROJECT_ROOT}"
9 9  
10   -RERANKER_VENV="${RERANKER_VENV:-${PROJECT_ROOT}/.venv-reranker}"
11   -PYTHON_BIN="${RERANKER_VENV}/bin/python"
12   -
13   -if [[ ! -x "${PYTHON_BIN}" ]]; then
14   - echo "ERROR: reranker venv not found: ${RERANKER_VENV}" >&2
15   - echo "Please run: ./scripts/setup_reranker_venv.sh" >&2
16   - exit 1
17   -fi
18   -
19 10 # Load .env without activating main venv.
20 11 # shellcheck source=scripts/lib/load_env.sh
21 12 source "${PROJECT_ROOT}/scripts/lib/load_env.sh"
22 13 load_env_file "${PROJECT_ROOT}/.env"
  14 +# shellcheck source=scripts/lib/reranker_backend_env.sh
  15 +source "${PROJECT_ROOT}/scripts/lib/reranker_backend_env.sh"
23 16  
24 17 RERANKER_HOST="${RERANKER_HOST:-0.0.0.0}"
25 18 RERANKER_PORT="${RERANKER_PORT:-6007}"
26   -RERANK_BACKEND=$("${PYTHON_BIN}" -c "from config.services_config import get_rerank_backend_config; print(get_rerank_backend_config()[0])")
  19 +RERANK_BACKEND="${RERANK_BACKEND:-$(detect_rerank_backend "${PROJECT_ROOT}")}"
  20 +RERANKER_VENV="${RERANKER_VENV:-$(reranker_backend_venv_dir "${PROJECT_ROOT}" "${RERANK_BACKEND}")}"
  21 +PYTHON_BIN="${RERANKER_VENV}/bin/python"
  22 +
  23 +if [[ ! -x "${PYTHON_BIN}" ]]; then
  24 + echo "ERROR: reranker venv not found for backend ${RERANK_BACKEND}: ${RERANKER_VENV}" >&2
  25 + echo "Please run: ./scripts/setup_reranker_venv.sh ${RERANK_BACKEND}" >&2
  26 + exit 1
  27 +fi
27 28  
28 29 # Keep vLLM/triton/torch caches out of system disk.
29 30 RERANKER_RUNTIME_DIR="${RERANKER_RUNTIME_DIR:-${PROJECT_ROOT}/.runtime/reranker}"
... ... @@ -42,23 +43,56 @@ export TMPDIR=&quot;${RERANKER_RUNTIME_DIR}/tmp&quot;
42 43 export VLLM_NO_USAGE_STATS="${VLLM_NO_USAGE_STATS:-1}"
43 44 export PATH="${RERANKER_VENV}/bin:${PATH}"
44 45  
45   -if [[ "${RERANK_BACKEND}" == "qwen3_vllm" ]]; then
  46 +if [[ "${RERANK_BACKEND}" == qwen3_gguf* ]]; then
  47 + export HF_HUB_DISABLE_XET="${HF_HUB_DISABLE_XET:-1}"
  48 +fi
  49 +
  50 +if [[ "${RERANK_BACKEND}" == "qwen3_vllm" || "${RERANK_BACKEND}" == "qwen3_vllm_score" || "${RERANK_BACKEND}" == "qwen3_transformers_packed" ]]; then
46 51 if ! command -v nvidia-smi >/dev/null 2>&1 || ! nvidia-smi >/dev/null 2>&1; then
47   - echo "ERROR: qwen3_vllm backend requires NVIDIA GPU, but nvidia-smi is unavailable." >&2
  52 + echo "ERROR: ${RERANK_BACKEND} backend requires NVIDIA GPU, but nvidia-smi is unavailable." >&2
48 53 exit 1
49 54 fi
50 55 if ! "${PYTHON_BIN}" - <<'PY'
51 56 try:
52   - import vllm # noqa: F401
53 57 import torch
  58 + try:
  59 + import vllm # noqa: F401
  60 + except Exception:
  61 + pass
54 62 if not torch.cuda.is_available():
55 63 raise SystemExit(1)
56 64 except Exception:
57 65 raise SystemExit(1)
58 66 PY
59 67 then
60   - echo "ERROR: qwen3_vllm backend requires vllm + CUDA runtime in ${RERANKER_VENV}." >&2
61   - echo "Please run: ./scripts/setup_reranker_venv.sh and verify CUDA is available." >&2
  68 + if [[ "${RERANK_BACKEND}" == "qwen3_transformers_packed" ]]; then
  69 + echo "ERROR: ${RERANK_BACKEND} backend requires torch + CUDA runtime in ${RERANKER_VENV}." >&2
  70 + else
  71 + echo "ERROR: ${RERANK_BACKEND} backend requires vllm + CUDA runtime in ${RERANKER_VENV}." >&2
  72 + fi
  73 + echo "Please run: ./scripts/setup_reranker_venv.sh ${RERANK_BACKEND} and verify CUDA is available." >&2
  74 + exit 1
  75 + fi
  76 +fi
  77 +
  78 +if [[ "${RERANK_BACKEND}" == qwen3_gguf* ]]; then
  79 + gguf_check_status=0
  80 + "${PYTHON_BIN}" - <<'PY' || gguf_check_status=$?
  81 +try:
  82 + import llama_cpp
  83 + if hasattr(llama_cpp, "llama_supports_gpu_offload") and not llama_cpp.llama_supports_gpu_offload():
  84 + raise SystemExit(2)
  85 +except Exception:
  86 + raise SystemExit(1)
  87 +PY
  88 + if [[ "${gguf_check_status}" != "0" ]]; then
  89 + if [[ "${gguf_check_status}" == "2" ]]; then
  90 + echo "ERROR: ${RERANK_BACKEND} backend detected a CPU-only llama-cpp-python build in ${RERANKER_VENV}." >&2
  91 + echo "Please rerun: ./scripts/setup_reranker_venv.sh ${RERANK_BACKEND}" >&2
  92 + else
  93 + echo "ERROR: ${RERANK_BACKEND} backend requires llama-cpp-python in ${RERANKER_VENV}." >&2
  94 + echo "Please run: ./scripts/setup_reranker_venv.sh ${RERANK_BACKEND}" >&2
  95 + fi
62 96 exit 1
63 97 fi
64 98 fi
... ...
search/rerank_client.py
... ... @@ -200,19 +200,24 @@ def _multiply_fusion_factors(
200 200 knn_score: float,
201 201 fusion: RerankFusionConfig,
202 202 ) -> Tuple[float, float, float, float]:
203   - """(rerank_factor, text_factor, knn_factor, fused)."""
  203 + """(rerank_factor, text_factor, knn_factor, fused_without_style_boost)."""
204 204 r = (max(rerank_score, 0.0) + fusion.rerank_bias) ** fusion.rerank_exponent
205 205 t = (max(text_score, 0.0) + fusion.text_bias) ** fusion.text_exponent
206 206 k = (max(knn_score, 0.0) + fusion.knn_bias) ** fusion.knn_exponent
207 207 return r, t, k, r * t * k
208 208  
209 209  
  210 +def _has_selected_sku(hit: Dict[str, Any]) -> bool:
  211 + return bool(str(hit.get("_style_rerank_suffix") or "").strip())
  212 +
  213 +
210 214 def fuse_scores_and_resort(
211 215 es_hits: List[Dict[str, Any]],
212 216 rerank_scores: List[float],
213 217 weight_es: float = DEFAULT_WEIGHT_ES,
214 218 weight_ai: float = DEFAULT_WEIGHT_AI,
215 219 fusion: Optional[RerankFusionConfig] = None,
  220 + style_intent_selected_sku_boost: float = 1.2,
216 221 debug: bool = False,
217 222 rerank_debug_rows: Optional[List[Dict[str, Any]]] = None,
218 223 ) -> List[Dict[str, Any]]:
... ... @@ -220,7 +225,10 @@ def fuse_scores_and_resort(
220 225 将 ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。
221 226  
222 227 融合形式(由 ``fusion`` 配置 bias / exponent)::
223   - fused = (max(rerank,0)+b_r)^e_r * (max(text,0)+b_t)^e_t * (max(knn,0)+b_k)^e_k
  228 + fused = (max(rerank,0)+b_r)^e_r * (max(text,0)+b_t)^e_t * (max(knn,0)+b_k)^e_k * sku_boost
  229 +
  230 + 其中 sku_boost 仅在当前 hit 已选中 SKU 时生效,默认值为 1.2,可通过
  231 + ``query.style_intent.selected_sku_boost`` 配置。
224 232  
225 233 对每条 hit 会写入:
226 234 - _original_score: 原始 ES 分数
... ... @@ -252,12 +260,16 @@ def fuse_scores_and_resort(
252 260 rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors(
253 261 rerank_score, text_score, knn_score, f
254 262 )
  263 + sku_selected = _has_selected_sku(hit)
  264 + style_boost = style_intent_selected_sku_boost if sku_selected else 1.0
  265 + fused *= style_boost
255 266  
256 267 hit["_original_score"] = hit.get("_score")
257 268 hit["_rerank_score"] = rerank_score
258 269 hit["_text_score"] = text_score
259 270 hit["_knn_score"] = knn_score
260 271 hit["_fused_score"] = fused
  272 + hit["_style_intent_selected_sku_boost"] = style_boost
261 273 if debug:
262 274 hit["_text_source_score"] = text_components["source_score"]
263 275 hit["_text_translation_score"] = text_components["translation_score"]
... ... @@ -285,6 +297,8 @@ def fuse_scores_and_resort(
285 297 "rerank_factor": rerank_factor,
286 298 "text_factor": text_factor,
287 299 "knn_factor": knn_factor,
  300 + "style_intent_selected_sku": sku_selected,
  301 + "style_intent_selected_sku_boost": style_boost,
288 302 "matched_queries": matched_queries,
289 303 "fused_score": fused,
290 304 }
... ... @@ -311,6 +325,7 @@ def run_rerank(
311 325 top_n: Optional[int] = None,
312 326 debug: bool = False,
313 327 fusion: Optional[RerankFusionConfig] = None,
  328 + style_intent_selected_sku_boost: float = 1.2,
314 329 ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]:
315 330 """
316 331 完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。
... ... @@ -345,6 +360,7 @@ def run_rerank(
345 360 weight_es=weight_es,
346 361 weight_ai=weight_ai,
347 362 fusion=fusion,
  363 + style_intent_selected_sku_boost=style_intent_selected_sku_boost,
348 364 debug=debug,
349 365 rerank_debug_rows=rerank_debug_rows,
350 366 )
... ...
search/searcher.py
... ... @@ -594,6 +594,7 @@ class Searcher:
594 594 top_n=(from_ + size),
595 595 debug=debug,
596 596 fusion=rc.fusion,
  597 + style_intent_selected_sku_boost=self.config.query_config.style_intent_selected_sku_boost,
597 598 )
598 599  
599 600 if rerank_meta is not None:
... ... @@ -1055,4 +1056,3 @@ class Searcher:
1055 1056 except Exception as e:
1056 1057 logger.error(f"Failed to get document {doc_id} from tenant {tenant_id}: {e}", exc_info=True)
1057 1058 return None
1058   -
... ...
search/sku_intent_selector.py
... ... @@ -5,12 +5,10 @@ SKU selection for style-intent-aware search results.
5 5 from __future__ import annotations
6 6  
7 7 from dataclasses import dataclass, field
8   -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
9   -
10   -import numpy as np
  8 +from typing import Any, Callable, Dict, List, Optional, Tuple
11 9  
12 10 from query.style_intent import StyleIntentProfile, StyleIntentRegistry
13   -from query.tokenization import normalize_query_text
  11 +from query.tokenization import normalize_query_text, simple_tokenize_query
14 12  
15 13  
16 14 @dataclass(frozen=True)
... ... @@ -34,24 +32,11 @@ class SkuSelectionDecision:
34 32  
35 33  
36 34 @dataclass
37   -class _SkuCandidate:
38   - index: int
39   - sku_id: str
40   - sku: Dict[str, Any]
41   - selection_text: str
42   - normalized_selection_text: str
43   - intent_values: Dict[str, str]
44   - normalized_intent_values: Dict[str, str]
45   -
46   -
47   -@dataclass
48 35 class _SelectionContext:
49   - query_texts: Tuple[str, ...]
50   - matched_terms_by_intent: Dict[str, Tuple[str, ...]]
51   - query_vector: Optional[np.ndarray]
  36 + attribute_terms_by_intent: Dict[str, Tuple[str, ...]]
  37 + normalized_text_cache: Dict[str, str] = field(default_factory=dict)
  38 + tokenized_text_cache: Dict[str, Tuple[str, ...]] = field(default_factory=dict)
52 39 text_match_cache: Dict[Tuple[str, str], bool] = field(default_factory=dict)
53   - selection_vector_cache: Dict[str, Optional[np.ndarray]] = field(default_factory=dict)
54   - similarity_cache: Dict[str, Optional[float]] = field(default_factory=dict)
55 40  
56 41  
57 42 class StyleSkuSelector:
... ... @@ -76,7 +61,7 @@ class StyleSkuSelector:
76 61 if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active:
77 62 return decisions
78 63  
79   - selection_context = self._build_selection_context(parsed_query, style_profile)
  64 + selection_context = self._build_selection_context(style_profile)
80 65  
81 66 for hit in es_hits:
82 67 source = hit.get("_source")
... ... @@ -126,81 +111,37 @@ class StyleSkuSelector:
126 111 else:
127 112 hit.pop("_style_rerank_suffix", None)
128 113  
129   - def _build_query_texts(
130   - self,
131   - parsed_query: Any,
132   - style_profile: StyleIntentProfile,
133   - ) -> List[str]:
134   - texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text]
135   - if texts:
136   - return list(dict.fromkeys(texts))
137   -
138   - fallbacks: List[str] = []
139   - for value in (
140   - getattr(parsed_query, "original_query", None),
141   - getattr(parsed_query, "query_normalized", None),
142   - getattr(parsed_query, "rewritten_query", None),
143   - ):
144   - normalized = normalize_query_text(value)
145   - if normalized:
146   - fallbacks.append(normalized)
147   - translations = getattr(parsed_query, "translations", {}) or {}
148   - if isinstance(translations, dict):
149   - for value in translations.values():
150   - normalized = normalize_query_text(value)
151   - if normalized:
152   - fallbacks.append(normalized)
153   - return list(dict.fromkeys(fallbacks))
154   -
155   - def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]:
156   - query_vector = getattr(parsed_query, "query_vector", None)
157   - if query_vector is not None:
158   - return np.asarray(query_vector, dtype=np.float32)
159   -
160   - text_encoder = self._get_text_encoder()
161   - if text_encoder is None:
162   - return None
163   -
164   - query_text = (
165   - getattr(parsed_query, "rewritten_query", None)
166   - or getattr(parsed_query, "query_normalized", None)
167   - or getattr(parsed_query, "original_query", None)
168   - )
169   - if not query_text:
170   - return None
171   -
172   - vectors = text_encoder.encode([query_text], priority=1)
173   - if vectors is None or len(vectors) == 0 or vectors[0] is None:
174   - return None
175   - return np.asarray(vectors[0], dtype=np.float32)
176   -
177 114 def _build_selection_context(
178 115 self,
179   - parsed_query: Any,
180 116 style_profile: StyleIntentProfile,
181 117 ) -> _SelectionContext:
182   - matched_terms_by_intent: Dict[str, List[str]] = {}
  118 + attribute_terms_by_intent: Dict[str, List[str]] = {}
183 119 for intent in style_profile.intents:
184   - normalized_term = normalize_query_text(intent.matched_term)
185   - if not normalized_term:
186   - continue
187   - matched_terms = matched_terms_by_intent.setdefault(intent.intent_type, [])
188   - if normalized_term not in matched_terms:
189   - matched_terms.append(normalized_term)
  120 + terms = attribute_terms_by_intent.setdefault(intent.intent_type, [])
  121 + for raw_term in intent.attribute_terms:
  122 + normalized_term = normalize_query_text(raw_term)
  123 + if not normalized_term or normalized_term in terms:
  124 + continue
  125 + terms.append(normalized_term)
190 126  
191 127 return _SelectionContext(
192   - query_texts=tuple(self._build_query_texts(parsed_query, style_profile)),
193   - matched_terms_by_intent={
  128 + attribute_terms_by_intent={
194 129 intent_type: tuple(terms)
195   - for intent_type, terms in matched_terms_by_intent.items()
  130 + for intent_type, terms in attribute_terms_by_intent.items()
196 131 },
197   - query_vector=self._get_query_vector(parsed_query),
198 132 )
199 133  
200   - def _get_text_encoder(self) -> Any:
201   - if self._text_encoder_getter is None:
202   - return None
203   - return self._text_encoder_getter()
  134 + @staticmethod
  135 + def _normalize_cached(selection_context: _SelectionContext, value: Any) -> str:
  136 + raw = str(value or "").strip()
  137 + if not raw:
  138 + return ""
  139 + cached = selection_context.normalized_text_cache.get(raw)
  140 + if cached is not None:
  141 + return cached
  142 + normalized = normalize_query_text(raw)
  143 + selection_context.normalized_text_cache[raw] = normalized
  144 + return normalized
204 145  
205 146 def _resolve_dimensions(
206 147 self,
... ... @@ -225,51 +166,6 @@ class StyleSkuSelector:
225 166 resolved[intent.intent_type] = matched_field
226 167 return resolved
227 168  
228   - def _build_candidates(
229   - self,
230   - skus: List[Dict[str, Any]],
231   - resolved_dimensions: Dict[str, Optional[str]],
232   - ) -> List[_SkuCandidate]:
233   - if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()):
234   - return []
235   -
236   - candidates: List[_SkuCandidate] = []
237   - for index, sku in enumerate(skus):
238   - intent_values: Dict[str, str] = {}
239   - normalized_intent_values: Dict[str, str] = {}
240   - for intent_type, field_name in resolved_dimensions.items():
241   - if not field_name:
242   - continue
243   - raw = str(sku.get(field_name) or "").strip()
244   - intent_values[intent_type] = raw
245   - normalized_intent_values[intent_type] = normalize_query_text(raw)
246   -
247   - selection_parts: List[str] = []
248   - norm_parts: List[str] = []
249   - seen: set[str] = set()
250   - for intent_type, raw in intent_values.items():
251   - nv = normalized_intent_values[intent_type]
252   - if not nv or nv in seen:
253   - continue
254   - seen.add(nv)
255   - selection_parts.append(raw)
256   - norm_parts.append(nv)
257   -
258   - selection_text = " ".join(selection_parts).strip()
259   - normalized_selection_text = " ".join(norm_parts).strip()
260   - candidates.append(
261   - _SkuCandidate(
262   - index=index,
263   - sku_id=str(sku.get("sku_id") or ""),
264   - sku=sku,
265   - selection_text=selection_text,
266   - normalized_selection_text=normalized_selection_text,
267   - intent_values=intent_values,
268   - normalized_intent_values=normalized_intent_values,
269   - )
270   - )
271   - return candidates
272   -
273 169 @staticmethod
274 170 def _empty_decision(
275 171 resolved_dimensions: Dict[str, Optional[str]],
... ... @@ -286,13 +182,10 @@ class StyleSkuSelector:
286 182 def _is_text_match(
287 183 self,
288 184 intent_type: str,
289   - value: str,
290 185 selection_context: _SelectionContext,
291 186 *,
292   - normalized_value: Optional[str] = None,
  187 + normalized_value: str,
293 188 ) -> bool:
294   - if normalized_value is None:
295   - normalized_value = normalize_query_text(value)
296 189 if not normalized_value:
297 190 return False
298 191  
... ... @@ -301,84 +194,94 @@ class StyleSkuSelector:
301 194 if cached is not None:
302 195 return cached
303 196  
304   - matched_terms = selection_context.matched_terms_by_intent.get(intent_type, ())
305   - has_term_match = any(term in normalized_value for term in matched_terms if term)
306   - query_contains_value = any(
307   - normalized_value in query_text
308   - for query_text in selection_context.query_texts
  197 + attribute_terms = selection_context.attribute_terms_by_intent.get(intent_type, ())
  198 + value_tokens = self._tokenize_cached(selection_context, normalized_value)
  199 + matched = any(
  200 + self._matches_term_tokens(
  201 + term=term,
  202 + value_tokens=value_tokens,
  203 + selection_context=selection_context,
  204 + normalized_value=normalized_value,
  205 + )
  206 + for term in attribute_terms
  207 + if term
309 208 )
310   - matched = bool(has_term_match or query_contains_value)
311 209 selection_context.text_match_cache[cache_key] = matched
312 210 return matched
313 211  
314   - def _find_first_text_match(
  212 + @staticmethod
  213 + def _tokenize_cached(selection_context: _SelectionContext, value: str) -> Tuple[str, ...]:
  214 + normalized_value = normalize_query_text(value)
  215 + if not normalized_value:
  216 + return ()
  217 + cached = selection_context.tokenized_text_cache.get(normalized_value)
  218 + if cached is not None:
  219 + return cached
  220 + tokens = tuple(normalize_query_text(token) for token in simple_tokenize_query(normalized_value) if token)
  221 + selection_context.tokenized_text_cache[normalized_value] = tokens
  222 + return tokens
  223 +
  224 + def _matches_term_tokens(
315 225 self,
316   - candidates: Sequence[_SkuCandidate],
  226 + *,
  227 + term: str,
  228 + value_tokens: Tuple[str, ...],
317 229 selection_context: _SelectionContext,
318   - ) -> Optional[_SkuCandidate]:
319   - for candidate in candidates:
320   - if candidate.intent_values and all(
321   - self._is_text_match(
322   - intent_type,
323   - value,
324   - selection_context,
325   - normalized_value=candidate.normalized_intent_values[intent_type],
326   - )
327   - for intent_type, value in candidate.intent_values.items()
328   - ):
329   - return candidate
330   - return None
  230 + normalized_value: str,
  231 + ) -> bool:
  232 + normalized_term = normalize_query_text(term)
  233 + if not normalized_term:
  234 + return False
  235 + if normalized_term == normalized_value:
  236 + return True
331 237  
332   - def _select_by_embedding(
  238 + term_tokens = self._tokenize_cached(selection_context, normalized_term)
  239 + if not term_tokens or not value_tokens:
  240 + return normalized_term in normalized_value
  241 +
  242 + term_length = len(term_tokens)
  243 + value_length = len(value_tokens)
  244 + if term_length > value_length:
  245 + return False
  246 +
  247 + for start in range(value_length - term_length + 1):
  248 + if value_tokens[start:start + term_length] == term_tokens:
  249 + return True
  250 + return False
  251 +
  252 + def _find_first_text_match(
333 253 self,
334   - candidates: Sequence[_SkuCandidate],
  254 + skus: List[Dict[str, Any]],
  255 + resolved_dimensions: Dict[str, Optional[str]],
335 256 selection_context: _SelectionContext,
336   - ) -> Tuple[Optional[_SkuCandidate], Optional[float]]:
337   - if not candidates:
338   - return None, None
339   - text_encoder = self._get_text_encoder()
340   - if selection_context.query_vector is None or text_encoder is None:
341   - return None, None
342   -
343   - unique_texts = list(
344   - dict.fromkeys(
345   - candidate.normalized_selection_text
346   - for candidate in candidates
347   - if candidate.normalized_selection_text
348   - and candidate.normalized_selection_text not in selection_context.selection_vector_cache
349   - )
350   - )
351   - if unique_texts:
352   - vectors = text_encoder.encode(unique_texts, priority=1)
353   - for key, vector in zip(unique_texts, vectors):
354   - selection_context.selection_vector_cache[key] = (
355   - np.asarray(vector, dtype=np.float32) if vector is not None else None
356   - )
357   -
358   - best_candidate: Optional[_SkuCandidate] = None
359   - best_score: Optional[float] = None
360   - query_vector_array = np.asarray(selection_context.query_vector, dtype=np.float32)
361   - for candidate in candidates:
362   - normalized_text = candidate.normalized_selection_text
363   - if not normalized_text:
364   - continue
  257 + ) -> Optional[Tuple[str, str]]:
  258 + for sku in skus:
  259 + selection_parts: List[str] = []
  260 + seen_parts: set[str] = set()
  261 + matched = True
365 262  
366   - score = selection_context.similarity_cache.get(normalized_text)
367   - if score is None:
368   - candidate_vector = selection_context.selection_vector_cache.get(normalized_text)
369   - if candidate_vector is None:
370   - selection_context.similarity_cache[normalized_text] = None
371   - continue
372   - score = float(np.inner(query_vector_array, candidate_vector))
373   - selection_context.similarity_cache[normalized_text] = score
  263 + for intent_type, field_name in resolved_dimensions.items():
  264 + if not field_name:
  265 + matched = False
  266 + break
374 267  
375   - if score is None:
376   - continue
377   - if best_score is None or score > best_score:
378   - best_candidate = candidate
379   - best_score = score
  268 + raw_value = str(sku.get(field_name) or "").strip()
  269 + normalized_value = self._normalize_cached(selection_context, raw_value)
  270 + if not self._is_text_match(
  271 + intent_type,
  272 + selection_context,
  273 + normalized_value=normalized_value,
  274 + ):
  275 + matched = False
  276 + break
380 277  
381   - return best_candidate, best_score
  278 + if raw_value and normalized_value not in seen_parts:
  279 + seen_parts.add(normalized_value)
  280 + selection_parts.append(raw_value)
  281 +
  282 + if matched:
  283 + return str(sku.get("sku_id") or ""), " ".join(selection_parts).strip()
  284 + return None
382 285  
383 286 def _select_for_source(
384 287 self,
... ... @@ -395,36 +298,29 @@ class StyleSkuSelector:
395 298 if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()):
396 299 return self._empty_decision(resolved_dimensions, matched_stage="unresolved")
397 300  
398   - candidates = self._build_candidates(skus, resolved_dimensions)
399   - if not candidates:
400   - return self._empty_decision(resolved_dimensions, matched_stage="no_candidates")
401   -
402   - text_match = self._find_first_text_match(candidates, selection_context)
403   - if text_match is not None:
404   - return self._build_decision(text_match, resolved_dimensions, matched_stage="text")
405   -
406   - chosen, similarity_score = self._select_by_embedding(candidates, selection_context)
407   - if chosen is None:
  301 + text_match = self._find_first_text_match(skus, resolved_dimensions, selection_context)
  302 + if text_match is None:
408 303 return self._empty_decision(resolved_dimensions, matched_stage="no_match")
409 304 return self._build_decision(
410   - chosen,
411   - resolved_dimensions,
412   - matched_stage="embedding",
413   - similarity_score=similarity_score,
  305 + selected_sku_id=text_match[0],
  306 + selected_text=text_match[1],
  307 + resolved_dimensions=resolved_dimensions,
  308 + matched_stage="text",
414 309 )
415 310  
416 311 @staticmethod
417 312 def _build_decision(
418   - candidate: _SkuCandidate,
  313 + selected_sku_id: str,
  314 + selected_text: str,
419 315 resolved_dimensions: Dict[str, Optional[str]],
420 316 *,
421 317 matched_stage: str,
422 318 similarity_score: Optional[float] = None,
423 319 ) -> SkuSelectionDecision:
424 320 return SkuSelectionDecision(
425   - selected_sku_id=candidate.sku_id or None,
426   - rerank_suffix=str(candidate.selection_text or "").strip(),
427   - selected_text=str(candidate.selection_text or "").strip(),
  321 + selected_sku_id=selected_sku_id or None,
  322 + rerank_suffix=str(selected_text or "").strip(),
  323 + selected_text=str(selected_text or "").strip(),
428 324 matched_stage=matched_stage,
429 325 similarity_score=similarity_score,
430 326 resolved_dimensions=dict(resolved_dimensions),
... ...
search/sku_intent_selector___deprecated.py 0 → 100644
... ... @@ -0,0 +1,452 @@
  1 +"""
  2 +SKU selection for style-intent-aware search results.
  3 +"""
  4 +
  5 +from __future__ import annotations
  6 +
  7 +from dataclasses import dataclass, field
  8 +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
  9 +
  10 +import numpy as np
  11 +
  12 +from query.style_intent import StyleIntentProfile, StyleIntentRegistry
  13 +from query.tokenization import normalize_query_text
  14 +
  15 +
  16 +@dataclass(frozen=True)
  17 +class SkuSelectionDecision:
  18 + selected_sku_id: Optional[str]
  19 + rerank_suffix: str
  20 + selected_text: str
  21 + matched_stage: str
  22 + similarity_score: Optional[float] = None
  23 + resolved_dimensions: Dict[str, Optional[str]] = field(default_factory=dict)
  24 +
  25 + def to_dict(self) -> Dict[str, Any]:
  26 + return {
  27 + "selected_sku_id": self.selected_sku_id,
  28 + "rerank_suffix": self.rerank_suffix,
  29 + "selected_text": self.selected_text,
  30 + "matched_stage": self.matched_stage,
  31 + "similarity_score": self.similarity_score,
  32 + "resolved_dimensions": dict(self.resolved_dimensions),
  33 + }
  34 +
  35 +
  36 +@dataclass
  37 +class _SkuCandidate:
  38 + index: int
  39 + sku_id: str
  40 + sku: Dict[str, Any]
  41 + selection_text: str
  42 + normalized_selection_text: str
  43 + intent_values: Dict[str, str]
  44 + normalized_intent_values: Dict[str, str]
  45 +
  46 +
  47 +@dataclass
  48 +class _SelectionContext:
  49 + query_texts: Tuple[str, ...]
  50 + matched_terms_by_intent: Dict[str, Tuple[str, ...]]
  51 + query_vector: Optional[np.ndarray]
  52 + text_match_cache: Dict[Tuple[str, str], bool] = field(default_factory=dict)
  53 + selection_vector_cache: Dict[str, Optional[np.ndarray]] = field(default_factory=dict)
  54 + similarity_cache: Dict[str, Optional[float]] = field(default_factory=dict)
  55 +
  56 +
  57 +class StyleSkuSelector:
  58 + """Selects the best SKU for an SPU based on detected style intent."""
  59 +
  60 + def __init__(
  61 + self,
  62 + registry: StyleIntentRegistry,
  63 + *,
  64 + text_encoder_getter: Optional[Callable[[], Any]] = None,
  65 + ) -> None:
  66 + self.registry = registry
  67 + self._text_encoder_getter = text_encoder_getter
  68 +
  69 + def prepare_hits(
  70 + self,
  71 + es_hits: List[Dict[str, Any]],
  72 + parsed_query: Any,
  73 + ) -> Dict[str, SkuSelectionDecision]:
  74 + decisions: Dict[str, SkuSelectionDecision] = {}
  75 + style_profile = getattr(parsed_query, "style_intent_profile", None)
  76 + if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active:
  77 + return decisions
  78 +
  79 + selection_context = self._build_selection_context(parsed_query, style_profile)
  80 +
  81 + for hit in es_hits:
  82 + source = hit.get("_source")
  83 + if not isinstance(source, dict):
  84 + continue
  85 +
  86 + decision = self._select_for_source(
  87 + source,
  88 + style_profile=style_profile,
  89 + selection_context=selection_context,
  90 + )
  91 + if decision is None:
  92 + continue
  93 +
  94 + if decision.rerank_suffix:
  95 + hit["_style_rerank_suffix"] = decision.rerank_suffix
  96 + else:
  97 + hit.pop("_style_rerank_suffix", None)
  98 +
  99 + doc_id = hit.get("_id")
  100 + if doc_id is not None:
  101 + decisions[str(doc_id)] = decision
  102 +
  103 + return decisions
  104 +
  105 + def apply_precomputed_decisions(
  106 + self,
  107 + es_hits: List[Dict[str, Any]],
  108 + decisions: Dict[str, SkuSelectionDecision],
  109 + ) -> None:
  110 + if not es_hits or not decisions:
  111 + return
  112 +
  113 + for hit in es_hits:
  114 + doc_id = hit.get("_id")
  115 + if doc_id is None:
  116 + continue
  117 + decision = decisions.get(str(doc_id))
  118 + if decision is None:
  119 + continue
  120 + source = hit.get("_source")
  121 + if not isinstance(source, dict):
  122 + continue
  123 + self._apply_decision_to_source(source, decision)
  124 + if decision.rerank_suffix:
  125 + hit["_style_rerank_suffix"] = decision.rerank_suffix
  126 + else:
  127 + hit.pop("_style_rerank_suffix", None)
  128 +
  129 + def _build_query_texts(
  130 + self,
  131 + parsed_query: Any,
  132 + style_profile: StyleIntentProfile,
  133 + ) -> List[str]:
  134 + texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text]
  135 + if texts:
  136 + return list(dict.fromkeys(texts))
  137 +
  138 + fallbacks: List[str] = []
  139 + for value in (
  140 + getattr(parsed_query, "original_query", None),
  141 + getattr(parsed_query, "query_normalized", None),
  142 + getattr(parsed_query, "rewritten_query", None),
  143 + ):
  144 + normalized = normalize_query_text(value)
  145 + if normalized:
  146 + fallbacks.append(normalized)
  147 + translations = getattr(parsed_query, "translations", {}) or {}
  148 + if isinstance(translations, dict):
  149 + for value in translations.values():
  150 + normalized = normalize_query_text(value)
  151 + if normalized:
  152 + fallbacks.append(normalized)
  153 + return list(dict.fromkeys(fallbacks))
  154 +
  155 + def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]:
  156 + query_vector = getattr(parsed_query, "query_vector", None)
  157 + if query_vector is not None:
  158 + return np.asarray(query_vector, dtype=np.float32)
  159 +
  160 + text_encoder = self._get_text_encoder()
  161 + if text_encoder is None:
  162 + return None
  163 +
  164 + query_text = (
  165 + getattr(parsed_query, "rewritten_query", None)
  166 + or getattr(parsed_query, "query_normalized", None)
  167 + or getattr(parsed_query, "original_query", None)
  168 + )
  169 + if not query_text:
  170 + return None
  171 +
  172 + vectors = text_encoder.encode([query_text], priority=1)
  173 + if vectors is None or len(vectors) == 0 or vectors[0] is None:
  174 + return None
  175 + return np.asarray(vectors[0], dtype=np.float32)
  176 +
  177 + def _build_selection_context(
  178 + self,
  179 + parsed_query: Any,
  180 + style_profile: StyleIntentProfile,
  181 + ) -> _SelectionContext:
  182 + matched_terms_by_intent: Dict[str, List[str]] = {}
  183 + for intent in style_profile.intents:
  184 + normalized_term = normalize_query_text(intent.matched_term)
  185 + if not normalized_term:
  186 + continue
  187 + matched_terms = matched_terms_by_intent.setdefault(intent.intent_type, [])
  188 + if normalized_term not in matched_terms:
  189 + matched_terms.append(normalized_term)
  190 +
  191 + return _SelectionContext(
  192 + query_texts=tuple(self._build_query_texts(parsed_query, style_profile)),
  193 + matched_terms_by_intent={
  194 + intent_type: tuple(terms)
  195 + for intent_type, terms in matched_terms_by_intent.items()
  196 + },
  197 + query_vector=self._get_query_vector(parsed_query),
  198 + )
  199 +
  200 + def _get_text_encoder(self) -> Any:
  201 + if self._text_encoder_getter is None:
  202 + return None
  203 + return self._text_encoder_getter()
  204 +
  205 + def _resolve_dimensions(
  206 + self,
  207 + source: Dict[str, Any],
  208 + style_profile: StyleIntentProfile,
  209 + ) -> Dict[str, Optional[str]]:
  210 + option_names = {
  211 + "option1_value": normalize_query_text(source.get("option1_name")),
  212 + "option2_value": normalize_query_text(source.get("option2_name")),
  213 + "option3_value": normalize_query_text(source.get("option3_name")),
  214 + }
  215 + resolved: Dict[str, Optional[str]] = {}
  216 + for intent in style_profile.intents:
  217 + if intent.intent_type in resolved:
  218 + continue
  219 + aliases = set(intent.dimension_aliases or self.registry.get_dimension_aliases(intent.intent_type))
  220 + matched_field = None
  221 + for field_name, option_name in option_names.items():
  222 + if option_name and option_name in aliases:
  223 + matched_field = field_name
  224 + break
  225 + resolved[intent.intent_type] = matched_field
  226 + return resolved
  227 +
  228 + def _build_candidates(
  229 + self,
  230 + skus: List[Dict[str, Any]],
  231 + resolved_dimensions: Dict[str, Optional[str]],
  232 + ) -> List[_SkuCandidate]:
  233 + if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()):
  234 + return []
  235 +
  236 + candidates: List[_SkuCandidate] = []
  237 + for index, sku in enumerate(skus):
  238 + intent_values: Dict[str, str] = {}
  239 + normalized_intent_values: Dict[str, str] = {}
  240 + for intent_type, field_name in resolved_dimensions.items():
  241 + if not field_name:
  242 + continue
  243 + raw = str(sku.get(field_name) or "").strip()
  244 + intent_values[intent_type] = raw
  245 + normalized_intent_values[intent_type] = normalize_query_text(raw)
  246 +
  247 + selection_parts: List[str] = []
  248 + norm_parts: List[str] = []
  249 + seen: set[str] = set()
  250 + for intent_type, raw in intent_values.items():
  251 + nv = normalized_intent_values[intent_type]
  252 + if not nv or nv in seen:
  253 + continue
  254 + seen.add(nv)
  255 + selection_parts.append(raw)
  256 + norm_parts.append(nv)
  257 +
  258 + selection_text = " ".join(selection_parts).strip()
  259 + normalized_selection_text = " ".join(norm_parts).strip()
  260 + candidates.append(
  261 + _SkuCandidate(
  262 + index=index,
  263 + sku_id=str(sku.get("sku_id") or ""),
  264 + sku=sku,
  265 + selection_text=selection_text,
  266 + normalized_selection_text=normalized_selection_text,
  267 + intent_values=intent_values,
  268 + normalized_intent_values=normalized_intent_values,
  269 + )
  270 + )
  271 + return candidates
  272 +
  273 + @staticmethod
  274 + def _empty_decision(
  275 + resolved_dimensions: Dict[str, Optional[str]],
  276 + matched_stage: str,
  277 + ) -> SkuSelectionDecision:
  278 + return SkuSelectionDecision(
  279 + selected_sku_id=None,
  280 + rerank_suffix="",
  281 + selected_text="",
  282 + matched_stage=matched_stage,
  283 + resolved_dimensions=dict(resolved_dimensions),
  284 + )
  285 +
  286 + def _is_text_match(
  287 + self,
  288 + intent_type: str,
  289 + value: str,
  290 + selection_context: _SelectionContext,
  291 + *,
  292 + normalized_value: Optional[str] = None,
  293 + ) -> bool:
  294 + if normalized_value is None:
  295 + normalized_value = normalize_query_text(value)
  296 + if not normalized_value:
  297 + return False
  298 +
  299 + cache_key = (intent_type, normalized_value)
  300 + cached = selection_context.text_match_cache.get(cache_key)
  301 + if cached is not None:
  302 + return cached
  303 +
  304 + matched_terms = selection_context.matched_terms_by_intent.get(intent_type, ())
  305 + has_term_match = any(term in normalized_value for term in matched_terms if term)
  306 + query_contains_value = any(
  307 + normalized_value in query_text
  308 + for query_text in selection_context.query_texts
  309 + )
  310 + matched = bool(has_term_match or query_contains_value)
  311 + selection_context.text_match_cache[cache_key] = matched
  312 + return matched
  313 +
  314 + def _find_first_text_match(
  315 + self,
  316 + candidates: Sequence[_SkuCandidate],
  317 + selection_context: _SelectionContext,
  318 + ) -> Optional[_SkuCandidate]:
  319 + for candidate in candidates:
  320 + if candidate.intent_values and all(
  321 + self._is_text_match(
  322 + intent_type,
  323 + value,
  324 + selection_context,
  325 + normalized_value=candidate.normalized_intent_values[intent_type],
  326 + )
  327 + for intent_type, value in candidate.intent_values.items()
  328 + ):
  329 + return candidate
  330 + return None
  331 +
  332 + def _select_by_embedding(
  333 + self,
  334 + candidates: Sequence[_SkuCandidate],
  335 + selection_context: _SelectionContext,
  336 + ) -> Tuple[Optional[_SkuCandidate], Optional[float]]:
  337 + if not candidates:
  338 + return None, None
  339 + text_encoder = self._get_text_encoder()
  340 + if selection_context.query_vector is None or text_encoder is None:
  341 + return None, None
  342 +
  343 + unique_texts = list(
  344 + dict.fromkeys(
  345 + candidate.normalized_selection_text
  346 + for candidate in candidates
  347 + if candidate.normalized_selection_text
  348 + and candidate.normalized_selection_text not in selection_context.selection_vector_cache
  349 + )
  350 + )
  351 + if unique_texts:
  352 + vectors = text_encoder.encode(unique_texts, priority=1)
  353 + for key, vector in zip(unique_texts, vectors):
  354 + selection_context.selection_vector_cache[key] = (
  355 + np.asarray(vector, dtype=np.float32) if vector is not None else None
  356 + )
  357 +
  358 + best_candidate: Optional[_SkuCandidate] = None
  359 + best_score: Optional[float] = None
  360 + query_vector_array = np.asarray(selection_context.query_vector, dtype=np.float32)
  361 + for candidate in candidates:
  362 + normalized_text = candidate.normalized_selection_text
  363 + if not normalized_text:
  364 + continue
  365 +
  366 + score = selection_context.similarity_cache.get(normalized_text)
  367 + if score is None:
  368 + candidate_vector = selection_context.selection_vector_cache.get(normalized_text)
  369 + if candidate_vector is None:
  370 + selection_context.similarity_cache[normalized_text] = None
  371 + continue
  372 + score = float(np.inner(query_vector_array, candidate_vector))
  373 + selection_context.similarity_cache[normalized_text] = score
  374 +
  375 + if score is None:
  376 + continue
  377 + if best_score is None or score > best_score:
  378 + best_candidate = candidate
  379 + best_score = score
  380 +
  381 + return best_candidate, best_score
  382 +
  383 + def _select_for_source(
  384 + self,
  385 + source: Dict[str, Any],
  386 + *,
  387 + style_profile: StyleIntentProfile,
  388 + selection_context: _SelectionContext,
  389 + ) -> Optional[SkuSelectionDecision]:
  390 + skus = source.get("skus")
  391 + if not isinstance(skus, list) or not skus:
  392 + return None
  393 +
  394 + resolved_dimensions = self._resolve_dimensions(source, style_profile)
  395 + if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()):
  396 + return self._empty_decision(resolved_dimensions, matched_stage="unresolved")
  397 +
  398 + candidates = self._build_candidates(skus, resolved_dimensions)
  399 + if not candidates:
  400 + return self._empty_decision(resolved_dimensions, matched_stage="no_candidates")
  401 +
  402 + text_match = self._find_first_text_match(candidates, selection_context)
  403 + if text_match is not None:
  404 + return self._build_decision(text_match, resolved_dimensions, matched_stage="text")
  405 +
  406 + chosen, similarity_score = self._select_by_embedding(candidates, selection_context)
  407 + if chosen is None:
  408 + return self._empty_decision(resolved_dimensions, matched_stage="no_match")
  409 + return self._build_decision(
  410 + chosen,
  411 + resolved_dimensions,
  412 + matched_stage="embedding",
  413 + similarity_score=similarity_score,
  414 + )
  415 +
  416 + @staticmethod
  417 + def _build_decision(
  418 + candidate: _SkuCandidate,
  419 + resolved_dimensions: Dict[str, Optional[str]],
  420 + *,
  421 + matched_stage: str,
  422 + similarity_score: Optional[float] = None,
  423 + ) -> SkuSelectionDecision:
  424 + return SkuSelectionDecision(
  425 + selected_sku_id=candidate.sku_id or None,
  426 + rerank_suffix=str(candidate.selection_text or "").strip(),
  427 + selected_text=str(candidate.selection_text or "").strip(),
  428 + matched_stage=matched_stage,
  429 + similarity_score=similarity_score,
  430 + resolved_dimensions=dict(resolved_dimensions),
  431 + )
  432 +
  433 + @staticmethod
  434 + def _apply_decision_to_source(source: Dict[str, Any], decision: SkuSelectionDecision) -> None:
  435 + skus = source.get("skus")
  436 + if not isinstance(skus, list) or not skus or not decision.selected_sku_id:
  437 + return
  438 +
  439 + selected_index = None
  440 + for index, sku in enumerate(skus):
  441 + if str(sku.get("sku_id") or "") == decision.selected_sku_id:
  442 + selected_index = index
  443 + break
  444 + if selected_index is None:
  445 + return
  446 +
  447 + selected_sku = skus.pop(selected_index)
  448 + skus.insert(0, selected_sku)
  449 +
  450 + image_src = selected_sku.get("image_src") or selected_sku.get("imageSrc")
  451 + if image_src:
  452 + source["image_url"] = image_src
... ...
tests/test_rerank_client.py
... ... @@ -118,3 +118,34 @@ def test_fuse_scores_and_resort_uses_configurable_fusion_params():
118 118 by_id = {h["_id"]: h for h in hits}
119 119 assert isclose(by_id["a"]["_fused_score"], 1.0, rel_tol=1e-9)
120 120 assert isclose(by_id["b"]["_fused_score"], 0.0, rel_tol=1e-9)
  121 +
  122 +
  123 +def test_fuse_scores_and_resort_boosts_hits_with_selected_sku():
  124 + hits = [
  125 + {
  126 + "_id": "style-selected",
  127 + "_score": 1.0,
  128 + "_style_rerank_suffix": "Blue XL",
  129 + "matched_queries": {"base_query": 1.0, "knn_query": 0.0},
  130 + },
  131 + {
  132 + "_id": "plain",
  133 + "_score": 1.0,
  134 + "matched_queries": {"base_query": 1.0, "knn_query": 0.0},
  135 + },
  136 + ]
  137 +
  138 + debug = fuse_scores_and_resort(
  139 + hits,
  140 + [1.0, 1.0],
  141 + style_intent_selected_sku_boost=1.2,
  142 + debug=True,
  143 + )
  144 +
  145 + by_id = {h["_id"]: h for h in hits}
  146 + assert isclose(by_id["style-selected"]["_fused_score"], by_id["plain"]["_fused_score"] * 1.2, rel_tol=1e-9)
  147 + assert by_id["style-selected"]["_style_intent_selected_sku_boost"] == 1.2
  148 + assert by_id["plain"]["_style_intent_selected_sku_boost"] == 1.0
  149 + assert [h["_id"] for h in hits] == ["style-selected", "plain"]
  150 + assert debug[0]["style_intent_selected_sku"] is True
  151 + assert debug[0]["style_intent_selected_sku_boost"] == 1.2
... ...
tests/test_reranker_qwen3_gguf_backend.py 0 → 100644
... ... @@ -0,0 +1,119 @@
  1 +from __future__ import annotations
  2 +
  3 +import sys
  4 +import types
  5 +
  6 +from reranker.backends import get_rerank_backend
  7 +from reranker.backends.qwen3_gguf import Qwen3GGUFRerankerBackend
  8 +
  9 +
  10 +class _FakeLlama:
  11 + def __init__(self, model_path: str | None = None, **kwargs):
  12 + self.model_path = model_path
  13 + self.kwargs = kwargs
  14 + self.eval_logits = []
  15 + self._tokens = []
  16 + self.eval_call_count = 0
  17 +
  18 + @classmethod
  19 + def from_pretrained(cls, repo_id: str, filename: str, local_dir=None, cache_dir=None, **kwargs):
  20 + inst = cls(model_path=f"{repo_id}/{filename}", **kwargs)
  21 + inst.repo_id = repo_id
  22 + inst.filename = filename
  23 + inst.local_dir = local_dir
  24 + inst.cache_dir = cache_dir
  25 + return inst
  26 +
  27 + def tokenize(self, text: bytes, add_bos: bool = False, special: bool = False):
  28 + raw = text.decode("utf-8")
  29 + if raw == "yes":
  30 + return [1]
  31 + if raw == "no":
  32 + return [2]
  33 + return [10 + (ord(ch) % 17) for ch in raw]
  34 +
  35 + def reset(self):
  36 + self._tokens = []
  37 + return None
  38 +
  39 + def eval(self, prompt_tokens):
  40 + self.eval_call_count += 1
  41 + self._tokens.extend(prompt_tokens)
  42 + pos = float(sum(self._tokens) % 11) + 3.0
  43 + neg = 1.0
  44 + logits = [0.0] * 64
  45 + logits[1] = pos
  46 + logits[2] = neg
  47 + self.eval_logits = [logits]
  48 +
  49 + def save_state(self):
  50 + return list(self._tokens)
  51 +
  52 + def load_state(self, state):
  53 + self._tokens = list(state)
  54 +
  55 +
  56 +def _install_fake_llama_cpp(monkeypatch):
  57 + fake_module = types.SimpleNamespace(Llama=_FakeLlama)
  58 + monkeypatch.setitem(sys.modules, "llama_cpp", fake_module)
  59 +
  60 +
  61 +def test_qwen3_gguf_backend_factory_loads(monkeypatch):
  62 + _install_fake_llama_cpp(monkeypatch)
  63 + backend = get_rerank_backend(
  64 + "qwen3_gguf",
  65 + {
  66 + "repo_id": "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF",
  67 + "filename": "*Q8_0.gguf",
  68 + "enable_warmup": False,
  69 + },
  70 + )
  71 + assert isinstance(backend, Qwen3GGUFRerankerBackend)
  72 + assert backend._backend_name == "qwen3_gguf"
  73 +
  74 +
  75 +def test_qwen3_gguf_06b_backend_factory_loads(monkeypatch):
  76 + _install_fake_llama_cpp(monkeypatch)
  77 + backend = get_rerank_backend(
  78 + "qwen3_gguf_06b",
  79 + {
  80 + "enable_warmup": False,
  81 + },
  82 + )
  83 + assert isinstance(backend, Qwen3GGUFRerankerBackend)
  84 + assert backend._backend_name == "qwen3_gguf_06b"
  85 + assert backend._repo_id == "ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF"
  86 + assert backend._filename == "qwen3-reranker-0.6b-q8_0.gguf"
  87 +
  88 +
  89 +def test_qwen3_gguf_backend_score_with_meta_dedup_and_restore(monkeypatch):
  90 + _install_fake_llama_cpp(monkeypatch)
  91 + backend = Qwen3GGUFRerankerBackend(
  92 + {
  93 + "repo_id": "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF",
  94 + "filename": "*Q8_0.gguf",
  95 + "enable_warmup": False,
  96 + "infer_batch_size": 2,
  97 + "sort_by_doc_length": True,
  98 + "reuse_query_state": True,
  99 + }
  100 + )
  101 +
  102 + scores, meta = backend.score_with_meta(
  103 + query="wireless mouse",
  104 + docs=["doc-a", "doc-b", "doc-a", "", " ", None],
  105 + normalize=True,
  106 + )
  107 +
  108 + assert len(scores) == 6
  109 + assert scores[0] == scores[2]
  110 + assert scores[0] > 0.5
  111 + assert scores[1] > 0.5
  112 + assert scores[3:] == [0.0, 0.0, 0.0]
  113 + assert meta["input_docs"] == 6
  114 + assert meta["usable_docs"] == 3
  115 + assert meta["unique_docs"] == 2
  116 + assert meta["backend"] == "qwen3_gguf"
  117 + assert meta["inference_batches"] == 1
  118 + assert meta["reuse_query_state"] is True
  119 + assert backend._llm.eval_call_count == 3
... ...
tests/test_search_rerank_window.py
... ... @@ -63,6 +63,7 @@ def _build_style_intent_profile(intent_type: str, canonical_value: str, *dimensi
63 63 canonical_value=canonical_value,
64 64 matched_term=canonical_value,
65 65 matched_query_text=canonical_value,
  66 + attribute_terms=(canonical_value,),
66 67 dimension_aliases=tuple(aliases),
67 68 ),
68 69 )
... ...
tests/test_sku_intent_selector.py 0 → 100644
... ... @@ -0,0 +1,197 @@
  1 +from types import SimpleNamespace
  2 +
  3 +from config import QueryConfig
  4 +from query.style_intent import DetectedStyleIntent, StyleIntentProfile, StyleIntentRegistry
  5 +from search.sku_intent_selector import StyleSkuSelector
  6 +
  7 +
  8 +def test_style_sku_selector_matches_first_sku_by_attribute_terms():
  9 + registry = StyleIntentRegistry.from_query_config(
  10 + QueryConfig(
  11 + style_intent_terms={
  12 + "color": [{"en_terms": ["navy"], "zh_terms": ["藏青"], "attribute_terms": ["navy"]}],
  13 + "size": [{"en_terms": ["xl"], "zh_terms": ["加大码"], "attribute_terms": ["x-large"]}],
  14 + },
  15 + style_intent_dimension_aliases={
  16 + "color": ["color", "颜色"],
  17 + "size": ["size", "尺码"],
  18 + },
  19 + )
  20 + )
  21 + selector = StyleSkuSelector(registry)
  22 + parsed_query = SimpleNamespace(
  23 + style_intent_profile=StyleIntentProfile(
  24 + intents=(
  25 + DetectedStyleIntent(
  26 + intent_type="color",
  27 + canonical_value="navy",
  28 + matched_term="藏青",
  29 + matched_query_text="藏青",
  30 + attribute_terms=("navy",),
  31 + dimension_aliases=("color", "颜色"),
  32 + ),
  33 + DetectedStyleIntent(
  34 + intent_type="size",
  35 + canonical_value="x-large",
  36 + matched_term="xl",
  37 + matched_query_text="xl",
  38 + attribute_terms=("x-large",),
  39 + dimension_aliases=("size", "尺码"),
  40 + ),
  41 + ),
  42 + )
  43 + )
  44 + source = {
  45 + "option1_name": "Color",
  46 + "option2_name": "Size",
  47 + "skus": [
  48 + {"sku_id": "1", "option1_value": "Black", "option2_value": "M"},
  49 + {"sku_id": "2", "option1_value": "Navy Blue", "option2_value": "X-Large", "image_src": "matched.jpg"},
  50 + {"sku_id": "3", "option1_value": "Navy", "option2_value": "XL"},
  51 + ],
  52 + }
  53 + hits = [{"_id": "spu-1", "_source": source}]
  54 +
  55 + decisions = selector.prepare_hits(hits, parsed_query)
  56 + decision = decisions["spu-1"]
  57 +
  58 + assert decision.selected_sku_id == "2"
  59 + assert decision.selected_text == "Navy Blue X-Large"
  60 + assert decision.matched_stage == "text"
  61 +
  62 + selector.apply_precomputed_decisions(hits, decisions)
  63 +
  64 + assert source["skus"][0]["sku_id"] == "2"
  65 + assert source["image_url"] == "matched.jpg"
  66 +
  67 +
  68 +def test_style_sku_selector_returns_no_match_without_attribute_contains():
  69 + registry = StyleIntentRegistry.from_query_config(
  70 + QueryConfig(
  71 + style_intent_terms={
  72 + "color": [{"en_terms": ["beige"], "zh_terms": ["米色"], "attribute_terms": ["beige"]}],
  73 + },
  74 + style_intent_dimension_aliases={"color": ["color", "颜色"]},
  75 + )
  76 + )
  77 + selector = StyleSkuSelector(registry)
  78 + parsed_query = SimpleNamespace(
  79 + style_intent_profile=StyleIntentProfile(
  80 + intents=(
  81 + DetectedStyleIntent(
  82 + intent_type="color",
  83 + canonical_value="beige",
  84 + matched_term="米色",
  85 + matched_query_text="米色",
  86 + attribute_terms=("beige",),
  87 + dimension_aliases=("color", "颜色"),
  88 + ),
  89 + ),
  90 + )
  91 + )
  92 + hits = [{
  93 + "_id": "spu-1",
  94 + "_source": {
  95 + "option1_name": "Color",
  96 + "skus": [
  97 + {"sku_id": "1", "option1_value": "Khaki"},
  98 + {"sku_id": "2", "option1_value": "Light Brown"},
  99 + ],
  100 + },
  101 + }]
  102 +
  103 + decisions = selector.prepare_hits(hits, parsed_query)
  104 +
  105 + assert decisions["spu-1"].selected_sku_id is None
  106 + assert decisions["spu-1"].matched_stage == "no_match"
  107 +
  108 +
  109 +def test_is_text_match_uses_token_boundaries_for_sizes():
  110 + registry = StyleIntentRegistry.from_query_config(
  111 + QueryConfig(
  112 + style_intent_terms={
  113 + "size": [{"en_terms": ["l"], "zh_terms": ["大码"], "attribute_terms": ["l"]}],
  114 + },
  115 + style_intent_dimension_aliases={"size": ["size", "尺码"]},
  116 + )
  117 + )
  118 + selector = StyleSkuSelector(registry)
  119 + style_profile = StyleIntentProfile(
  120 + intents=(
  121 + DetectedStyleIntent(
  122 + intent_type="size",
  123 + canonical_value="l",
  124 + matched_term="l",
  125 + matched_query_text="l",
  126 + attribute_terms=("l",),
  127 + dimension_aliases=("size", "尺码"),
  128 + ),
  129 + ),
  130 + )
  131 + selection_context = selector._build_selection_context(style_profile)
  132 +
  133 + assert selector._is_text_match("size", selection_context, normalized_value="l")
  134 + assert not selector._is_text_match("size", selection_context, normalized_value="xl")
  135 + assert not selector._is_text_match("size", selection_context, normalized_value="xxl")
  136 +
  137 +
  138 +def test_is_text_match_handles_punctuation_and_descriptive_attribute_values():
  139 + registry = StyleIntentRegistry.from_query_config(
  140 + QueryConfig(
  141 + style_intent_terms={
  142 + "color": [{"en_terms": ["blue"], "zh_terms": ["蓝色"], "attribute_terms": ["blue"]}],
  143 + "style": [{"en_terms": ["off-white"], "zh_terms": ["米白"], "attribute_terms": ["off-white"]}],
  144 + "accessory": [{"en_terms": ["headscarf"], "zh_terms": ["头巾"], "attribute_terms": ["headscarf"]}],
  145 + "size": [{"en_terms": ["2xl"], "zh_terms": ["2xl"], "attribute_terms": ["2xl"]}],
  146 + },
  147 + style_intent_dimension_aliases={
  148 + "color": ["color", "颜色"],
  149 + "style": ["style", "风格"],
  150 + "accessory": ["accessory", "配饰"],
  151 + "size": ["size", "尺码"],
  152 + },
  153 + )
  154 + )
  155 + selector = StyleSkuSelector(registry)
  156 + style_profile = StyleIntentProfile(
  157 + intents=(
  158 + DetectedStyleIntent(
  159 + intent_type="color",
  160 + canonical_value="blue",
  161 + matched_term="blue",
  162 + matched_query_text="blue",
  163 + attribute_terms=("blue",),
  164 + dimension_aliases=("color", "颜色"),
  165 + ),
  166 + DetectedStyleIntent(
  167 + intent_type="style",
  168 + canonical_value="off-white",
  169 + matched_term="off-white",
  170 + matched_query_text="off-white",
  171 + attribute_terms=("off-white",),
  172 + dimension_aliases=("style", "风格"),
  173 + ),
  174 + DetectedStyleIntent(
  175 + intent_type="accessory",
  176 + canonical_value="headscarf",
  177 + matched_term="headscarf",
  178 + matched_query_text="headscarf",
  179 + attribute_terms=("headscarf",),
  180 + dimension_aliases=("accessory", "配饰"),
  181 + ),
  182 + DetectedStyleIntent(
  183 + intent_type="size",
  184 + canonical_value="2xl",
  185 + matched_term="2xl",
  186 + matched_query_text="2xl",
  187 + attribute_terms=("2xl",),
  188 + dimension_aliases=("size", "尺码"),
  189 + ),
  190 + ),
  191 + )
  192 + selection_context = selector._build_selection_context(style_profile)
  193 +
  194 + assert selector._is_text_match("color", selection_context, normalized_value="gray blue")
  195 + assert selector._is_text_match("style", selection_context, normalized_value="off-white/lined")
  196 + assert selector._is_text_match("accessory", selection_context, normalized_value="army green + headscarf")
  197 + assert selector._is_text_match("size", selection_context, normalized_value="2xl recommended 65-70kg")
... ...
tests/test_style_intent.py
... ... @@ -7,8 +7,8 @@ from query.style_intent import StyleIntentDetector, StyleIntentRegistry
7 7 def test_style_intent_detector_matches_original_and_translated_queries():
8 8 query_config = QueryConfig(
9 9 style_intent_terms={
10   - "color": [["black", "黑色", "black"]],
11   - "size": [["xl", "x-large", "加大码"]],
  10 + "color": [{"en_terms": ["black"], "zh_terms": ["黑色"], "attribute_terms": ["black"]}],
  11 + "size": [{"en_terms": ["xl", "x-large"], "zh_terms": ["加大码"], "attribute_terms": ["x-large"]}],
12 12 },
13 13 style_intent_dimension_aliases={
14 14 "color": ["color", "颜色"],
... ... @@ -31,5 +31,30 @@ def test_style_intent_detector_matches_original_and_translated_queries():
31 31  
32 32 assert profile.is_active is True
33 33 assert profile.get_canonical_values("color") == {"black"}
34   - assert profile.get_canonical_values("size") == {"xl"}
  34 + assert profile.get_canonical_values("size") == {"x-large"}
35 35 assert len(profile.query_variants) == 2
  36 +
  37 +
  38 +def test_style_intent_detector_uses_original_query_when_language_translation_missing():
  39 + query_config = QueryConfig(
  40 + style_intent_terms={
  41 + "color": [{"en_terms": ["black"], "zh_terms": ["黑色"], "attribute_terms": ["black"]}],
  42 + },
  43 + style_intent_dimension_aliases={"color": ["color", "颜色"]},
  44 + )
  45 + detector = StyleIntentDetector(
  46 + StyleIntentRegistry.from_query_config(query_config),
  47 + tokenizer=lambda text: text.split(),
  48 + )
  49 +
  50 + parsed_query = SimpleNamespace(
  51 + original_query="black dress",
  52 + query_normalized="black dress",
  53 + rewritten_query="black dress",
  54 + translations={"zh": "连衣裙"},
  55 + )
  56 +
  57 + profile = detector.detect(parsed_query)
  58 +
  59 + assert profile.get_canonical_values("color") == {"black"}
  60 + assert profile.intents[0].attribute_terms == ("black",)
... ...