Commit 6e3e677078c096cef77127bd0426099e16159cef

Authored by tangwang
1 parent 9f33fe3c

suggest文档维护

docs/issues/issue-2026-04-06-推理优化-2.md deleted
... ... @@ -1,69 +0,0 @@
1   -这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查:
2   -
3   -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。
4   -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。
5   -
6   -示例prompt和对应的9个分类词:
7   -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query}
8   -
9   -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。
10   -
11   -使用Tesla T4,因此:
12   -1. 使用FP16。不用 BF16 作为主路径
13   -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention
14   -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。
15   -
16   -
17   -主要考虑优化方向为:
18   -hidden_last -> N-class scorer -> argmax
19   -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV
20   -去 full vocab logits
21   -去 decode / constrained decode
22   -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8)
23   -专用 tail kernel(输出 N 类原始分数)
24   -
25   -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化:
26   -15.1 预分配
27   -对每个 bucket 预分配:
28   -input ids buffer
29   -attention scratch buffer
30   -output hidden buffer
31   -class id buffer
32   -15.2 Graph capture
33   -
34   -每个 bucket 预先 capture 一张图:
35   -embedding
36   -transformer 主干
37   -compact last-state
38   -9-way tail kernel
39   -
40   -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。
41   -
42   -你有sudo权限,你可以执行为本项目安装自己的环境
43   -
44   -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型)
45   -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。
46   -
47   -
48   -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。
49   -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。
50   -注意使用fp16版本,不要使用量化版本。
51   -
52   -上面的需求你已经满足了大部分,
53   -但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。
54   -需要通过多 token 标签做极致的性能优化,避免串行decode。
55   -一个考虑的方向是:
56   -将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch)
57   -
58   -只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。
59   -
60   -所有后续聚合均为向量化操作(GPU 上几乎不花时间)。
61   -
62   -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。
63   -
64   -参考文档:
65   -
66   -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/
67   -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/
68   -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html
69   -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq
70 0 \ No newline at end of file
docs/issues/issue-2026-04-06-推理优化-3.md deleted
... ... @@ -1,98 +0,0 @@
1   -这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查:
2   -
3   -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。
4   -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。
5   -
6   -示例prompt和对应的9个分类词:
7   -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query}
8   -
9   -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。
10   -
11   -使用Tesla T4,因此:
12   -1. 使用FP16。不用 BF16 作为主路径
13   -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention
14   -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。
15   -
16   -
17   -主要考虑优化方向为:
18   -hidden_last -> N-class scorer -> argmax
19   -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV
20   -去 full vocab logits
21   -去 decode / constrained decode
22   -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8)
23   -专用 tail kernel(输出 N 类原始分数)
24   -
25   -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化:
26   -15.1 预分配
27   -对每个 bucket 预分配:
28   -input ids buffer
29   -attention scratch buffer
30   -output hidden buffer
31   -class id buffer
32   -15.2 Graph capture
33   -
34   -每个 bucket 预先 capture 一张图:
35   -embedding
36   -transformer 主干
37   -compact last-state
38   -9-way tail kernel
39   -
40   -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。
41   -
42   -你有sudo权限,你可以执行为本项目安装自己的环境
43   -
44   -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型)
45   -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。
46   -
47   -
48   -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。
49   -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。
50   -注意使用fp16版本,不要使用量化版本。
51   -
52   -上面的需求你已经满足了大部分,
53   -但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。
54   -需要通过多 token 标签做极致的性能优化,避免串行decode。
55   -一个考虑的方向是:
56   -将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch)
57   -
58   -只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。
59   -
60   -所有后续聚合均为向量化操作(GPU 上几乎不花时间)。
61   -
62   -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。
63   -
64   -参考文档:
65   -
66   -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/
67   -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/
68   -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html
69   -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq
70   -
71   -
72   -现在还有以下问题:
73   -1. 请满足:
74   -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。
75   -
76   -2. 打分结果不对:不管输入什么都输出jeans,输入裙子也输出jeans,只有精确的输入skirt才能输出skirt。
77   -
78   -3. 现在是一个prompt,我需要增加6个prompt,都进行配置化,每次输入一个query、对7个prompt进行批量的推理,对每个prompt都输出词典中每个分类的打分分布
79   -人群
80   -Analyze the target user group of the given query. Output exactly one label from: boy, girl, man, woman, pregnant. Output nothing else query:{query}
81   -
82   -尺寸
83   -Analyze the size intent of the given query. Output exactly one label from: xs, s, m, l, xl, xxl, plus, petite. Output nothing else query:{query}
84   -
85   -领口
86   -Analyze the neckline type of the given query. Output exactly one label from: round, v, turtleneck, collared, scoop, offshoulder. Output nothing else query:{query}
87   -
88   -袖长类型
89   -Analyze the sleeve length of the given query. Output exactly one label from: long, short, sleeveless, half, threequarter, cap. Output nothing else query:{query}
90   -
91   -上衣长度类型
92   -Analyze the top length of the given query. Output exactly one label from: long, short, regular, cropped, tunic, midi. Output nothing else query:{query}
93   -
94   -面料
95   -Analyze the fabric material of the given query. Output exactly one label from: cotton, linen, silk, wool, polyester, denim, leather, chiffon, fleece. Output nothing else query:{query}
96   -
97   -
98   -注意要使用测试用例进行测试。包括打分结果是否符合预期、性能测试。
docs/issues/issue-2026-04-06-推理优化-4.md deleted
... ... @@ -1,4 +0,0 @@
1   -
2   -1. “仓库启动时会为所有 batch_sizes x continuation_buckets x prompt 预建 bucket”,之前可能是根据“极致的性能要求”、“不要做任何懒加载,确保真实请求发生时得到极致的响应时间”所做的设计,这显然太过了,我是希望一些基本的可以事先加载的应该先加载,但是牺牲巨大的显存占用来换取微弱的耗时提升是不提倡的,请你站在更高的角度,理会我的需求(先加载好模型、跨session的KV cache、并针对特殊用法 即score方式而不是逐步decode方式,来极致的优化性能,用于对线上的单个query,以最短耗时得到7个prompt的分类结果)
3   -
4   -2. 现在的 7 个 prompt推理是串行的,MultiPromptRunner.score_query() 里就是 for runner in self.runners 一个一个跑,需要把执行模型改成“按 prompt 分组批量并行”,但是现在有 2 个 fast path 和 5 个 multi-token path,是各合成一次 forward,还是有可能合成一个?因为multi-token也可以一次线prefill进去,是否能做到跟fast path同级别的性能?请你站在更高的角度进行思考,保证性能的同时降低复杂性。
5 0 \ No newline at end of file
docs/issues/issue-2026-04-06-推理优化-重建.md deleted
... ... @@ -1,1003 +0,0 @@
1   -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。
2   -先专注于推理的优化,最后再考虑服务化,支持一定程度的并发(比如4)的请求,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,使用N个prompt进行N个维度的分类,对于每个prompt,推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。
3   -
4   -下面有一些参考技术资料,但是你并不需要严格,你应该有一定的灵活度,来追求极致的性能。
5   -
6   -在 Tesla T4 上,用 3B 到 6B 级别的开源 decoder-only 基座模型做 query 分类。
7   -启动时完成 tokenizer、权重、prefix cache 和共享执行器准备工作。
8   -每次输入一个 query,输出每个 prompt 下每个 label 的分数分布,以及预测耗时和阶段耗时。
9   -不走通用生成路径,不做 decode,不取 full vocab logits,不做 constrained decode。
10   -对 multi-token label 做专门优化,避免 Python 侧串行 decode。
11   -prompt 和 label 集合必须可配置(目前只有两个,以后我会加到8个,每次请求输入一个query,并行的调用8个prompt进行推理得到得分最高的label:
12   -
13   -
14   -prompts:
15   - - name: category
16   - prompt_template: |
17   - <|im_start|>system
18   - Analyze the category intent. Output exactly one label from [none, dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other]. Use 'none' if no category intent.<|im_end|>
19   - <|im_start|>user
20   - query: {query}<|im_end|>
21   - <|im_start|>assistant
22   - label:
23   - label_prefix: " "
24   - labels: [dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other, none]
25   -
26   - - name: audience
27   - prompt_template: |
28   - <|im_start|>system
29   - Analyze the target user group. Output exactly one label from [none, boy, girl, man, woman, pregnant]. Use 'none' if no audience mentioned.<|im_end|>
30   - <|im_start|>user
31   - query: {query}<|im_end|>
32   - <|im_start|>assistant
33   - label:
34   - label_prefix: " "
35   - labels: [boy, girl, man, woman, pregnant, none]
36   -
37   -
38   -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟。
39   -
40   -主要考虑优化方向为:
41   -1. hidden_last -> N-class scorer -> argmax
42   -2. 参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV
43   -3. 去 full vocab logits
44   -4. 去 decode / constrained decode
45   -5. 专用 tail kernel(输出 N 类原始分数)
46   -6. 配置的N个 prompt推理要并行推理(2-8个)
47   -7. 使用Tesla T4,因此不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention
48   -多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。
49   -
50   -
51   -你有sudo权限,你可以执行为本项目安装自己的环境
52   -
53   -使用Qwen/Qwen3-8B的Q4或Q8模型,具体用哪个版本,请你查找huggingface相关资料,选择合适的版本完成部署,并进行推理耗时的测试。
54   -
55   -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。
56   -
57   -一个重要的问题:一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况。
58   -需要通过多 token 标签做极致的性能优化,避免串行decode。
59   -我们最终目的是得到哪个label的得分最高,不一定要精确的概率,计算log P(id1 | query, prompt) + log P(id2 | query, prompt, id1)有可能导致难以优化性能,精确的概率是可以考虑放弃的,要清楚我们的最终目的,达到分类的目的即可,只要得到分类,优先考虑性能,精确的概率可以放下。
60   -如何通过一次模型 forward处理包括多token label的整个 batch,是你需要探索的问题。
61   -
62   -单 token fast path 的做法比较确定: last_hidden -> small class scorer -> argmax。 只取目标 label 对应 LM head 行,不做 full vocab 输出。
63   -multi-token 怎么做需要搜索相关资料进行考量,最好要做到跟单token开销相同(放弃精确的log-prob的前提下。但是:多token和单token的label的打分的对比,一定要是可比的才能正确的分类,兼顾性能和打分的准确性)
64   -
65   -还需要增加一个配置:force_single_token_labels,所有 label 都按首 token 处理,因为,如果各个label收token不同,那么可以近似的认为首token打分代表整个label打分。
66   -你需要找到多label打分性能和准确性上面的最佳实践。同时也支持force_single_token_labels以达到极致的性能。
67   -
68   -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。
69   -以下是一些参考示例:
70   -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/
71   -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/
72   -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html
73   -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq
74   -
75   -TensorRT support matrix: T4 / SM7.5 supports FP16 and INT8, but not BF16/FP8 in the main matrix. https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Support-Matrix-Guide.pdf
76   -TensorRT-LLM support matrix: current official hardware list omits Turing/T4, so T4 is effectively community-support territory there. https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html
77   -FlashAttention repo: FA3 is Hopper-focused; current published benchmarks are A100/H100-centric. https://github.com/Dao-AILab/flash-attention
78   -vLLM APC docs: KV reuse via hashed KV blocks was the baseline idea for the prefix-cache metadata. https://docs.vllm.ai/_/downloads/en/v0.6.2/pdf/
79   -SGLang HiCache/RadixAttention docs: useful reference for prefix-cache reuse and page-granular KV organization. https://docs.sglang.io/advanced_features/hicache_design.html
80   -FasterTransformer repo: still a useful T4 FP16 optimization baseline and historical Turing-oriented reference. https://github.com/NVIDIA/FasterTransformer
81   -xFormers README: relevant as a Turing-friendly attention alternative; my mainline choice here is PyTorch SDPA on T4, which is an engineering inference from these sources rather than a direct vendor recommendation. https://github.com/facebookresearch/xformers
82   -
83   -注意:已经有一个项目 llm-qp, llm-qp2,这两个项目,对于单token的处理方式是可以的:
84   -SDPA
85   -prefix cache
86   -prebuilt bucket + CUDA graph
87   -他的核心代码是:
88   -from __future__ import annotations
89   -
90   -import hashlib
91   -import time
92   -from dataclasses import asdict, dataclass
93   -from typing import Iterable
94   -
95   -import torch
96   -import torch.nn.functional as F
97   -from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
98   -
99   -try:
100   - from transformers import BitsAndBytesConfig
101   -except ImportError: # pragma: no cover
102   - BitsAndBytesConfig = None
103   -
104   -from llm_qp.config import PromptTaskConfig, RuntimeConfig
105   -from llm_qp.scorer import SmallClassScorer
106   -
107   -try:
108   - from torch.nn.attention import SDPBackend, sdpa_kernel
109   -except ImportError: # pragma: no cover
110   - SDPBackend = None
111   - sdpa_kernel = None
112   -
113   -
114   -@dataclass(slots=True)
115   -class EncodedLabel:
116   - text: str
117   - token_ids: list[int]
118   -
119   -
120   -@dataclass(slots=True)
121   -class PrefixCache:
122   - prefix_ids: list[int]
123   - prefix_hashes: list[str]
124   - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...]
125   -
126   - @property
127   - def prefix_len(self) -> int:
128   - return len(self.prefix_ids)
129   -
130   -
131   -@dataclass(slots=True)
132   -class MultiTokenTables:
133   - label_token_ids: torch.Tensor
134   - label_token_mask: torch.Tensor
135   - label_prefix_ids: torch.Tensor
136   - label_prefix_mask: torch.Tensor
137   - label_position_offsets: torch.Tensor
138   -
139   - @property
140   - def max_label_len(self) -> int:
141   - return self.label_token_ids.shape[1]
142   -
143   - @property
144   - def max_label_prefix_len(self) -> int:
145   - return self.label_prefix_ids.shape[1]
146   -
147   -
148   -@dataclass(slots=True)
149   -class QueryScoreResult:
150   - task_name: str
151   - query: str
152   - predicted_label: str
153   - scores: list[tuple[str, float, float]]
154   - total_ms: float
155   - stage_ms: dict[str, float]
156   - fast_path: bool
157   - prefix_tokens: int
158   - continuation_tokens: int
159   - label_token_lengths: dict[str, int]
160   -
161   - @property
162   - def predicted_prob(self) -> float:
163   - for label, _score, prob in self.scores:
164   - if label == self.predicted_label:
165   - return prob
166   - return 0.0
167   -
168   -
169   -@dataclass(slots=True)
170   -class MultiPromptScoreResult:
171   - query: str
172   - total_ms: float
173   - details: list[QueryScoreResult]
174   - stage_ms: dict[str, float]
175   -
176   - def http_json(self) -> dict[str, object]:
177   - return {
178   - "query": self.query,
179   - "total_ms": self.total_ms,
180   - "stage_ms": self.stage_ms,
181   - "details": [asdict(t) for t in self.details],
182   - "task_results": {
183   - t.task_name: [t.predicted_label, t.continuation_tokens, t.predicted_prob] for t in self.details if t.predicted_label != 'none'
184   - },
185   - }
186   -
187   -
188   -@dataclass(slots=True)
189   -class BatchScoreResult:
190   - batch_size: int
191   - total_ms: float
192   - results: list[MultiPromptScoreResult]
193   - stage_ms: dict[str, float]
194   -
195   -
196   -@dataclass(slots=True)
197   -class SharedRuntime:
198   - device: torch.device
199   - dtype: torch.dtype
200   - tokenizer: object
201   - model: object
202   - backbone: object
203   - hidden_size: int
204   - graph_capture_pool: object | None = None
205   - graph_capture_stream: torch.cuda.Stream | None = None
206   -
207   -
208   -@dataclass(slots=True)
209   -class PromptBatchPlan:
210   - runner: "PromptClassifierRunner"
211   - row_start: int
212   - row_count: int
213   - score_buffer: torch.Tensor
214   -
215   - @property
216   - def row_stop(self) -> int:
217   - return self.row_start + self.row_count
218   -
219   -
220   -@dataclass(slots=True)
221   -class MixedPrefixCache:
222   - batch_size: int
223   - total_rows: int
224   - prefix_lengths: torch.Tensor
225   - attention_mask: torch.Tensor
226   - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...]
227   -
228   - @property
229   - def max_prefix_len(self) -> int:
230   - return int(self.prefix_lengths.max().item()) if self.prefix_lengths.numel() else 0
231   -
232   -
233   -@dataclass(slots=True)
234   -class BatchLayout:
235   - batch_size: int
236   - total_rows: int
237   - plans: list[PromptBatchPlan]
238   -
239   -
240   -@dataclass(slots=True)
241   -class MixedBucketRuntime:
242   - batch_size: int
243   - total_rows: int
244   - continuation_len: int
245   - max_input_len: int
246   - input_ids: torch.Tensor
247   - attention_mask: torch.Tensor
248   - position_ids: torch.Tensor
249   - last_hidden_state: torch.Tensor
250   - graph: torch.cuda.CUDAGraph | None = None
251   -
252   -
253   -@dataclass(slots=True)
254   -class PreloadReport:
255   - total_ms: float
256   - stage_ms: dict[str, float]
257   - runtime: dict[str, object]
258   -
259   -
260   -def _hash_blocks(token_ids: Iterable[int], block_size: int) -> list[str]:
261   - token_list = list(token_ids)
262   - hashes: list[str] = []
263   - for start in range(0, len(token_list), block_size):
264   - block = token_list[start : start + block_size]
265   - payload = ",".join(str(x) for x in block).encode("utf-8")
266   - hashes.append(hashlib.sha1(payload).hexdigest())
267   - return hashes
268   -
269   -
270   -def _expand_legacy_cache(
271   - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...],
272   - batch_size: int,
273   -) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
274   - expanded: list[tuple[torch.Tensor, torch.Tensor]] = []
275   - for key, value in raw_cache:
276   - expanded.append(
277   - (
278   - key.expand(batch_size, *key.shape[1:]).contiguous(),
279   - value.expand(batch_size, *value.shape[1:]).contiguous(),
280   - )
281   - )
282   - return tuple(expanded)
283   -
284   -
285   -class PromptClassifierRunner:
286   - def __init__(
287   - self,
288   - cfg: RuntimeConfig,
289   - task_cfg: PromptTaskConfig,
290   - shared_runtime: SharedRuntime,
291   - ):
292   - self.cfg = cfg
293   - self.task_cfg = task_cfg
294   - self.device = shared_runtime.device
295   - self.dtype = shared_runtime.dtype
296   - self.tokenizer = shared_runtime.tokenizer
297   - self.model = shared_runtime.model
298   - self.backbone = shared_runtime.backbone
299   - self.hidden_size = shared_runtime.hidden_size
300   - self.prefix_text, self.suffix_text = task_cfg.prompt_parts
301   - self.prefix_ids = self.tokenizer.encode(self.prefix_text, add_special_tokens=False)
302   - self.suffix_ids = self.tokenizer.encode(self.suffix_text, add_special_tokens=False)
303   - self.labels = list(task_cfg.labels)
304   - self.encoded_labels = [
305   - EncodedLabel(text=label, token_ids=self._encode_label_token_ids(label))
306   - for label in self.labels
307   - ]
308   - self.num_labels = len(self.labels)
309   - self.lm_head = self.model.get_output_embeddings()
310   - self.lm_head_weight = self.lm_head.weight.detach()
311   - self.lm_head_bias = self.lm_head.bias.detach() if getattr(self.lm_head, "bias", None) is not None else None
312   - if self.cfg.force_single_token_labels and not self._has_unique_single_token_labels():
313   - raise ValueError(
314   - f"prompt task '{self.task_cfg.name}' cannot force single-token labels because first tokens collide"
315   - )
316   - self.fast_path = self._has_unique_single_token_labels()
317   - self.fast_path_token_ids = [item.token_ids[0] for item in self.encoded_labels] if self.fast_path else []
318   - self.scorer = self._build_scorer() if self.fast_path else None
319   - self.multi_token_tables = self._build_multi_token_tables() if not self.fast_path else None
320   - self.prefix_cache = self._build_prefix_cache()
321   -
322   - def _encode_label_token_ids(self, label: str) -> list[int]:
323   - token_ids = self.tokenizer.encode(
324   - f"{self.task_cfg.label_prefix}{label}",
325   - add_special_tokens=False,
326   - )
327   - if not token_ids:
328   - raise ValueError(f"label '{label}' in prompt '{self.task_cfg.name}' tokenizes to an empty sequence")
329   - if self.cfg.force_single_token_labels:
330   - return token_ids[:1]
331   - return token_ids
332   -
333   - def _has_unique_single_token_labels(self) -> bool:
334   - token_ids: list[int] = []
335   - for item in self.encoded_labels:
336   - if len(item.token_ids) != 1:
337   - return False
338   - token_ids.append(item.token_ids[0])
339   - return len(token_ids) == len(set(token_ids))
340   -
341   - def _build_scorer(self) -> SmallClassScorer:
342   - token_ids = torch.tensor(self.fast_path_token_ids, dtype=torch.long, device=self.device)
343   - weights = torch.index_select(self.lm_head_weight, 0, token_ids).to(dtype=self.dtype).contiguous()
344   - bias = None
345   - if self.lm_head_bias is not None:
346   - bias = torch.index_select(self.lm_head_bias, 0, token_ids).to(dtype=self.dtype).contiguous()
347   - return SmallClassScorer(weights=weights, bias=bias)
348   -
349   - def _build_multi_token_tables(self) -> MultiTokenTables:
350   - max_label_len = max(len(item.token_ids) for item in self.encoded_labels)
351   - max_prefix_len = max(len(item.token_ids) - 1 for item in self.encoded_labels)
352   - label_token_ids = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.long)
353   - label_token_mask = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.float32)
354   - label_prefix_ids = torch.full(
355   - (self.num_labels, max_prefix_len),
356   - fill_value=self.tokenizer.pad_token_id,
357   - device=self.device,
358   - dtype=torch.long,
359   - )
360   - label_prefix_mask = torch.zeros((self.num_labels, max_prefix_len), device=self.device, dtype=torch.long)
361   - for idx, item in enumerate(self.encoded_labels):
362   - token_ids = torch.tensor(item.token_ids, device=self.device, dtype=torch.long)
363   - token_len = token_ids.numel()
364   - label_token_ids[idx, :token_len] = token_ids
365   - label_token_mask[idx, :token_len] = 1.0
366   - if token_len > 1:
367   - prefix_len = token_len - 1
368   - label_prefix_ids[idx, :prefix_len] = token_ids[:-1]
369   - label_prefix_mask[idx, :prefix_len] = 1
370   - return MultiTokenTables(
371   - label_token_ids=label_token_ids.contiguous(),
372   - label_token_mask=label_token_mask.contiguous(),
373   - label_prefix_ids=label_prefix_ids.contiguous(),
374   - label_prefix_mask=label_prefix_mask.contiguous(),
375   - label_position_offsets=torch.arange(max_label_len, device=self.device, dtype=torch.long),
376   - )
377   -
378   - @torch.inference_mode()
379   - def _build_prefix_cache(self) -> PrefixCache:
380   - if not self.prefix_ids:
381   - return PrefixCache(prefix_ids=[], prefix_hashes=[], raw_cache=tuple())
382   - prefix_tensor = torch.tensor([self.prefix_ids], dtype=torch.long, device=self.device)
383   - attention_mask = torch.ones_like(prefix_tensor, dtype=torch.long, device=self.device)
384   - outputs = self.model(
385   - input_ids=prefix_tensor,
386   - attention_mask=attention_mask,
387   - use_cache=True,
388   - return_dict=True,
389   - )
390   - raw_cache = tuple(
391   - (layer.keys.detach(), layer.values.detach())
392   - for layer in outputs.past_key_values.layers
393   - )
394   - return PrefixCache(
395   - prefix_ids=list(self.prefix_ids),
396   - prefix_hashes=_hash_blocks(self.prefix_ids, self.cfg.prefix_block_size),
397   - raw_cache=raw_cache,
398   - )
399   -
400   - def expand_prefix_raw_cache(self, batch_size: int) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
401   - if not self.prefix_cache.raw_cache:
402   - return tuple()
403   - return _expand_legacy_cache(self.prefix_cache.raw_cache, batch_size)
404   -
405   - def build_continuation_from_query_ids(self, query_ids: list[int]) -> list[int]:
406   - continuation = query_ids + self.suffix_ids
407   - if not continuation:
408   - raise ValueError("prompt continuation is empty after substituting query")
409   - if self.prefix_cache.prefix_len + len(continuation) > self.cfg.max_length:
410   - raise ValueError(
411   - f"sequence length {self.prefix_cache.prefix_len + len(continuation)} exceeds max_length={self.cfg.max_length}"
412   - )
413   - return continuation
414   -
415   - @torch.inference_mode()
416   - def reduce_fast_scores(
417   - self,
418   - hidden: torch.Tensor,
419   - out_scores: torch.Tensor,
420   - ) -> None:
421   - assert self.scorer is not None
422   - out_scores.copy_(self.scorer(hidden))
423   -
424   - @torch.inference_mode()
425   - def reduce_multi_token_scores(
426   - self,
427   - last_hidden_state: torch.Tensor,
428   - batch_size: int,
429   - max_input_len: int,
430   - score_positions: torch.Tensor,
431   - out_scores: torch.Tensor,
432   - ) -> None:
433   - assert self.multi_token_tables is not None
434   - hidden = last_hidden_state.reshape(batch_size, self.num_labels, max_input_len, self.hidden_size)
435   - gather_index = score_positions[:, None, :, None].expand(
436   - batch_size,
437   - self.num_labels,
438   - self.multi_token_tables.max_label_len,
439   - self.hidden_size,
440   - )
441   - gathered_hidden = torch.gather(hidden, 2, gather_index)
442   - used_mask = self.multi_token_tables.label_token_mask.unsqueeze(0).expand(batch_size, -1, -1).bool()
443   -
444   - token_log_probs = torch.zeros(
445   - (batch_size, self.num_labels, self.multi_token_tables.max_label_len),
446   - device=self.device,
447   - dtype=torch.float32,
448   - )
449   - if used_mask.any():
450   - flat_hidden = gathered_hidden[used_mask]
451   - flat_token_ids = self.multi_token_tables.label_token_ids.unsqueeze(0).expand(batch_size, -1, -1)[used_mask]
452   - linear_hidden = flat_hidden.to(self.dtype) if self.device.type == "cuda" else flat_hidden.float()
453   - linear_weight = self.lm_head_weight if self.device.type == "cuda" else self.lm_head_weight.float()
454   - linear_bias = self.lm_head_bias
455   - if linear_bias is not None and self.device.type != "cuda":
456   - linear_bias = linear_bias.float()
457   - flat_logits = F.linear(linear_hidden, linear_weight, linear_bias)
458   - flat_selected = flat_logits.gather(1, flat_token_ids.unsqueeze(1)).squeeze(1).float()
459   - flat_log_norm = torch.logsumexp(flat_logits.float(), dim=-1)
460   - token_log_probs[used_mask] = flat_selected - flat_log_norm
461   - out_scores.copy_(token_log_probs.sum(dim=-1))
462   -
463   - def build_score_result(
464   - self,
465   - query: str,
466   - scores: torch.Tensor,
467   - stage_ms: dict[str, float],
468   - continuation_tokens: int,
469   - ) -> QueryScoreResult:
470   - score_values = scores.detach().float().cpu().tolist()
471   - best_idx = max(range(len(score_values)), key=score_values.__getitem__)
472   - probs = torch.softmax(torch.tensor(score_values, dtype=torch.float32), dim=0).tolist()
473   - return QueryScoreResult(
474   - task_name=self.task_cfg.name,
475   - query=query,
476   - predicted_label=self.labels[best_idx],
477   - scores=[
478   - (label, score, prob)
479   - for label, score, prob in zip(self.labels, score_values, probs, strict=True)
480   - ],
481   - total_ms=sum(stage_ms.values()),
482   - stage_ms=stage_ms,
483   - fast_path=self.fast_path,
484   - prefix_tokens=self.prefix_cache.prefix_len,
485   - continuation_tokens=continuation_tokens,
486   - label_token_lengths={item.text: len(item.token_ids) for item in self.encoded_labels},
487   - )
488   -
489   -
490   -class MultiPromptRunner:
491   - def __init__(self, cfg: RuntimeConfig):
492   - self.cfg = cfg
493   - t0 = time.perf_counter()
494   - self.shared_runtime = self.build_shared_runtime(cfg)
495   - t1 = time.perf_counter()
496   - self.device = self.shared_runtime.device
497   - self.dtype = self.shared_runtime.dtype
498   - self.tokenizer = self.shared_runtime.tokenizer
499   - self.model = self.shared_runtime.model
500   - self.backbone = self.shared_runtime.backbone
501   - self.hidden_size = self.shared_runtime.hidden_size
502   - self.graph_capture_pool = self.shared_runtime.graph_capture_pool
503   - self.graph_capture_stream = self.shared_runtime.graph_capture_stream
504   - self.runners = [
505   - PromptClassifierRunner(cfg=cfg, task_cfg=task_cfg, shared_runtime=self.shared_runtime)
506   - for task_cfg in cfg.tasks
507   - ]
508   - t2 = time.perf_counter()
509   - self.batch_layouts = {batch_size: self._build_batch_layout(batch_size) for batch_size in self.cfg.batch_sizes}
510   - t3 = time.perf_counter()
511   - self.mixed_prefix_caches = {
512   - batch_size: self._build_mixed_prefix_cache(self.batch_layouts[batch_size])
513   - for batch_size in self.cfg.batch_sizes
514   - }
515   - t4 = time.perf_counter()
516   - self.max_label_prefix_len = max(
517   - (runner.multi_token_tables.max_label_prefix_len if runner.multi_token_tables is not None else 0)
518   - for runner in self.runners
519   - )
520   - self.mixed_buckets = {
521   - (batch_size, continuation_len): self._build_mixed_bucket(
522   - self.batch_layouts[batch_size],
523   - self.mixed_prefix_caches[batch_size],
524   - continuation_len,
525   - )
526   - for batch_size in self.cfg.batch_sizes
527   - for continuation_len in self.cfg.continuation_buckets
528   - }
529   - t5 = time.perf_counter()
530   - self._warmup_results: dict[int, BatchScoreResult] = {}
531   - self._preload_report: PreloadReport | None = None
532   - self._init_stage_ms = {
533   - "load_model_and_tokenizer": (t1 - t0) * 1000.0,
534   - "build_prompt_runtimes": (t2 - t1) * 1000.0,
535   - "build_batch_layouts": (t3 - t2) * 1000.0,
536   - "build_mixed_prefix_caches": (t4 - t3) * 1000.0,
537   - "build_mixed_buckets_and_graphs": (t5 - t4) * 1000.0,
538   - }
539   - self._init_total_ms = sum(self._init_stage_ms.values())
540   -
541   - @staticmethod
542   - def build_shared_runtime(cfg: RuntimeConfig) -> SharedRuntime:
543   - device = torch.device(cfg.device)
544   - dtype = torch.float16
545   - tokenizer = AutoTokenizer.from_pretrained(
546   - cfg.resolved_model_source,
547   - trust_remote_code=cfg.resolved_trust_remote_code,
548   - token=cfg.hf_token,
549   - cache_dir=cfg.hf_cache_dir,
550   - local_files_only=cfg.resolved_local_files_only,
551   - )
552   - if tokenizer.pad_token_id is None:
553   - tokenizer.pad_token = tokenizer.eos_token
554   - attn_impl = MultiPromptRunner._resolve_attn_impl(cfg.attn_backend)
555   - quantization_config = None
556   - model_kwargs: dict[str, object] = {
557   - "trust_remote_code": cfg.resolved_trust_remote_code,
558   - "attn_implementation": attn_impl,
559   - "token": cfg.hf_token,
560   - "cache_dir": cfg.hf_cache_dir,
561   - "local_files_only": cfg.resolved_local_files_only,
562   - }
563   - if cfg.load_in_4bit:
564   - if BitsAndBytesConfig is None:
565   - raise ImportError("transformers BitsAndBytesConfig is unavailable; install bitsandbytes support first")
566   - quantization_config = BitsAndBytesConfig(
567   - load_in_4bit=True,
568   - bnb_4bit_compute_dtype=dtype,
569   - bnb_4bit_quant_type=cfg.bnb_4bit_quant_type,
570   - bnb_4bit_use_double_quant=cfg.bnb_4bit_use_double_quant,
571   - )
572   - model_kwargs["quantization_config"] = quantization_config
573   - model_kwargs["device_map"] = {"": device.index or 0}
574   - else:
575   - model_kwargs["dtype"] = dtype
576   - model_kwargs["device_map"] = None
577   - model = AutoModelForCausalLM.from_pretrained(
578   - cfg.resolved_model_source,
579   - **model_kwargs,
580   - ).eval()
581   - if not cfg.load_in_4bit:
582   - model = model.to(device)
583   - backbone = model.get_submodule(model.base_model_prefix)
584   - hidden_size = model.get_output_embeddings().weight.shape[1]
585   - graph_capture_pool = None
586   - graph_capture_stream = None
587   - if device.type == "cuda" and torch.cuda.is_available() and cfg.cuda_graphs and not cfg.load_in_4bit:
588   - graph_capture_pool = torch.cuda.graph_pool_handle()
589   - graph_capture_stream = torch.cuda.Stream(device=device)
590   - return SharedRuntime(
591   - device=device,
592   - dtype=dtype,
593   - tokenizer=tokenizer,
594   - model=model,
595   - backbone=backbone,
596   - hidden_size=hidden_size,
597   - graph_capture_pool=graph_capture_pool,
598   - graph_capture_stream=graph_capture_stream,
599   - )
600   -
601   - @staticmethod
602   - def _resolve_attn_impl(requested: str) -> str:
603   - if requested in {"sdpa", "eager"}:
604   - return requested
605   - if requested == "auto":
606   - return "sdpa"
607   - raise ValueError(f"unsupported attn_backend: {requested}")
608   -
609   - def _attn_context(self):
610   - if sdpa_kernel is not None and self.cfg.attn_backend in {"auto", "sdpa"}:
611   - return sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH])
612   - return torch.no_grad()
613   -
614   - def _sync(self) -> None:
615   - if self.device.type == "cuda":
616   - torch.cuda.synchronize()
617   -
618   - def _pick_bucket(self, continuation_len: int) -> int:
619   - for bucket in self.cfg.continuation_buckets:
620   - if continuation_len <= bucket:
621   - return bucket
622   - if self.cfg.pad_to_bucket:
623   - raise ValueError(
624   - f"continuation length {continuation_len} exceeds configured buckets; extend continuation_buckets"
625   - )
626   - return continuation_len
627   -
628   - def _build_batch_layout(self, batch_size: int) -> BatchLayout:
629   - plans: list[PromptBatchPlan] = []
630   - row_start = 0
631   - for runner in self.runners:
632   - row_count = batch_size if runner.fast_path else batch_size * runner.num_labels
633   - plans.append(
634   - PromptBatchPlan(
635   - runner=runner,
636   - row_start=row_start,
637   - row_count=row_count,
638   - score_buffer=torch.empty((batch_size, runner.num_labels), device=self.device, dtype=torch.float32),
639   - )
640   - )
641   - row_start += row_count
642   - return BatchLayout(batch_size=batch_size, total_rows=row_start, plans=plans)
643   -
644   - def _build_mixed_prefix_cache(self, layout: BatchLayout) -> MixedPrefixCache:
645   - prefix_lengths = torch.zeros((layout.total_rows,), device=self.device, dtype=torch.long)
646   - non_empty = [plan.runner.prefix_cache.raw_cache for plan in layout.plans if plan.runner.prefix_cache.raw_cache]
647   - if not non_empty:
648   - return MixedPrefixCache(
649   - batch_size=layout.batch_size,
650   - total_rows=layout.total_rows,
651   - prefix_lengths=prefix_lengths,
652   - attention_mask=torch.zeros((layout.total_rows, 0), device=self.device, dtype=torch.long),
653   - raw_cache=tuple(),
654   - )
655   -
656   - max_prefix_len = max(plan.runner.prefix_cache.prefix_len for plan in layout.plans)
657   - num_layers = len(non_empty[0])
658   - attention_mask = torch.zeros((layout.total_rows, max_prefix_len), device=self.device, dtype=torch.long)
659   - raw_layers: list[tuple[torch.Tensor, torch.Tensor]] = []
660   - for layer_idx in range(num_layers):
661   - sample_key, sample_value = non_empty[0][layer_idx]
662   - merged_key = sample_key.new_zeros(
663   - (layout.total_rows, sample_key.shape[1], max_prefix_len, sample_key.shape[3])
664   - )
665   - merged_value = sample_value.new_zeros(
666   - (layout.total_rows, sample_value.shape[1], max_prefix_len, sample_value.shape[3])
667   - )
668   - raw_layers.append((merged_key, merged_value))
669   -
670   - for plan in layout.plans:
671   - runner = plan.runner
672   - prefix_len = runner.prefix_cache.prefix_len
673   - row_slice = slice(plan.row_start, plan.row_stop)
674   - prefix_lengths[row_slice] = prefix_len
675   - if prefix_len == 0:
676   - continue
677   - attention_mask[row_slice, :prefix_len] = 1
678   - raw_cache = runner.expand_prefix_raw_cache(plan.row_count)
679   - for layer_idx, (key, value) in enumerate(raw_cache):
680   - merged_key, merged_value = raw_layers[layer_idx]
681   - merged_key[row_slice, :, :prefix_len, :] = key
682   - merged_value[row_slice, :, :prefix_len, :] = value
683   -
684   - return MixedPrefixCache(
685   - batch_size=layout.batch_size,
686   - total_rows=layout.total_rows,
687   - prefix_lengths=prefix_lengths,
688   - attention_mask=attention_mask.contiguous(),
689   - raw_cache=tuple(raw_layers),
690   - )
691   -
692   - def _build_mixed_bucket(
693   - self,
694   - layout: BatchLayout,
695   - prefix_cache: MixedPrefixCache,
696   - continuation_len: int,
697   - ) -> MixedBucketRuntime:
698   - max_input_len = continuation_len + self.max_label_prefix_len
699   - total_len = prefix_cache.max_prefix_len + max_input_len
700   - input_ids = torch.full(
701   - (layout.total_rows, max_input_len),
702   - fill_value=self.tokenizer.pad_token_id,
703   - device=self.device,
704   - dtype=torch.long,
705   - )
706   - attention_mask = torch.zeros((layout.total_rows, total_len), device=self.device, dtype=torch.long)
707   - if prefix_cache.max_prefix_len:
708   - attention_mask[:, : prefix_cache.max_prefix_len] = prefix_cache.attention_mask
709   - position_ids = (
710   - prefix_cache.prefix_lengths[:, None]
711   - + torch.arange(max_input_len, device=self.device, dtype=torch.long).unsqueeze(0)
712   - ).contiguous()
713   - last_hidden_state = torch.empty(
714   - (layout.total_rows, max_input_len, self.hidden_size),
715   - device=self.device,
716   - dtype=self.dtype,
717   - )
718   - bucket = MixedBucketRuntime(
719   - batch_size=layout.batch_size,
720   - total_rows=layout.total_rows,
721   - continuation_len=continuation_len,
722   - max_input_len=max_input_len,
723   - input_ids=input_ids,
724   - attention_mask=attention_mask,
725   - position_ids=position_ids,
726   - last_hidden_state=last_hidden_state,
727   - )
728   - if self.cfg.cuda_graphs:
729   - self._capture_mixed_bucket(bucket, prefix_cache)
730   - return bucket
731   -
732   - @torch.inference_mode()
733   - def _run_mixed_backbone(
734   - self,
735   - bucket: MixedBucketRuntime,
736   - prefix_cache: MixedPrefixCache,
737   - ) -> None:
738   - cache = DynamicCache(ddp_cache_data=prefix_cache.raw_cache, config=self.model.config)
739   - with self._attn_context():
740   - outputs = self.backbone(
741   - input_ids=bucket.input_ids,
742   - attention_mask=bucket.attention_mask,
743   - position_ids=bucket.position_ids,
744   - past_key_values=cache,
745   - use_cache=False,
746   - return_dict=True,
747   - )
748   - bucket.last_hidden_state.copy_(outputs.last_hidden_state)
749   -
750   - def _capture_mixed_bucket(self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache) -> None:
751   - if not torch.cuda.is_available():
752   - return
753   - try:
754   - torch.cuda.synchronize()
755   - stream = self.graph_capture_stream or torch.cuda.Stream(device=self.device)
756   - with torch.cuda.stream(stream):
757   - for _ in range(self.cfg.graph_warmups):
758   - self._run_mixed_backbone(bucket, prefix_cache)
759   - stream.synchronize()
760   - graph = torch.cuda.CUDAGraph()
761   - with torch.cuda.graph(graph, pool=self.graph_capture_pool, stream=stream):
762   - self._run_mixed_backbone(bucket, prefix_cache)
763   - bucket.graph = graph
764   - except RuntimeError:
765   - bucket.graph = None
766   -
767   - def _prepare_bucket(
768   - self,
769   - layout: BatchLayout,
770   - prefix_cache: MixedPrefixCache,
771   - bucket: MixedBucketRuntime,
772   - query_ids_batch: list[list[int]],
773   - ) -> tuple[list[list[int]], dict[str, list[object]]]:
774   - del prefix_cache
775   - bucket.input_ids.fill_(self.tokenizer.pad_token_id)
776   - bucket.attention_mask.zero_()
777   - if self.mixed_prefix_caches[layout.batch_size].max_prefix_len:
778   - bucket.attention_mask[:, : self.mixed_prefix_caches[layout.batch_size].max_prefix_len] = (
779   - self.mixed_prefix_caches[layout.batch_size].attention_mask
780   - )
781   - continuation_lengths_per_task: dict[str, list[int]] = {}
782   - continuation_tokens_per_task: dict[str, list[list[int]]] = {}
783   - prefix_base = self.mixed_prefix_caches[layout.batch_size].max_prefix_len
784   - for plan in layout.plans:
785   - runner = plan.runner
786   - per_query_continuations = [runner.build_continuation_from_query_ids(query_ids) for query_ids in query_ids_batch]
787   - continuation_tokens_per_task[runner.task_cfg.name] = per_query_continuations
788   - continuation_lengths_per_task[runner.task_cfg.name] = [len(ids) for ids in per_query_continuations]
789   - if runner.fast_path:
790   - for batch_idx, continuation in enumerate(per_query_continuations):
791   - cont_len = len(continuation)
792   - row_idx = plan.row_start + batch_idx
793   - bucket.input_ids[row_idx, :cont_len] = torch.tensor(continuation, device=self.device, dtype=torch.long)
794   - bucket.attention_mask[row_idx, prefix_base : prefix_base + cont_len] = 1
795   - continue
796   -
797   - assert runner.multi_token_tables is not None
798   - for batch_idx, continuation in enumerate(per_query_continuations):
799   - cont_len = len(continuation)
800   - row_start = plan.row_start + batch_idx * runner.num_labels
801   - row_stop = row_start + runner.num_labels
802   - row_slice = slice(row_start, row_stop)
803   - cont_tensor = torch.tensor(continuation, device=self.device, dtype=torch.long)
804   - bucket.input_ids[row_slice, :cont_len] = cont_tensor.unsqueeze(0).expand(runner.num_labels, -1)
805   - bucket.attention_mask[row_slice, prefix_base : prefix_base + cont_len] = 1
806   - if runner.multi_token_tables.max_label_prefix_len:
807   - bucket.input_ids[
808   - row_slice,
809   - cont_len : cont_len + runner.multi_token_tables.max_label_prefix_len,
810   - ] = runner.multi_token_tables.label_prefix_ids
811   - bucket.attention_mask[
812   - row_slice,
813   - prefix_base + cont_len : prefix_base + cont_len + runner.multi_token_tables.max_label_prefix_len,
814   - ] = runner.multi_token_tables.label_prefix_mask
815   - return query_ids_batch, {
816   - "continuation_lengths_per_task": continuation_lengths_per_task,
817   - "continuation_tokens_per_task": continuation_tokens_per_task,
818   - }
819   -
820   - def _reduce_prompt_scores(
821   - self,
822   - layout: BatchLayout,
823   - bucket: MixedBucketRuntime,
824   - query_texts: list[str],
825   - prep_meta: dict[str, list[object]],
826   - shared_stage_ms: dict[str, float],
827   - ) -> list[MultiPromptScoreResult]:
828   - result_rows = [[] for _ in range(layout.batch_size)]
829   - prompt_reduce_total_ms = 0.0
830   - for plan in layout.plans:
831   - runner = plan.runner
832   - continuation_lengths = prep_meta["continuation_lengths_per_task"][runner.task_cfg.name]
833   - reduce_start = time.perf_counter()
834   - if runner.fast_path:
835   - hidden_rows = []
836   - row_slice = bucket.last_hidden_state[plan.row_start : plan.row_start + layout.batch_size]
837   - for batch_idx, cont_len in enumerate(continuation_lengths):
838   - hidden_rows.append(row_slice[batch_idx, cont_len - 1])
839   - hidden = torch.stack(hidden_rows, dim=0)
840   - runner.reduce_fast_scores(hidden=hidden, out_scores=plan.score_buffer)
841   - stage_name = "tail_scorer"
842   - else:
843   - assert runner.multi_token_tables is not None
844   - score_positions = torch.stack(
845   - [
846   - cont_len - 1 + runner.multi_token_tables.label_position_offsets
847   - for cont_len in continuation_lengths
848   - ],
849   - dim=0,
850   - )
851   - runner.reduce_multi_token_scores(
852   - last_hidden_state=bucket.last_hidden_state[plan.row_start : plan.row_stop],
853   - batch_size=layout.batch_size,
854   - max_input_len=bucket.max_input_len,
855   - score_positions=score_positions,
856   - out_scores=plan.score_buffer,
857   - )
858   - stage_name = "candidate_reduce"
859   - self._sync()
860   - reduce_end = time.perf_counter()
861   - reduce_ms = (reduce_end - reduce_start) * 1000.0
862   - prompt_reduce_total_ms += reduce_ms
863   - for batch_idx, query in enumerate(query_texts):
864   - stage_ms = dict(shared_stage_ms)
865   - stage_ms[stage_name] = reduce_ms / layout.batch_size
866   - result_rows[batch_idx].append(
867   - runner.build_score_result(
868   - query=query,
869   - scores=plan.score_buffer[batch_idx],
870   - stage_ms=stage_ms,
871   - continuation_tokens=continuation_lengths[batch_idx],
872   - )
873   - )
874   -
875   - batch_total_ms = sum(shared_stage_ms.values()) + prompt_reduce_total_ms
876   - shared_plus_reduce = dict(shared_stage_ms)
877   - shared_plus_reduce["prompt_reduce_total"] = prompt_reduce_total_ms
878   - results: list[MultiPromptScoreResult] = []
879   - for batch_idx, query in enumerate(query_texts):
880   - results.append(
881   - MultiPromptScoreResult(
882   - query=query,
883   - total_ms=batch_total_ms / layout.batch_size,
884   - details=result_rows[batch_idx],
885   - stage_ms={
886   - **shared_plus_reduce,
887   - "per_query_total_estimate": batch_total_ms / layout.batch_size,
888   - },
889   - )
890   - )
891   - return results
892   -
893   - @torch.inference_mode()
894   - def score_queries(self, queries: list[str]) -> BatchScoreResult:
895   - if not queries:
896   - raise ValueError("queries must not be empty")
897   - batch_size = len(queries)
898   - if batch_size not in self.batch_layouts:
899   - raise ValueError(f"batch size {batch_size} is not preloaded; configured batch_sizes={self.cfg.batch_sizes}")
900   - layout = self.batch_layouts[batch_size]
901   - prefix_cache = self.mixed_prefix_caches[batch_size]
902   -
903   - self._sync()
904   - t0 = time.perf_counter()
905   - query_ids_batch = [self.tokenizer.encode(query, add_special_tokens=False) for query in queries]
906   - self._sync()
907   - t1 = time.perf_counter()
908   -
909   - max_continuation_len = max(
910   - len(plan.runner.build_continuation_from_query_ids(query_ids))
911   - for plan in layout.plans
912   - for query_ids in query_ids_batch
913   - )
914   - picked_bucket = self._pick_bucket(max_continuation_len)
915   - bucket = self.mixed_buckets[(batch_size, picked_bucket)]
916   - _, prep_meta = self._prepare_bucket(layout, prefix_cache, bucket, query_ids_batch)
917   - self._sync()
918   - t2 = time.perf_counter()
919   -
920   - if bucket.graph is not None:
921   - bucket.graph.replay()
922   - else:
923   - self._run_mixed_backbone(bucket, prefix_cache)
924   - self._sync()
925   - t3 = time.perf_counter()
926   -
927   - shared_stage_ms = {
928   - "encode_queries_shared": (t1 - t0) * 1000.0,
929   - "prepare_batch_shared": (t2 - t1) * 1000.0,
930   - "backbone_shared": (t3 - t2) * 1000.0,
931   - }
932   - results = self._reduce_prompt_scores(layout, bucket, queries, prep_meta, shared_stage_ms)
933   - total_ms = sum(shared_stage_ms.values()) + results[0].stage_ms["prompt_reduce_total"]
934   - return BatchScoreResult(
935   - batch_size=batch_size,
936   - total_ms=total_ms,
937   - results=results,
938   - stage_ms={
939   - **shared_stage_ms,
940   - "prompt_reduce_total": results[0].stage_ms["prompt_reduce_total"],
941   - },
942   - )
943   -
944   - def score_query(self, query: str) -> MultiPromptScoreResult:
945   - return self.score_queries([query]).results[0]
946   -
947   - def preload(self) -> PreloadReport:
948   - if self._preload_report is not None:
949   - return self._preload_report
950   - stage_ms: dict[str, float] = dict(self._init_stage_ms)
951   - start = time.perf_counter()
952   - self._sync()
953   - t0 = time.perf_counter()
954   - warmup_batch_sizes = self.cfg.warmup_batch_sizes or self.cfg.batch_sizes
955   - for batch_size in warmup_batch_sizes:
956   - queries = [self.cfg.warmup_query] * batch_size
957   - self._warmup_results[batch_size] = self.score_queries(queries)
958   - self._sync()
959   - t1 = time.perf_counter()
960   - stage_ms["warmup_end_to_end"] = (t1 - t0) * 1000.0
961   - stage_ms["startup_total_before_warmup"] = self._init_total_ms
962   - total_ms = self._init_total_ms + (t1 - start) * 1000.0
963   - runtime = self.preload_report()
964   - self._preload_report = PreloadReport(total_ms=total_ms, stage_ms=stage_ms, runtime=runtime)
965   - return self._preload_report
966   -
967   - def preload_report(self) -> dict[str, object]:
968   - return {
969   - "model_name": self.cfg.resolved_model_name,
970   - "model_source": self.cfg.resolved_model_source,
971   - "device": str(self.device),
972   - "dtype": self.cfg.dtype,
973   - "attn_backend": self.cfg.attn_backend,
974   - "execution_model": "single_mixed_backbone_per_batch",
975   - "num_tasks": len(self.runners),
976   - "task_names": [runner.task_cfg.name for runner in self.runners],
977   - "batch_sizes": list(self.cfg.batch_sizes),
978   - "continuation_buckets": list(self.cfg.continuation_buckets),
979   - "mixed_bucket_count": len(self.mixed_buckets),
980   - "captured_mixed_buckets": sum(bucket.graph is not None for bucket in self.mixed_buckets.values()),
981   - "all_configured_buckets_preloaded": True,
982   - "init_stage_ms": dict(self._init_stage_ms),
983   - "init_total_ms": self._init_total_ms,
984   - "force_single_token_labels": self.cfg.force_single_token_labels,
985   - "warmup_query": self.cfg.warmup_query,
986   - "tasks": [
987   - {
988   - "task_name": runner.task_cfg.name,
989   - "fast_path": runner.fast_path,
990   - "num_labels": runner.num_labels,
991   - "label_token_lengths": {item.text: len(item.token_ids) for item in runner.encoded_labels},
992   - "prefix_tokens": runner.prefix_cache.prefix_len,
993   - "prefix_hashes": runner.prefix_cache.prefix_hashes,
994   - "label_prefix": runner.task_cfg.label_prefix,
995   - }
996   - for runner in self.runners
997   - ],
998   - }
999   -
1000   -但是关于多token的处理方式是低效的、是错的,不要参考他,请你重新实现。
1001   -本地的.venv已经创建好,是复用llm-qp的,请使用该环境
1002   -有用的代码我已经引用过来了,请你基于需求,搜寻相关资料,专注于重新开发全新的版本。
1003   -任务较重,请耐心逐步完成,慢慢来,要长时间迭代一直到服务建设完成、测试完全符合预期。
docs/issues/issue-2026-04-06-推理优化.md deleted
... ... @@ -1,59 +0,0 @@
1   -
2   -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。
3   -先专注于推理的优化,不用考虑服务后,可以程序启动后标准输入读取query,输出分类词。
4   -
5   -示例prompt和对应的9个分类词:
6   -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query}
7   -
8   -做专用执行路径,不是在通用生成引擎上做配置优化。
9   -
10   -prompt已经固定(不要考虑蒸馏、微调、或者缩短prompt。缩短prompt是可以我自己调整的,并且prompt可能改为其他场景,与你做专用推理优化不相关,我给的prompt只是一个例子,你专注于专用推理的框架,适配配置化的prompt+分类词列表打分检查)
11   -
12   -
13   -使用Tesla T4,因此:
14   -1. 使用FP16。不用 BF16 作为主路径
15   -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention
16   -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。
17   -
18   -
19   -主要考虑优化方向为:
20   -hidden_last -> N-class scorer -> argmax
21   -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV
22   -去 full vocab logits
23   -去 decode / constrained decode
24   -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8)
25   -专用 tail kernel(输出 N 类原始分数)
26   -
27   -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化:
28   -15.1 预分配
29   -对每个 bucket 预分配:
30   -input ids buffer
31   -attention scratch buffer
32   -output hidden buffer
33   -class id buffer
34   -15.2 Graph capture
35   -
36   -每个 bucket 预先 capture 一张图:
37   -embedding
38   -transformer 主干
39   -compact last-state
40   -9-way tail kernel
41   -
42   -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。
43   -
44   -你有sudo权限,你可以执行为本项目安装自己的环境
45   -
46   -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型)
47   -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。
48   -
49   -
50   -
51   -
52   -
53   -
54   -
55   -另外,我想要有个命令行工具,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。
56   -输入query,输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下),目前是否已经满足了。
57   -
58   -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。
59   -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。
docs/issues/issue-2026-04-07-服务化.md deleted
... ... @@ -1,66 +0,0 @@
1   -彻底清除掉命令行交互方式的代码,改造为提供 HTTP 接口,端口6001。
2   -并提供scripts/service_ctl.sh的方式管理服务的启停
3   -返回分类结果的list(results):list由多个dict组成,每个dict的key为对应的分类任务,value是三元组,即值、打分、概率(如果值为none则该项不输出)
4   -完善日志系统,当前cli方式输出的重要信息,在日志以及http接口的details字段体现。
5   -提供一个压测脚本放到本项目的合适的目录下,作为压测工具。该压测工具针对每条请求打印出结果,并最后给出性能指标,参考:
6   -#!/bin/bash
7   -
8   -# 默认值
9   -concurrency=${1:-1}
10   -top_lines=${2:-100}
11   -
12   -# 固定查询文件路径
13   -query_file="/data/saas-search/scripts/evaluation/queries/queries.txt"
14   -
15   -# 检查文件是否存在
16   -if [ ! -f "$query_file" ]; then
17   - echo "错误: 查询文件不存在: $query_file" >&2
18   - exit 1
19   -fi
20   -
21   -# 检查 jq 是否可用
22   -if ! command -v jq &> /dev/null; then
23   - echo "错误: 需要安装 jq 来解析 JSON" >&2
24   - exit 1
25   -fi
26   -
27   -url="http://127.0.0.1:6001/..."
28   -max_jobs=$concurrency
29   -job_count=0
30   -
31   -# 读取文件前 top_lines 行,每行作为一个 query
32   -while IFS= read -r query; do
33   - # 跳过空行
34   - [ -z "$query" ] && continue
35   -
36   - # 启动子进程执行请求
37   - (
38   - # 安全构建 JSON payload
39   - payload=$(jq -n --arg q "$query" '{query: $q}')
40   - # 发送请求并获取响应
41   - response=$(curl -s -X POST "$url" \
42   - -H 'Content-Type: application/json' \
43   - -d "$payload")
44   - # 提取 results 字段(紧凑 JSON 格式)
45   - results=$(echo "$response" | jq -c '.results')
46   - # 输出 query 和对应的 results
47   - printf "%s\t%s\n" "$query" "$results"
48   - ) &
49   -
50   - # 控制并发数量
51   - ((job_count++))
52   - if (( job_count >= max_jobs )); then
53   - wait -n # 等待任意一个后台进程完成
54   - ((job_count--))
55   - fi
56   -done < <(head -n "$top_lines" "$query_file")
57   -
58   -# 等待所有剩余后台进程完成
59   -# 在这里统计处性能情况,指标:
60   - 平均耗时:
61   - 最大耗时:
62   - 最小耗时:
63   - TP50:
64   - TP90:
65   - TP99:
66   -wait
67 0 \ No newline at end of file
docs/issues/issue.md
1 1 项目 TODO 清单
2 2  
  3 +CLAUDE.md需要更新
  4 +
3 5 2. 核心搜索功能优化
4 6  
5 7 2.1 意图识别模块
... ...
scripts/evaluation/eval_framework/__init__.py
... ... @@ -20,7 +20,6 @@ from .constants import ( # noqa: E402
20 20 RELEVANCE_LOW,
21 21 RELEVANCE_NON_IRRELEVANT,
22 22 VALID_LABELS,
23   - normalize_stored_label,
24 23 )
25 24 from .framework import SearchEvaluationFramework # noqa: E402
26 25 from .store import EvalStore, QueryBuildResult # noqa: E402
... ... @@ -51,7 +50,6 @@ __all__ = [
51 50 "create_web_app",
52 51 "ensure_dir",
53 52 "main",
54   - "normalize_stored_label",
55 53 "render_batch_report_markdown",
56 54 "sha1_text",
57 55 "utc_now_iso",
... ...
scripts/evaluation/eval_framework/constants.py
... ... @@ -42,20 +42,6 @@ STOP_PROB_MAP = {
42 42 RELEVANCE_IRRELEVANT: 0.0,
43 43 }
44 44  
45   -_LEGACY_LABEL_MAP = {
46   - "Exact": RELEVANCE_EXACT,
47   - "Partial": RELEVANCE_HIGH,
48   -}
49   -
50   -
51   -def normalize_stored_label(label: str) -> str:
52   - """Map legacy 3-way SQLite labels to current 4-way strings; pass through canonical labels."""
53   - s = str(label).strip()
54   - if s in VALID_LABELS:
55   - return s
56   - return _LEGACY_LABEL_MAP.get(s, s)
57   -
58   -
59 45 DEFAULT_ARTIFACT_ROOT = PROJECT_ROOT / "artifacts" / "search_evaluation"
60 46 DEFAULT_QUERY_FILE = _SCRIPTS_EVAL_DIR / "queries" / "queries.txt"
61 47  
... ...
scripts/evaluation/eval_framework/static/eval_web.css
... ... @@ -48,9 +48,9 @@
48 48 .results { display: grid; gap: 10px; }
49 49 .result { display: grid; grid-template-columns: 110px 100px 1fr; gap: 14px; align-items: center; background: var(--panel); border: 1px solid var(--line); border-radius: 18px; padding: 12px; }
50 50 .badge { display: inline-block; padding: 8px 10px; border-radius: 999px; color: white; font-weight: 700; text-align: center; }
51   - .label-exact-match { background: var(--exact); }
52   - .label-high-relevant { background: var(--high); }
53   - .label-low-relevant { background: var(--low); }
  51 + .label-fully-relevant { background: var(--exact); }
  52 + .label-mostly-relevant { background: var(--high); }
  53 + .label-weakly-relevant { background: var(--low); }
54 54 .label-irrelevant { background: var(--irrelevant); }
55 55 .badge-unknown { background: #637381; }
56 56 .thumb { width: 100px; height: 100px; object-fit: cover; border-radius: 14px; background: #e7e1d4; }
... ...
scripts/evaluation/eval_framework/store.py
... ... @@ -8,7 +8,7 @@ from dataclasses import dataclass
8 8 from pathlib import Path
9 9 from typing import Any, Dict, List, Optional, Sequence
10 10  
11   -from .constants import VALID_LABELS, normalize_stored_label
  11 +from .constants import VALID_LABELS
12 12 from .utils import ensure_dir, safe_json_dumps, utc_now_iso
13 13  
14 14  
... ... @@ -220,7 +220,7 @@ class EvalStore:
220 220 """,
221 221 (tenant_id, query_text),
222 222 ).fetchall()
223   - return {str(row["spu_id"]): normalize_stored_label(str(row["label"])) for row in rows}
  223 + return {str(row["spu_id"]): str(row["label"]) for row in rows}
224 224  
225 225 def upsert_labels(
226 226 self,
... ... @@ -379,8 +379,8 @@ class EvalStore:
379 379 SELECT
380 380 query_text,
381 381 COUNT(*) AS total,
382   - SUM(CASE WHEN label IN ('Fully Relevant','Exact') THEN 1 ELSE 0 END) AS exact_count,
383   - SUM(CASE WHEN label IN ('Mostly Relevant','Partial') THEN 1 ELSE 0 END) AS high_relevant_count,
  382 + SUM(CASE WHEN label='Fully Relevant' THEN 1 ELSE 0 END) AS exact_count,
  383 + SUM(CASE WHEN label='Mostly Relevant' THEN 1 ELSE 0 END) AS high_relevant_count,
384 384 SUM(CASE WHEN label='Weakly Relevant' THEN 1 ELSE 0 END) AS low_relevant_count,
385 385 SUM(CASE WHEN label='Irrelevant' THEN 1 ELSE 0 END) AS irrelevant_count,
386 386 MAX(updated_at) AS updated_at
... ... @@ -409,8 +409,8 @@ class EvalStore:
409 409 """
410 410 SELECT
411 411 COUNT(*) AS total,
412   - SUM(CASE WHEN label IN ('Fully Relevant','Exact') THEN 1 ELSE 0 END) AS exact_count,
413   - SUM(CASE WHEN label IN ('Mostly Relevant','Partial') THEN 1 ELSE 0 END) AS high_relevant_count,
  412 + SUM(CASE WHEN label='Fully Relevant' THEN 1 ELSE 0 END) AS exact_count,
  413 + SUM(CASE WHEN label='Mostly Relevant' THEN 1 ELSE 0 END) AS high_relevant_count,
414 414 SUM(CASE WHEN label='Weakly Relevant' THEN 1 ELSE 0 END) AS low_relevant_count,
415 415 SUM(CASE WHEN label='Irrelevant' THEN 1 ELSE 0 END) AS irrelevant_count,
416 416 MAX(updated_at) AS updated_at
... ...
suggestion/ARCHITECTURE_V2.md deleted
... ... @@ -1,304 +0,0 @@
1   -# Suggestion 架构方案 V2(仅 Suggest,去除结果直达)
2   -
3   -## 0. 结论
4   -
5   -本方案将 Suggest 设计为**独立高性能检索系统**,只返回建议词,不再返回商品卡片,也不做历史兼容。
6   -
7   -- 只保留 `/search/suggestions` 的词级自动补全能力
8   -- 完全移除 `with_results/result_size/products[]` 链路
9   -- 多语言优先,支持高并发、低延迟、可持续演进
10   -
11   ----
12   -
13   -## 1. 当前实现的关键问题(基于现有代码审视)
14   -
15   -1. 在线链路曾包含“suggest -> 二次商品查询”,属于典型 N+1 放大,QPS 上升后延迟和 ES 负载都不稳定。
16   -2. `builder.py` 全量构建使用“大量 in-memory 聚合 + fetchall”,大租户下内存风险高。
17   -3. 查询参数上限过大(原 `size<=200`),不符合自动补全接口性能边界。
18   -4. 文档与实现长期混合(README 仍包含结果直达),导致认知不一致。
19   -5. 多语言归一化仍偏基础(仅 lower/空白折叠),对 Unicode、变音符、跨语系兼容不够。
20   -
21   ----
22   -
23   -## 2. 目标与 SLO
24   -
25   -### 2.1 业务目标
26   -
27   -- 输入时实时返回高相关建议词(query suggestion)
28   -- 多语言稳定(至少覆盖租户配置 `index_languages`)
29   -- 支持词级排序和运营治理(黑白名单、降噪、降权)
30   -
31   -### 2.2 性能目标(建议)
32   -
33   -- P50 < 10ms,P95 < 25ms,P99 < 50ms(ES 查询耗时,不含网关)
34   -- 单集群支持高并发(千级 QPS 可横向扩展)
35   -- 数据新鲜度:增量 5-15 分钟可见
36   -
37   ----
38   -
39   -## 3. 总体架构
40   -
41   -## 3.1 在线路径(单跳)
42   -
43   -Client -> API `/search/suggestions` -> ES `search_suggestions_v2` -> 返回 suggestions
44   -
45   -原则:
46   -
47   -- **单次 ES 查询完成主路径**(可选双召回融合,但仍在同一次 API 请求内完成)
48   -- 不调用 `search_products`,不返回商品结果
49   -- 通过 `routing=tenant_id` 避免跨分片 fan-out
50   -
51   -## 3.2 离线路径(构建)
52   -
53   -数据源:
54   -
55   -- 商品字段:`title.{lang}`、`qanchors.{lang}`
56   -- 搜索日志:`shoplazza_search_log`(含 `language/request_params`)
57   -- 行为信号(可选增强):点击、加购、下单
58   -
59   -产物:
60   -
61   -- Suggest 文档(`tenant_id + lang + text_norm` 唯一)
62   -- completion + prefix 检索字段
63   -- 排序特征(热度、近期度、质量分)
64   -
65   -发布方式:
66   -
67   -- 写入新物理索引(版本化)
68   -- 原子切换 alias(零停机)
69   -
70   ----
71   -
72   -## 4. 索引设计(ES)
73   -
74   -## 4.1 索引组织
75   -
76   -推荐两级策略:
77   -
78   -1. 默认:环境级共享索引(降低海量租户 index 数量)
79   -2. 大租户:可升级为租户独享索引(隔离资源)
80   -
81   -统一通过 alias 暴露:
82   -
83   -- `search_suggestions_v2_current`
84   -
85   -## 4.2 Mapping(核心字段)
86   -
87   -```json
88   -{
89   - "settings": {
90   - "number_of_shards": 3,
91   - "number_of_replicas": 1,
92   - "refresh_interval": "30s"
93   - },
94   - "mappings": {
95   - "properties": {
96   - "tenant_id": { "type": "keyword" },
97   - "lang": { "type": "keyword" },
98   - "text": { "type": "keyword" },
99   - "text_norm": { "type": "keyword" },
100   - "status": { "type": "byte" },
101   - "sources": { "type": "keyword" },
102   -
103   - "query_count_7d": { "type": "integer" },
104   - "query_count_30d": { "type": "integer" },
105   - "ctr_30d": { "type": "float" },
106   - "order_rate_30d": { "type": "float" },
107   - "rank_score": { "type": "float" },
108   -
109   - "suggest": {
110   - "type": "completion",
111   - "contexts": [
112   - { "name": "tenant", "type": "category" },
113   - { "name": "lang", "type": "category" }
114   - ]
115   - },
116   -
117   - "sat": {
118   - "properties": {
119   - "zh": { "type": "search_as_you_type", "analyzer": "index_ik" },
120   - "en": { "type": "search_as_you_type", "analyzer": "english" },
121   - "ar": { "type": "search_as_you_type", "analyzer": "arabic" }
122   - }
123   - },
124   -
125   - "updated_at": { "type": "date" }
126   - }
127   - }
128   -}
129   -```
130   -
131   -说明:
132   -
133   -- `completion` 负责极速前缀命中(主召回)
134   -- `search_as_you_type` 负责多词前缀和召回兜底
135   -- `contexts` 强制租户与语言隔离
136   -
137   ----
138   -
139   -## 5. 多语言策略
140   -
141   -1. 语言归属优先级:`log.language > request_params.language > 脚本识别 > tenant.primary_language`
142   -2. 统一归一化:NFKC、大小写折叠、空白折叠、标点清洗
143   -3. 分词器按语言配置:
144   - - 中文:IK/ANSJ(与主索引保持一致)
145   - - 拉丁语系:对应内置 analyzer
146   - - 未覆盖语种:`standard + ICU folding` 兜底
147   -4. 保证写入语言必须在租户 `index_languages` 内
148   -
149   ----
150   -
151   -## 6. 在线检索策略(高性能)
152   -
153   -## 6.1 双通道召回(推荐)
154   -
155   -1. 通道 A:`completion suggester`(prefix,skip_duplicates)
156   -2. 通道 B:`multi_match(type=bool_prefix)` on `search_as_you_type`
157   -3. 融合去重:按 `text_norm` 去重,按最终分排序截断
158   -
159   -## 6.2 查询约束
160   -
161   -- 默认 `size=10`,最大 `size=50`
162   -- `track_total_hits=false`
163   -- `_source` 仅返回必要字段(`text/lang/rank_score/sources`)
164   -- `routing=tenant_id`
165   -
166   -## 6.3 打分建议
167   -
168   -```text
169   -final_score =
170   - es_score
171   - + a1*log1p(query_count_30d)
172   - + a2*log1p(query_count_7d)
173   - + a3*ctr_30d
174   - + a4*order_rate_30d
175   - + a5*freshness_decay
176   -```
177   -
178   ----
179   -
180   -## 7. 构建与发布
181   -
182   -## 7.1 构建模式
183   -
184   -- 每日全量:重建全量特征,清理脏词
185   -- 小时级增量:只处理新日志窗口
186   -
187   -## 7.2 工程要求
188   -
189   -- 禁止 `fetchall` 全量入内存,改为流式读取(分页/游标)
190   -- ES 扫描采用 `search_after` 流式聚合
191   -- 批量写入采用 bulk(分块 + 重试 + 失败重放)
192   -
193   -## 7.3 发布策略
194   -
195   -1. `search_suggestions_v2_YYYYMMDDHHmm` 写入完成
196   -2. 校验 count/抽样查询/核心词覆盖
197   -3. alias 原子切换到新索引
198   -4. 保留上一个版本用于快速回滚
199   -
200   ----
201   -
202   -## 8. API 契约(V2)
203   -
204   -请求:
205   -
206   -- `GET /search/suggestions`
207   -- 参数:`q`、`language`、`size`
208   -- Header:`X-Tenant-ID`
209   -
210   -响应:
211   -
212   -```json
213   -{
214   - "query": "iph",
215   - "language": "en",
216   - "resolved_language": "en",
217   - "suggestions": [
218   - {
219   - "text": "iphone 15",
220   - "lang": "en",
221   - "score": 8.31,
222   - "rank_score": 6.72,
223   - "sources": ["query_log", "qanchor"]
224   - }
225   - ],
226   - "took_ms": 12
227   -}
228   -```
229   -
230   -删除项(明确不支持):
231   -
232   -- `with_results`
233   -- `result_size`
234   -- `products[]`
235   -
236   ----
237   -
238   -## 9. 观测与治理
239   -
240   -核心监控:
241   -
242   -- QPS、P50/P95/P99、错误率
243   -- 空结果率(按语言、按租户)
244   -- suggestion 覆盖率(top query 是否命中)
245   -- 语言冲突率(log vs request_params)
246   -- 噪声词比例、黑名单命中率
247   -
248   -治理机制:
249   -
250   -- 黑名单:强制下线
251   -- 白名单:强制保留并可加权
252   -- 最小热度阈值:低频垃圾词过滤
253   -- 时间衰减:过期词自动下沉
254   -
255   ----
256   -
257   -## 10. 与官方最佳实践对齐(ES)
258   -
259   -本方案直接采用以下官方建议:
260   -
261   -1. `completion` 适合高性能自动补全,支持 `skip_duplicates` 与上下文过滤。
262   -2. `search_as_you_type + bool_prefix` 是官方推荐的 as-you-type 查询方式。
263   -3. `edge_ngram` 仅用于索引时分词,查询时应用普通 analyzer(`search_analyzer`)。
264   -4. 多语言场景使用 ICU Analysis 插件增强 Unicode 处理。
265   -5. 通过 `routing` 将租户请求路由到单分片,降低 fan-out。
266   -
267   ----
268   -
269   -## 11. 分阶段落地
270   -
271   -1. Phase 1(本次):去除结果直达,稳定 Suggest 单能力
272   -2. Phase 2:流式增量构建 + alias 原子发布
273   -3. Phase 3:行为信号排序(CTR/CVR)+ 运营治理台
274   -4. Phase 4:大租户独享索引自动升降级
275   -
276   ----
277   -
278   -## 12. Phase 2 落地命令(当前仓库)
279   -
280   -全量重建(版本化索引 + alias 发布):
281   -
282   -```bash
283   -python main.py build-suggestions \
284   - --tenant-id 162 \
285   - --mode full \
286   - --days 365 \
287   - --publish-alias \
288   - --keep-versions 2
289   -```
290   -
291   -增量更新(基于 watermark):
292   -
293   -```bash
294   -python main.py build-suggestions \
295   - --tenant-id 162 \
296   - --mode incremental \
297   - --overlap-minutes 30
298   -```
299   -
300   -一键脚本(全量 + 增量 + ES/API 验证):
301   -
302   -```bash
303   -./scripts/rebuild_suggestions.sh 162
304   -```
suggestion/README.md
1   -# Suggestion 模块说明(统一入口)
  1 +# Suggestion 模块说明
2 2  
3   -本文档是 suggestion 模块的统一入口,遵循 `docs/DEVELOPER_GUIDE.md` 的“单一入口、避免分叉”原则
  3 +`suggestion/` 目录负责搜索框自动补全能力,当前实现只关注 suggestion 本身:离线构建建议词索引,在线根据输入前缀返回建议词列表
4 4  
5   -## 1. 当前状态(Phase 2)
  5 +这份 README 以当前代码实现为准,重点说明模块现状、关键设计、索引结构、构建发布方式,以及在线检索和排序细节。
6 6  
7   -- 仅保留 Suggest 自动补全能力
8   -- 不支持结果直达(`with_results` / `result_size` / `products[]` 已移除)
9   -- 索引采用版本化发布:
10   - - 物理索引:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_v<timestamp>`
11   - - 读别名:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_current`
12   -- 支持增量更新(watermark + overlap)
  7 +## 1. 当前能力边界
13 8  
14   -## 2. 文档导航(唯一推荐顺序)
  9 +- 对外接口:`GET /search/suggestions`
  10 +- 输入参数:`q`、`size`、`language`
  11 +- Header:`X-Tenant-ID`
  12 +- 返回内容:建议词列表 `suggestions[]`
  13 +- 不做商品结果拼接,也不走二次商品查询链路
  14 +
  15 +当前模块由三部分组成:
  16 +
  17 +1. 离线构建:从商品索引和搜索日志构建 suggestion 文档
  18 +2. 索引发布:写入版本化索引,并通过 alias 原子切换
  19 +3. 在线查询:优先走 completion,必要时再走 `search_as_you_type` 兜底召回
  20 +
  21 +## 2. 目录与关键代码
  22 +
  23 +- [builder.py](/data/saas-search/suggestion/builder.py):离线构建、增量更新、alias 发布、meta 状态维护
  24 +- [mapping.py](/data/saas-search/suggestion/mapping.py):suggestion 索引 settings 和 mappings 生成
  25 +- [service.py](/data/saas-search/suggestion/service.py):在线查询服务,负责语言归一化、双路召回、去重和最终排序
  26 +- [RUNBOOK.md](/data/saas-search/suggestion/RUNBOOK.md):构建、发布、验证操作说明
  27 +- [TROUBLESHOOTING.md](/data/saas-search/suggestion/TROUBLESHOOTING.md):常见问题排查
  28 +
  29 +命令入口在 [main.py](/data/saas-search/main.py) 中的 `build-suggestions` 子命令。
  30 +
  31 +## 3. 整体架构
  32 +
  33 +在线路径:
  34 +
  35 +`Client -> /search/suggestions -> SuggestionService -> Elasticsearch suggestion alias`
  36 +
  37 +离线路径:
  38 +
  39 +`商品索引 + 搜索日志 -> SuggestionIndexBuilder -> 版本化 suggestion index -> alias publish`
  40 +
  41 +设计上有几个核心点:
  42 +
  43 +- suggestion 独立建索引,不依赖在线商品检索
  44 +- 每个租户单独维护 suggestion alias,避免租户间相互影响
  45 +- 全量构建写新索引,切换 alias 时零停机
  46 +- 增量更新只处理 query log 增量,减少重建成本
  47 +
  48 +## 4. 索引组织与发布
  49 +
  50 +索引命名在 [builder.py](/data/saas-search/suggestion/builder.py) 中统一定义:
  51 +
  52 +- 读别名:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_current`
  53 +- 版本索引:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_v<timestamp>`
  54 +- 元信息索引:`{ES_INDEX_NAMESPACE}search_suggestions_meta`
  55 +
  56 +当前实现是“每租户一个 suggestion alias + 多个版本索引”的模式,而不是环境级共享大索引。
  57 +
  58 +全量构建时的发布流程:
  59 +
  60 +1. 创建新的版本化索引
  61 +2. 写入本次构建出的 suggestion 文档
  62 +3. 校验新索引可分配、可读
  63 +4. alias 原子切换到新索引
  64 +5. 清理旧版本索引,只保留最近若干份
  65 +6. 更新 `search_suggestions_meta`
  66 +
  67 +元信息索引里记录:
  68 +
  69 +- `active_alias`
  70 +- `active_index`
  71 +- `last_full_build_at`
  72 +- `last_incremental_build_at`
  73 +- `last_incremental_watermark`
  74 +
  75 +这些信息主要服务于增量更新和排障。
  76 +
  77 +## 5. Mapping 与索引字段
  78 +
  79 +[mapping.py](/data/saas-search/suggestion/mapping.py) 会根据租户的 `index_languages` 动态生成字段。
  80 +
  81 +### 5.1 索引设置
  82 +
  83 +- `number_of_shards = 1`
  84 +- `number_of_replicas = 0`
  85 +- `refresh_interval = 30s`
  86 +
  87 +中文使用自定义 analyzer:
  88 +
  89 +- `index_ik`:`ik_max_word + lowercase + asciifolding`
  90 +- `query_ik`:`ik_smart + lowercase + asciifolding`
  91 +
  92 +其他语言优先使用 Elasticsearch 内置 analyzer,例如 `english`、`arabic`、`french`、`german` 等;未覆盖语言回退到 `standard`。
  93 +
  94 +### 5.2 核心字段
  95 +
  96 +- `tenant_id`:租户隔离
  97 +- `lang`:建议词所属语言
  98 +- `text`:原始展示文本
  99 +- `text_norm`:归一化文本,用于唯一键和去重
  100 +- `sources`:来源集合,可能包含 `title`、`qanchor`、`tag`、`query_log`
  101 +- `title_doc_count` / `qanchor_doc_count` / `tag_doc_count`:该词被多少商品字段支撑
  102 +- `query_count_7d` / `query_count_30d`:近 7/30 天搜索热度
  103 +- `rank_score`:离线预计算排序分
  104 +- `lang_confidence` / `lang_source` / `lang_conflict`:语言识别与冲突信息
  105 +- `status`:当前是否有效
  106 +- `updated_at`:最近更新时间
  107 +
  108 +### 5.3 两类检索字段
  109 +
  110 +1. `completion.<lang>`
  111 +2. `sat.<lang>`
  112 +
  113 +`completion.<lang>` 用于极速前缀补全,是短 query 下的主召回通道。
  114 +
  115 +`sat.<lang>` 使用 `search_as_you_type`,用于多词前缀和 completion 未补足时的兜底召回。
  116 +
  117 +也就是说,当前线上不是只靠一种召回方式,而是 completion 优先、SAT 补全。
  118 +
  119 +## 6. 候选词从哪里来
  120 +
  121 +[builder.py](/data/saas-search/suggestion/builder.py) 在全量构建中会聚合两大类数据源。
  122 +
  123 +### 6.1 商品侧
  124 +
  125 +从租户商品索引中流式读取:
  126 +
  127 +- `title`
  128 +- `qanchors`
  129 +- `enriched_tags`
  130 +
  131 +处理方式:
  132 +
  133 +- `title.<lang>`:经 `_prepare_title_for_suggest()` 裁剪后作为候选词
  134 +- `qanchors.<lang>`:按分隔符拆分后作为候选词
  135 +- `enriched_tags`:支持多语言对象或普通列表,必要时做语言识别
  136 +
  137 +商品扫描不是一次性全量拉入内存,而是通过 `search_after` 分批读取,这一点在 [_iter_products()](/data/saas-search/suggestion/builder.py#L363) 已实现。
  138 +
  139 +### 6.2 搜索日志侧
  140 +
  141 +从 MySQL `shoplazza_search_log` 中按时间窗口流式读取:
  142 +
  143 +- `query`
  144 +- `language`
  145 +- `request_params`
  146 +- `create_time`
  147 +
  148 +读取方式使用 `stream_results=True + fetchmany()`,避免 `fetchall()` 带来的内存风险,这也是当前实现相对旧方案的重要改进。
  149 +
  150 +搜索日志主要用于补充:
  151 +
  152 +- 用户真实搜索词
  153 +- 近 7/30 天热度
  154 +- 语言归属信息
  155 +
  156 +## 7. 文本清洗与语言策略
  157 +
  158 +### 7.1 文本归一化
  159 +
  160 +在 [_normalize_text()](/data/saas-search/suggestion/builder.py#L176) 中,当前实现会做:
  161 +
  162 +- Unicode `NFKC` 归一化
  163 +- 去首尾空白
  164 +- 转小写
  165 +- 多空白折叠为单空格
  166 +
  167 +这份 `text_norm` 是 suggestion 文档的稳定键的一部分,文档 `_id` 形式为:
  168 +
  169 +`{tenant_id}|{lang}|{text_norm}`
  170 +
  171 +这保证了同租户、同语言、同一归一化词面只会保留一份文档。
  172 +
  173 +### 7.2 噪声过滤
  174 +
  175 +在 [_looks_noise()](/data/saas-search/suggestion/builder.py#L264) 中,以下内容会被过滤:
  176 +
  177 +- 空文本
  178 +- 长度超过 120
  179 +- 全部由符号组成的文本
15 180  
16   -1. `ARCHITECTURE_V2.md`:架构与设计原则
17   -2. `RUNBOOK.md`:构建/发布/验证流程
18   -3. `TROUBLESHOOTING.md`:常见问题排查
  181 +### 7.3 语言判定优先级
19 182  
20   -## 3. 命令入口
  183 +日志 query 的语言归属由 [_resolve_query_language()](/data/saas-search/suggestion/builder.py#L299) 负责,优先级是:
21 184  
22   -- 全量或增量构建:
  185 +1. `shoplazza_search_log.language`
  186 +2. `request_params.language`
  187 +3. `detect_text_language_for_suggestions()`
  188 +4. 租户 `primary_language`
  189 +
  190 +同时会记录:
  191 +
  192 +- `lang_source`:语言来自哪里
  193 +- `lang_confidence`:识别置信度
  194 +- `lang_conflict`:日志语言与请求语言是否冲突
  195 +
  196 +在线查询侧在 [_resolve_language()](/data/saas-search/suggestion/service.py#L24) 也会做一次语言归一化,确保查询只打到租户允许的 `index_languages`。
  197 +
  198 +## 8. 排序与 rank 细节
  199 +
  200 +当前排序分成两层:离线 `rank_score`,以及在线最终排序。
  201 +
  202 +### 8.1 离线 `rank_score`
  203 +
  204 +在 [_compute_rank_score()](/data/saas-search/suggestion/builder.py#L338) 中,当前公式是:
  205 +
  206 +```text
  207 +rank_score =
  208 + 1.8 * log1p(query_count_30d)
  209 + + 1.2 * log1p(query_count_7d)
  210 + + 1.0 * log1p(qanchor_doc_count)
  211 + + 0.85 * log1p(tag_doc_count)
  212 + + 0.6 * log1p(title_doc_count)
  213 +```
  214 +
  215 +含义上是:
  216 +
  217 +- 搜索日志热度权重大于商品静态字段
  218 +- 30 天热度权重大于 7 天热度,但 7 天热度也会强化近期趋势
  219 +- `qanchor` 比普通标题更像“可搜索表达”,所以权重更高
  220 +- `tag` 次之
  221 +- `title` 提供基础覆盖,但权重相对更低
  222 +
  223 +这个分数会被写入:
  224 +
  225 +- 文档字段 `rank_score`
  226 +- `completion.<lang>.weight`
  227 +
  228 +因此它同时影响 completion 通道和 SAT 通道。
  229 +
  230 +### 8.2 在线召回排序
  231 +
  232 +[service.py](/data/saas-search/suggestion/service.py) 中的在线策略如下:
  233 +
  234 +1. 先查 `completion.<lang>`
  235 +2. 若 query 长度大于 2 且 completion 结果不足,再查 `sat.<lang>`
  236 +3. 按 `text` 归一化结果去重
  237 +4. 最终排序后截断
  238 +
  239 +completion 通道本身依赖 ES completion 的 `_score`;SAT 通道则用 `function_score + field_value_factor(rank_score)`。
  240 +
  241 +最终排序由 [_finalize_suggestion_list()](/data/saas-search/suggestion/service.py#L155) 负责,排序 key 为:
  242 +
  243 +1. `score * 长度惩罚系数`
  244 +2. `rank_score`
  245 +
  246 +长度惩罚定义在 [_suggestion_length_factor()](/data/saas-search/suggestion/service.py#L16):
  247 +
  248 +```text
  249 +length_factor = 1 / sqrt(token_len)
  250 +```
  251 +
  252 +这意味着在分数相近时,较短、较直接的 suggestion 会更容易排在前面,避免长尾长句把前缀补全结果“顶掉”。
  253 +
  254 +## 9. 在线查询细节
  255 +
  256 +[SuggestionService.search()](/data/saas-search/suggestion/service.py#L110) 的行为可以概括为:
  257 +
  258 +- 如果 alias 不存在,直接返回空数组,不抛 500
  259 +- 短 query 优先走 completion 快速返回
  260 +- 对于更长 query,再补一次 `bool_prefix`
  261 +- 查询时始终带 `routing=tenant_id`
  262 +
  263 +这里的 `routing` 很重要,它保证 suggestion 查询尽量只落在目标租户对应的分片路由上,减少无效 fan-out。
  264 +
  265 +SAT 查询部分还有两个显式过滤条件:
  266 +
  267 +- `lang == resolved_language`
  268 +- `status == 1`
  269 +
  270 +这能保证召回结果只来自当前语言、当前有效文档。
  271 +
  272 +## 10. 全量构建与增量更新
  273 +
  274 +### 10.1 全量构建
  275 +
  276 +入口在 [main.py](/data/saas-search/main.py#L104) 的 `build-suggestions --mode full`。
  277 +
  278 +行为是:
  279 +
  280 +1. 读取租户配置中的 `index_languages` 和 `primary_language`
  281 +2. 创建新版本索引并等待 ready
  282 +3. 聚合商品数据和搜索日志,构造候选词
  283 +4. 计算 `rank_score`
  284 +5. bulk 写入新索引
  285 +6. refresh
  286 +7. 发布 alias
  287 +8. 更新 meta 信息
  288 +
  289 +### 10.2 增量更新
  290 +
  291 +入口在 [main.py](/data/saas-search/main.py#L104) 的 `build-suggestions --mode incremental`。
  292 +
  293 +当前增量只处理 query log,不回扫商品数据。它依赖 meta 中的 watermark:
  294 +
  295 +- `last_incremental_watermark`
  296 +- 不存在时回退到 `last_full_build_at`
  297 +- 再不行就使用 `fallback_days`
  298 +
  299 +为了避免边界时间漏数,会额外减去 `overlap_minutes`,形成一个带重叠窗口的增量区间。
  300 +
  301 +增量写入不是整文档重建,而是通过 `scripted_upsert` 做原地累加:
  302 +
  303 +- 增加 `query_count_30d`
  304 +- 增加 `query_count_7d`
  305 +- 更新 `lang_confidence` / `lang_source` / `lang_conflict`
  306 +- 重新计算 `rank_score`
  307 +- 更新 `completion` 和 `sat`
  308 +
  309 +对应逻辑在 [_build_incremental_update_script()](/data/saas-search/suggestion/builder.py#L834)。
  310 +
  311 +如果 alias 尚不存在,而 `bootstrap_if_missing=True`,增量任务会先自动做一次全量构建作为初始化。
  312 +
  313 +## 11. 当前实现的一些取舍
  314 +
  315 +### 11.1 优点
  316 +
  317 +- 在线链路很短,没有 suggestion 后再查商品的放大成本
  318 +- 构建和发布流程清晰,支持零停机切换
  319 +- 商品侧和日志侧都采用流式处理,能控制内存占用
  320 +- completion + SAT 双路召回兼顾低延迟和补全能力
  321 +- 语言、热度、来源信息都保存在索引中,便于后续优化
  322 +
  323 +### 11.2 现阶段边界
  324 +
  325 +- 增量更新目前只增量处理 query log,商品标题、qanchor、tag 变更仍依赖全量构建刷新
  326 +- `rank_score` 目前只使用热度和商品字段覆盖度,没有接入点击、转化等行为质量信号
  327 +- 文本归一化目前以 `NFKC + lower + whitespace fold` 为主,尚未做更激进的跨语种归并策略
  328 +- `number_of_replicas=0` 更偏开发或成本优先配置,生产是否需要副本要结合集群策略评估
  329 +
  330 +## 12. 常用命令
  331 +
  332 +全量构建:
23 333  
24 334 ```bash
25 335 ./scripts/build_suggestions.sh <tenant_id> --mode full
26   -./scripts/build_suggestions.sh <tenant_id> --mode incremental
27 336 ```
28 337  
29   -- 一键重建 + 验证
  338 +增量构建
30 339  
31 340 ```bash
32   -./scripts/rebuild_suggestions.sh <tenant_id>
  341 +./scripts/build_suggestions.sh <tenant_id> --mode incremental
33 342 ```
34 343  
35   -## 4. API 约定(简版)
  344 +一键重建并验证:
36 345  
37   -- 端点:`GET /search/suggestions`
38   -- 参数:`q`, `size`, `language`
39   -- Header:`X-Tenant-ID`
  346 +```bash
  347 +./scripts/rebuild_suggestions.sh <tenant_id>
  348 +```
40 349  
41   -示例:
  350 +接口示例:
42 351  
43 352 ```bash
44 353 curl "http://localhost:6002/search/suggestions?q=shi&size=10&language=en" \
45 354 -H "X-Tenant-ID: 162"
46 355 ```
  356 +
  357 +更多操作细节见 [RUNBOOK.md](/data/saas-search/suggestion/RUNBOOK.md),故障排查看 [TROUBLESHOOTING.md](/data/saas-search/suggestion/TROUBLESHOOTING.md)。
... ...
suggestion/RUNBOOK.md
... ... @@ -34,7 +34,7 @@ DB_PASSWORD=...
34 34 ### 3.1 执行
35 35  
36 36 ```bash
37   -./scripts/build_suggestions.sh 162 \
  37 +./scripts/build_suggestions.sh 163 \
38 38 --mode full \
39 39 --days 365 \
40 40 --publish-alias \
... ... @@ -56,7 +56,7 @@ DB_PASSWORD=...
56 56 ### 4.1 执行
57 57  
58 58 ```bash
59   -./scripts/build_suggestions.sh 162 \
  59 +./scripts/build_suggestions.sh 163 \
60 60 --mode incremental \
61 61 --overlap-minutes 30
62 62 ```
... ... @@ -76,7 +76,7 @@ DB_PASSWORD=...
76 76 > 若 ES 开启鉴权,请附带 `-u "$ES_USERNAME:$ES_PASSWORD"`。
77 77  
78 78 ```bash
79   -ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_162_current"
  79 +ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_163_current"
80 80  
81 81 curl "$ES_HOST/$ALIAS_NAME/_count?pretty"
82 82  
... ... @@ -97,10 +97,10 @@ curl &quot;$ES_HOST/$ALIAS_NAME/_search?pretty&quot; -H &#39;Content-Type: application/json&#39; -
97 97  
98 98 ```bash
99 99 curl "http://localhost:6002/search/suggestions?q=shirt&size=10&language=en" \
100   - -H "X-Tenant-ID: 162"
  100 + -H "X-Tenant-ID: 163"
101 101  
102 102 curl "http://localhost:6002/search/suggestions?q=玩具&size=10&language=zh" \
103   - -H "X-Tenant-ID: 162"
  103 + -H "X-Tenant-ID: 163"
104 104 ```
105 105  
106 106 通过标准:
... ... @@ -112,7 +112,7 @@ curl &quot;http://localhost:6002/search/suggestions?q=玩具&amp;size=10&amp;language=zh&quot; \
112 112 ## 7. 一键验证脚本
113 113  
114 114 ```bash
115   -./scripts/rebuild_suggestions.sh 162
  115 +./scripts/rebuild_suggestions.sh 163
116 116 ```
117 117  
118 118 该脚本执行:
... ...