Commit 6e3e677078c096cef77127bd0426099e16159cef

Authored by tangwang
1 parent 9f33fe3c

suggest文档维护

docs/issues/issue-2026-04-06-推理优化-2.md deleted
@@ -1,69 +0,0 @@ @@ -1,69 +0,0 @@
1 -这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查:  
2 -  
3 -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。  
4 -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。  
5 -  
6 -示例prompt和对应的9个分类词:  
7 -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query}  
8 -  
9 -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。  
10 -  
11 -使用Tesla T4,因此:  
12 -1. 使用FP16。不用 BF16 作为主路径  
13 -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention  
14 -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。  
15 -  
16 -  
17 -主要考虑优化方向为:  
18 -hidden_last -> N-class scorer -> argmax  
19 -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV  
20 -去 full vocab logits  
21 -去 decode / constrained decode  
22 -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8)  
23 -专用 tail kernel(输出 N 类原始分数)  
24 -  
25 -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化:  
26 -15.1 预分配  
27 -对每个 bucket 预分配:  
28 -input ids buffer  
29 -attention scratch buffer  
30 -output hidden buffer  
31 -class id buffer  
32 -15.2 Graph capture  
33 -  
34 -每个 bucket 预先 capture 一张图:  
35 -embedding  
36 -transformer 主干  
37 -compact last-state  
38 -9-way tail kernel  
39 -  
40 -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。  
41 -  
42 -你有sudo权限,你可以执行为本项目安装自己的环境  
43 -  
44 -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型)  
45 -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。  
46 -  
47 -  
48 -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。  
49 -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。  
50 -注意使用fp16版本,不要使用量化版本。  
51 -  
52 -上面的需求你已经满足了大部分,  
53 -但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。  
54 -需要通过多 token 标签做极致的性能优化,避免串行decode。  
55 -一个考虑的方向是:  
56 -将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch)  
57 -  
58 -只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。  
59 -  
60 -所有后续聚合均为向量化操作(GPU 上几乎不花时间)。  
61 -  
62 -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。  
63 -  
64 -参考文档:  
65 -  
66 -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/  
67 -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/  
68 -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html  
69 -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq  
70 \ No newline at end of file 0 \ No newline at end of file
docs/issues/issue-2026-04-06-推理优化-3.md deleted
@@ -1,98 +0,0 @@ @@ -1,98 +0,0 @@
1 -这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查:  
2 -  
3 -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。  
4 -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。  
5 -  
6 -示例prompt和对应的9个分类词:  
7 -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query}  
8 -  
9 -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。  
10 -  
11 -使用Tesla T4,因此:  
12 -1. 使用FP16。不用 BF16 作为主路径  
13 -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention  
14 -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。  
15 -  
16 -  
17 -主要考虑优化方向为:  
18 -hidden_last -> N-class scorer -> argmax  
19 -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV  
20 -去 full vocab logits  
21 -去 decode / constrained decode  
22 -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8)  
23 -专用 tail kernel(输出 N 类原始分数)  
24 -  
25 -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化:  
26 -15.1 预分配  
27 -对每个 bucket 预分配:  
28 -input ids buffer  
29 -attention scratch buffer  
30 -output hidden buffer  
31 -class id buffer  
32 -15.2 Graph capture  
33 -  
34 -每个 bucket 预先 capture 一张图:  
35 -embedding  
36 -transformer 主干  
37 -compact last-state  
38 -9-way tail kernel  
39 -  
40 -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。  
41 -  
42 -你有sudo权限,你可以执行为本项目安装自己的环境  
43 -  
44 -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型)  
45 -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。  
46 -  
47 -  
48 -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。  
49 -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。  
50 -注意使用fp16版本,不要使用量化版本。  
51 -  
52 -上面的需求你已经满足了大部分,  
53 -但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。  
54 -需要通过多 token 标签做极致的性能优化,避免串行decode。  
55 -一个考虑的方向是:  
56 -将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch)  
57 -  
58 -只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。  
59 -  
60 -所有后续聚合均为向量化操作(GPU 上几乎不花时间)。  
61 -  
62 -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。  
63 -  
64 -参考文档:  
65 -  
66 -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/  
67 -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/  
68 -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html  
69 -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq  
70 -  
71 -  
72 -现在还有以下问题:  
73 -1. 请满足:  
74 -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。  
75 -  
76 -2. 打分结果不对:不管输入什么都输出jeans,输入裙子也输出jeans,只有精确的输入skirt才能输出skirt。  
77 -  
78 -3. 现在是一个prompt,我需要增加6个prompt,都进行配置化,每次输入一个query、对7个prompt进行批量的推理,对每个prompt都输出词典中每个分类的打分分布  
79 -人群  
80 -Analyze the target user group of the given query. Output exactly one label from: boy, girl, man, woman, pregnant. Output nothing else query:{query}  
81 -  
82 -尺寸  
83 -Analyze the size intent of the given query. Output exactly one label from: xs, s, m, l, xl, xxl, plus, petite. Output nothing else query:{query}  
84 -  
85 -领口  
86 -Analyze the neckline type of the given query. Output exactly one label from: round, v, turtleneck, collared, scoop, offshoulder. Output nothing else query:{query}  
87 -  
88 -袖长类型  
89 -Analyze the sleeve length of the given query. Output exactly one label from: long, short, sleeveless, half, threequarter, cap. Output nothing else query:{query}  
90 -  
91 -上衣长度类型  
92 -Analyze the top length of the given query. Output exactly one label from: long, short, regular, cropped, tunic, midi. Output nothing else query:{query}  
93 -  
94 -面料  
95 -Analyze the fabric material of the given query. Output exactly one label from: cotton, linen, silk, wool, polyester, denim, leather, chiffon, fleece. Output nothing else query:{query}  
96 -  
97 -  
98 -注意要使用测试用例进行测试。包括打分结果是否符合预期、性能测试。  
docs/issues/issue-2026-04-06-推理优化-4.md deleted
@@ -1,4 +0,0 @@ @@ -1,4 +0,0 @@
1 -  
2 -1. “仓库启动时会为所有 batch_sizes x continuation_buckets x prompt 预建 bucket”,之前可能是根据“极致的性能要求”、“不要做任何懒加载,确保真实请求发生时得到极致的响应时间”所做的设计,这显然太过了,我是希望一些基本的可以事先加载的应该先加载,但是牺牲巨大的显存占用来换取微弱的耗时提升是不提倡的,请你站在更高的角度,理会我的需求(先加载好模型、跨session的KV cache、并针对特殊用法 即score方式而不是逐步decode方式,来极致的优化性能,用于对线上的单个query,以最短耗时得到7个prompt的分类结果)  
3 -  
4 -2. 现在的 7 个 prompt推理是串行的,MultiPromptRunner.score_query() 里就是 for runner in self.runners 一个一个跑,需要把执行模型改成“按 prompt 分组批量并行”,但是现在有 2 个 fast path 和 5 个 multi-token path,是各合成一次 forward,还是有可能合成一个?因为multi-token也可以一次线prefill进去,是否能做到跟fast path同级别的性能?请你站在更高的角度进行思考,保证性能的同时降低复杂性。  
5 \ No newline at end of file 0 \ No newline at end of file
docs/issues/issue-2026-04-06-推理优化-重建.md deleted
@@ -1,1003 +0,0 @@ @@ -1,1003 +0,0 @@
1 -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。  
2 -先专注于推理的优化,最后再考虑服务化,支持一定程度的并发(比如4)的请求,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,使用N个prompt进行N个维度的分类,对于每个prompt,推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。  
3 -  
4 -下面有一些参考技术资料,但是你并不需要严格,你应该有一定的灵活度,来追求极致的性能。  
5 -  
6 -在 Tesla T4 上,用 3B 到 6B 级别的开源 decoder-only 基座模型做 query 分类。  
7 -启动时完成 tokenizer、权重、prefix cache 和共享执行器准备工作。  
8 -每次输入一个 query,输出每个 prompt 下每个 label 的分数分布,以及预测耗时和阶段耗时。  
9 -不走通用生成路径,不做 decode,不取 full vocab logits,不做 constrained decode。  
10 -对 multi-token label 做专门优化,避免 Python 侧串行 decode。  
11 -prompt 和 label 集合必须可配置(目前只有两个,以后我会加到8个,每次请求输入一个query,并行的调用8个prompt进行推理得到得分最高的label:  
12 -  
13 -  
14 -prompts:  
15 - - name: category  
16 - prompt_template: |  
17 - <|im_start|>system  
18 - Analyze the category intent. Output exactly one label from [none, dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other]. Use 'none' if no category intent.<|im_end|>  
19 - <|im_start|>user  
20 - query: {query}<|im_end|>  
21 - <|im_start|>assistant  
22 - label:  
23 - label_prefix: " "  
24 - labels: [dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other, none]  
25 -  
26 - - name: audience  
27 - prompt_template: |  
28 - <|im_start|>system  
29 - Analyze the target user group. Output exactly one label from [none, boy, girl, man, woman, pregnant]. Use 'none' if no audience mentioned.<|im_end|>  
30 - <|im_start|>user  
31 - query: {query}<|im_end|>  
32 - <|im_start|>assistant  
33 - label:  
34 - label_prefix: " "  
35 - labels: [boy, girl, man, woman, pregnant, none]  
36 -  
37 -  
38 -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟。  
39 -  
40 -主要考虑优化方向为:  
41 -1. hidden_last -> N-class scorer -> argmax  
42 -2. 参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV  
43 -3. 去 full vocab logits  
44 -4. 去 decode / constrained decode  
45 -5. 专用 tail kernel(输出 N 类原始分数)  
46 -6. 配置的N个 prompt推理要并行推理(2-8个)  
47 -7. 使用Tesla T4,因此不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention  
48 -多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。  
49 -  
50 -  
51 -你有sudo权限,你可以执行为本项目安装自己的环境  
52 -  
53 -使用Qwen/Qwen3-8B的Q4或Q8模型,具体用哪个版本,请你查找huggingface相关资料,选择合适的版本完成部署,并进行推理耗时的测试。  
54 -  
55 -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。  
56 -  
57 -一个重要的问题:一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况。  
58 -需要通过多 token 标签做极致的性能优化,避免串行decode。  
59 -我们最终目的是得到哪个label的得分最高,不一定要精确的概率,计算log P(id1 | query, prompt) + log P(id2 | query, prompt, id1)有可能导致难以优化性能,精确的概率是可以考虑放弃的,要清楚我们的最终目的,达到分类的目的即可,只要得到分类,优先考虑性能,精确的概率可以放下。  
60 -如何通过一次模型 forward处理包括多token label的整个 batch,是你需要探索的问题。  
61 -  
62 -单 token fast path 的做法比较确定: last_hidden -> small class scorer -> argmax。 只取目标 label 对应 LM head 行,不做 full vocab 输出。  
63 -multi-token 怎么做需要搜索相关资料进行考量,最好要做到跟单token开销相同(放弃精确的log-prob的前提下。但是:多token和单token的label的打分的对比,一定要是可比的才能正确的分类,兼顾性能和打分的准确性)  
64 -  
65 -还需要增加一个配置:force_single_token_labels,所有 label 都按首 token 处理,因为,如果各个label收token不同,那么可以近似的认为首token打分代表整个label打分。  
66 -你需要找到多label打分性能和准确性上面的最佳实践。同时也支持force_single_token_labels以达到极致的性能。  
67 -  
68 -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。  
69 -以下是一些参考示例:  
70 -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/  
71 -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/  
72 -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html  
73 -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq  
74 -  
75 -TensorRT support matrix: T4 / SM7.5 supports FP16 and INT8, but not BF16/FP8 in the main matrix. https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Support-Matrix-Guide.pdf  
76 -TensorRT-LLM support matrix: current official hardware list omits Turing/T4, so T4 is effectively community-support territory there. https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html  
77 -FlashAttention repo: FA3 is Hopper-focused; current published benchmarks are A100/H100-centric. https://github.com/Dao-AILab/flash-attention  
78 -vLLM APC docs: KV reuse via hashed KV blocks was the baseline idea for the prefix-cache metadata. https://docs.vllm.ai/_/downloads/en/v0.6.2/pdf/  
79 -SGLang HiCache/RadixAttention docs: useful reference for prefix-cache reuse and page-granular KV organization. https://docs.sglang.io/advanced_features/hicache_design.html  
80 -FasterTransformer repo: still a useful T4 FP16 optimization baseline and historical Turing-oriented reference. https://github.com/NVIDIA/FasterTransformer  
81 -xFormers README: relevant as a Turing-friendly attention alternative; my mainline choice here is PyTorch SDPA on T4, which is an engineering inference from these sources rather than a direct vendor recommendation. https://github.com/facebookresearch/xformers  
82 -  
83 -注意:已经有一个项目 llm-qp, llm-qp2,这两个项目,对于单token的处理方式是可以的:  
84 -SDPA  
85 -prefix cache  
86 -prebuilt bucket + CUDA graph  
87 -他的核心代码是:  
88 -from __future__ import annotations  
89 -  
90 -import hashlib  
91 -import time  
92 -from dataclasses import asdict, dataclass  
93 -from typing import Iterable  
94 -  
95 -import torch  
96 -import torch.nn.functional as F  
97 -from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache  
98 -  
99 -try:  
100 - from transformers import BitsAndBytesConfig  
101 -except ImportError: # pragma: no cover  
102 - BitsAndBytesConfig = None  
103 -  
104 -from llm_qp.config import PromptTaskConfig, RuntimeConfig  
105 -from llm_qp.scorer import SmallClassScorer  
106 -  
107 -try:  
108 - from torch.nn.attention import SDPBackend, sdpa_kernel  
109 -except ImportError: # pragma: no cover  
110 - SDPBackend = None  
111 - sdpa_kernel = None  
112 -  
113 -  
114 -@dataclass(slots=True)  
115 -class EncodedLabel:  
116 - text: str  
117 - token_ids: list[int]  
118 -  
119 -  
120 -@dataclass(slots=True)  
121 -class PrefixCache:  
122 - prefix_ids: list[int]  
123 - prefix_hashes: list[str]  
124 - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...]  
125 -  
126 - @property  
127 - def prefix_len(self) -> int:  
128 - return len(self.prefix_ids)  
129 -  
130 -  
131 -@dataclass(slots=True)  
132 -class MultiTokenTables:  
133 - label_token_ids: torch.Tensor  
134 - label_token_mask: torch.Tensor  
135 - label_prefix_ids: torch.Tensor  
136 - label_prefix_mask: torch.Tensor  
137 - label_position_offsets: torch.Tensor  
138 -  
139 - @property  
140 - def max_label_len(self) -> int:  
141 - return self.label_token_ids.shape[1]  
142 -  
143 - @property  
144 - def max_label_prefix_len(self) -> int:  
145 - return self.label_prefix_ids.shape[1]  
146 -  
147 -  
148 -@dataclass(slots=True)  
149 -class QueryScoreResult:  
150 - task_name: str  
151 - query: str  
152 - predicted_label: str  
153 - scores: list[tuple[str, float, float]]  
154 - total_ms: float  
155 - stage_ms: dict[str, float]  
156 - fast_path: bool  
157 - prefix_tokens: int  
158 - continuation_tokens: int  
159 - label_token_lengths: dict[str, int]  
160 -  
161 - @property  
162 - def predicted_prob(self) -> float:  
163 - for label, _score, prob in self.scores:  
164 - if label == self.predicted_label:  
165 - return prob  
166 - return 0.0  
167 -  
168 -  
169 -@dataclass(slots=True)  
170 -class MultiPromptScoreResult:  
171 - query: str  
172 - total_ms: float  
173 - details: list[QueryScoreResult]  
174 - stage_ms: dict[str, float]  
175 -  
176 - def http_json(self) -> dict[str, object]:  
177 - return {  
178 - "query": self.query,  
179 - "total_ms": self.total_ms,  
180 - "stage_ms": self.stage_ms,  
181 - "details": [asdict(t) for t in self.details],  
182 - "task_results": {  
183 - t.task_name: [t.predicted_label, t.continuation_tokens, t.predicted_prob] for t in self.details if t.predicted_label != 'none'  
184 - },  
185 - }  
186 -  
187 -  
188 -@dataclass(slots=True)  
189 -class BatchScoreResult:  
190 - batch_size: int  
191 - total_ms: float  
192 - results: list[MultiPromptScoreResult]  
193 - stage_ms: dict[str, float]  
194 -  
195 -  
196 -@dataclass(slots=True)  
197 -class SharedRuntime:  
198 - device: torch.device  
199 - dtype: torch.dtype  
200 - tokenizer: object  
201 - model: object  
202 - backbone: object  
203 - hidden_size: int  
204 - graph_capture_pool: object | None = None  
205 - graph_capture_stream: torch.cuda.Stream | None = None  
206 -  
207 -  
208 -@dataclass(slots=True)  
209 -class PromptBatchPlan:  
210 - runner: "PromptClassifierRunner"  
211 - row_start: int  
212 - row_count: int  
213 - score_buffer: torch.Tensor  
214 -  
215 - @property  
216 - def row_stop(self) -> int:  
217 - return self.row_start + self.row_count  
218 -  
219 -  
220 -@dataclass(slots=True)  
221 -class MixedPrefixCache:  
222 - batch_size: int  
223 - total_rows: int  
224 - prefix_lengths: torch.Tensor  
225 - attention_mask: torch.Tensor  
226 - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...]  
227 -  
228 - @property  
229 - def max_prefix_len(self) -> int:  
230 - return int(self.prefix_lengths.max().item()) if self.prefix_lengths.numel() else 0  
231 -  
232 -  
233 -@dataclass(slots=True)  
234 -class BatchLayout:  
235 - batch_size: int  
236 - total_rows: int  
237 - plans: list[PromptBatchPlan]  
238 -  
239 -  
240 -@dataclass(slots=True)  
241 -class MixedBucketRuntime:  
242 - batch_size: int  
243 - total_rows: int  
244 - continuation_len: int  
245 - max_input_len: int  
246 - input_ids: torch.Tensor  
247 - attention_mask: torch.Tensor  
248 - position_ids: torch.Tensor  
249 - last_hidden_state: torch.Tensor  
250 - graph: torch.cuda.CUDAGraph | None = None  
251 -  
252 -  
253 -@dataclass(slots=True)  
254 -class PreloadReport:  
255 - total_ms: float  
256 - stage_ms: dict[str, float]  
257 - runtime: dict[str, object]  
258 -  
259 -  
260 -def _hash_blocks(token_ids: Iterable[int], block_size: int) -> list[str]:  
261 - token_list = list(token_ids)  
262 - hashes: list[str] = []  
263 - for start in range(0, len(token_list), block_size):  
264 - block = token_list[start : start + block_size]  
265 - payload = ",".join(str(x) for x in block).encode("utf-8")  
266 - hashes.append(hashlib.sha1(payload).hexdigest())  
267 - return hashes  
268 -  
269 -  
270 -def _expand_legacy_cache(  
271 - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...],  
272 - batch_size: int,  
273 -) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:  
274 - expanded: list[tuple[torch.Tensor, torch.Tensor]] = []  
275 - for key, value in raw_cache:  
276 - expanded.append(  
277 - (  
278 - key.expand(batch_size, *key.shape[1:]).contiguous(),  
279 - value.expand(batch_size, *value.shape[1:]).contiguous(),  
280 - )  
281 - )  
282 - return tuple(expanded)  
283 -  
284 -  
285 -class PromptClassifierRunner:  
286 - def __init__(  
287 - self,  
288 - cfg: RuntimeConfig,  
289 - task_cfg: PromptTaskConfig,  
290 - shared_runtime: SharedRuntime,  
291 - ):  
292 - self.cfg = cfg  
293 - self.task_cfg = task_cfg  
294 - self.device = shared_runtime.device  
295 - self.dtype = shared_runtime.dtype  
296 - self.tokenizer = shared_runtime.tokenizer  
297 - self.model = shared_runtime.model  
298 - self.backbone = shared_runtime.backbone  
299 - self.hidden_size = shared_runtime.hidden_size  
300 - self.prefix_text, self.suffix_text = task_cfg.prompt_parts  
301 - self.prefix_ids = self.tokenizer.encode(self.prefix_text, add_special_tokens=False)  
302 - self.suffix_ids = self.tokenizer.encode(self.suffix_text, add_special_tokens=False)  
303 - self.labels = list(task_cfg.labels)  
304 - self.encoded_labels = [  
305 - EncodedLabel(text=label, token_ids=self._encode_label_token_ids(label))  
306 - for label in self.labels  
307 - ]  
308 - self.num_labels = len(self.labels)  
309 - self.lm_head = self.model.get_output_embeddings()  
310 - self.lm_head_weight = self.lm_head.weight.detach()  
311 - self.lm_head_bias = self.lm_head.bias.detach() if getattr(self.lm_head, "bias", None) is not None else None  
312 - if self.cfg.force_single_token_labels and not self._has_unique_single_token_labels():  
313 - raise ValueError(  
314 - f"prompt task '{self.task_cfg.name}' cannot force single-token labels because first tokens collide"  
315 - )  
316 - self.fast_path = self._has_unique_single_token_labels()  
317 - self.fast_path_token_ids = [item.token_ids[0] for item in self.encoded_labels] if self.fast_path else []  
318 - self.scorer = self._build_scorer() if self.fast_path else None  
319 - self.multi_token_tables = self._build_multi_token_tables() if not self.fast_path else None  
320 - self.prefix_cache = self._build_prefix_cache()  
321 -  
322 - def _encode_label_token_ids(self, label: str) -> list[int]:  
323 - token_ids = self.tokenizer.encode(  
324 - f"{self.task_cfg.label_prefix}{label}",  
325 - add_special_tokens=False,  
326 - )  
327 - if not token_ids:  
328 - raise ValueError(f"label '{label}' in prompt '{self.task_cfg.name}' tokenizes to an empty sequence")  
329 - if self.cfg.force_single_token_labels:  
330 - return token_ids[:1]  
331 - return token_ids  
332 -  
333 - def _has_unique_single_token_labels(self) -> bool:  
334 - token_ids: list[int] = []  
335 - for item in self.encoded_labels:  
336 - if len(item.token_ids) != 1:  
337 - return False  
338 - token_ids.append(item.token_ids[0])  
339 - return len(token_ids) == len(set(token_ids))  
340 -  
341 - def _build_scorer(self) -> SmallClassScorer:  
342 - token_ids = torch.tensor(self.fast_path_token_ids, dtype=torch.long, device=self.device)  
343 - weights = torch.index_select(self.lm_head_weight, 0, token_ids).to(dtype=self.dtype).contiguous()  
344 - bias = None  
345 - if self.lm_head_bias is not None:  
346 - bias = torch.index_select(self.lm_head_bias, 0, token_ids).to(dtype=self.dtype).contiguous()  
347 - return SmallClassScorer(weights=weights, bias=bias)  
348 -  
349 - def _build_multi_token_tables(self) -> MultiTokenTables:  
350 - max_label_len = max(len(item.token_ids) for item in self.encoded_labels)  
351 - max_prefix_len = max(len(item.token_ids) - 1 for item in self.encoded_labels)  
352 - label_token_ids = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.long)  
353 - label_token_mask = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.float32)  
354 - label_prefix_ids = torch.full(  
355 - (self.num_labels, max_prefix_len),  
356 - fill_value=self.tokenizer.pad_token_id,  
357 - device=self.device,  
358 - dtype=torch.long,  
359 - )  
360 - label_prefix_mask = torch.zeros((self.num_labels, max_prefix_len), device=self.device, dtype=torch.long)  
361 - for idx, item in enumerate(self.encoded_labels):  
362 - token_ids = torch.tensor(item.token_ids, device=self.device, dtype=torch.long)  
363 - token_len = token_ids.numel()  
364 - label_token_ids[idx, :token_len] = token_ids  
365 - label_token_mask[idx, :token_len] = 1.0  
366 - if token_len > 1:  
367 - prefix_len = token_len - 1  
368 - label_prefix_ids[idx, :prefix_len] = token_ids[:-1]  
369 - label_prefix_mask[idx, :prefix_len] = 1  
370 - return MultiTokenTables(  
371 - label_token_ids=label_token_ids.contiguous(),  
372 - label_token_mask=label_token_mask.contiguous(),  
373 - label_prefix_ids=label_prefix_ids.contiguous(),  
374 - label_prefix_mask=label_prefix_mask.contiguous(),  
375 - label_position_offsets=torch.arange(max_label_len, device=self.device, dtype=torch.long),  
376 - )  
377 -  
378 - @torch.inference_mode()  
379 - def _build_prefix_cache(self) -> PrefixCache:  
380 - if not self.prefix_ids:  
381 - return PrefixCache(prefix_ids=[], prefix_hashes=[], raw_cache=tuple())  
382 - prefix_tensor = torch.tensor([self.prefix_ids], dtype=torch.long, device=self.device)  
383 - attention_mask = torch.ones_like(prefix_tensor, dtype=torch.long, device=self.device)  
384 - outputs = self.model(  
385 - input_ids=prefix_tensor,  
386 - attention_mask=attention_mask,  
387 - use_cache=True,  
388 - return_dict=True,  
389 - )  
390 - raw_cache = tuple(  
391 - (layer.keys.detach(), layer.values.detach())  
392 - for layer in outputs.past_key_values.layers  
393 - )  
394 - return PrefixCache(  
395 - prefix_ids=list(self.prefix_ids),  
396 - prefix_hashes=_hash_blocks(self.prefix_ids, self.cfg.prefix_block_size),  
397 - raw_cache=raw_cache,  
398 - )  
399 -  
400 - def expand_prefix_raw_cache(self, batch_size: int) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:  
401 - if not self.prefix_cache.raw_cache:  
402 - return tuple()  
403 - return _expand_legacy_cache(self.prefix_cache.raw_cache, batch_size)  
404 -  
405 - def build_continuation_from_query_ids(self, query_ids: list[int]) -> list[int]:  
406 - continuation = query_ids + self.suffix_ids  
407 - if not continuation:  
408 - raise ValueError("prompt continuation is empty after substituting query")  
409 - if self.prefix_cache.prefix_len + len(continuation) > self.cfg.max_length:  
410 - raise ValueError(  
411 - f"sequence length {self.prefix_cache.prefix_len + len(continuation)} exceeds max_length={self.cfg.max_length}"  
412 - )  
413 - return continuation  
414 -  
415 - @torch.inference_mode()  
416 - def reduce_fast_scores(  
417 - self,  
418 - hidden: torch.Tensor,  
419 - out_scores: torch.Tensor,  
420 - ) -> None:  
421 - assert self.scorer is not None  
422 - out_scores.copy_(self.scorer(hidden))  
423 -  
424 - @torch.inference_mode()  
425 - def reduce_multi_token_scores(  
426 - self,  
427 - last_hidden_state: torch.Tensor,  
428 - batch_size: int,  
429 - max_input_len: int,  
430 - score_positions: torch.Tensor,  
431 - out_scores: torch.Tensor,  
432 - ) -> None:  
433 - assert self.multi_token_tables is not None  
434 - hidden = last_hidden_state.reshape(batch_size, self.num_labels, max_input_len, self.hidden_size)  
435 - gather_index = score_positions[:, None, :, None].expand(  
436 - batch_size,  
437 - self.num_labels,  
438 - self.multi_token_tables.max_label_len,  
439 - self.hidden_size,  
440 - )  
441 - gathered_hidden = torch.gather(hidden, 2, gather_index)  
442 - used_mask = self.multi_token_tables.label_token_mask.unsqueeze(0).expand(batch_size, -1, -1).bool()  
443 -  
444 - token_log_probs = torch.zeros(  
445 - (batch_size, self.num_labels, self.multi_token_tables.max_label_len),  
446 - device=self.device,  
447 - dtype=torch.float32,  
448 - )  
449 - if used_mask.any():  
450 - flat_hidden = gathered_hidden[used_mask]  
451 - flat_token_ids = self.multi_token_tables.label_token_ids.unsqueeze(0).expand(batch_size, -1, -1)[used_mask]  
452 - linear_hidden = flat_hidden.to(self.dtype) if self.device.type == "cuda" else flat_hidden.float()  
453 - linear_weight = self.lm_head_weight if self.device.type == "cuda" else self.lm_head_weight.float()  
454 - linear_bias = self.lm_head_bias  
455 - if linear_bias is not None and self.device.type != "cuda":  
456 - linear_bias = linear_bias.float()  
457 - flat_logits = F.linear(linear_hidden, linear_weight, linear_bias)  
458 - flat_selected = flat_logits.gather(1, flat_token_ids.unsqueeze(1)).squeeze(1).float()  
459 - flat_log_norm = torch.logsumexp(flat_logits.float(), dim=-1)  
460 - token_log_probs[used_mask] = flat_selected - flat_log_norm  
461 - out_scores.copy_(token_log_probs.sum(dim=-1))  
462 -  
463 - def build_score_result(  
464 - self,  
465 - query: str,  
466 - scores: torch.Tensor,  
467 - stage_ms: dict[str, float],  
468 - continuation_tokens: int,  
469 - ) -> QueryScoreResult:  
470 - score_values = scores.detach().float().cpu().tolist()  
471 - best_idx = max(range(len(score_values)), key=score_values.__getitem__)  
472 - probs = torch.softmax(torch.tensor(score_values, dtype=torch.float32), dim=0).tolist()  
473 - return QueryScoreResult(  
474 - task_name=self.task_cfg.name,  
475 - query=query,  
476 - predicted_label=self.labels[best_idx],  
477 - scores=[  
478 - (label, score, prob)  
479 - for label, score, prob in zip(self.labels, score_values, probs, strict=True)  
480 - ],  
481 - total_ms=sum(stage_ms.values()),  
482 - stage_ms=stage_ms,  
483 - fast_path=self.fast_path,  
484 - prefix_tokens=self.prefix_cache.prefix_len,  
485 - continuation_tokens=continuation_tokens,  
486 - label_token_lengths={item.text: len(item.token_ids) for item in self.encoded_labels},  
487 - )  
488 -  
489 -  
490 -class MultiPromptRunner:  
491 - def __init__(self, cfg: RuntimeConfig):  
492 - self.cfg = cfg  
493 - t0 = time.perf_counter()  
494 - self.shared_runtime = self.build_shared_runtime(cfg)  
495 - t1 = time.perf_counter()  
496 - self.device = self.shared_runtime.device  
497 - self.dtype = self.shared_runtime.dtype  
498 - self.tokenizer = self.shared_runtime.tokenizer  
499 - self.model = self.shared_runtime.model  
500 - self.backbone = self.shared_runtime.backbone  
501 - self.hidden_size = self.shared_runtime.hidden_size  
502 - self.graph_capture_pool = self.shared_runtime.graph_capture_pool  
503 - self.graph_capture_stream = self.shared_runtime.graph_capture_stream  
504 - self.runners = [  
505 - PromptClassifierRunner(cfg=cfg, task_cfg=task_cfg, shared_runtime=self.shared_runtime)  
506 - for task_cfg in cfg.tasks  
507 - ]  
508 - t2 = time.perf_counter()  
509 - self.batch_layouts = {batch_size: self._build_batch_layout(batch_size) for batch_size in self.cfg.batch_sizes}  
510 - t3 = time.perf_counter()  
511 - self.mixed_prefix_caches = {  
512 - batch_size: self._build_mixed_prefix_cache(self.batch_layouts[batch_size])  
513 - for batch_size in self.cfg.batch_sizes  
514 - }  
515 - t4 = time.perf_counter()  
516 - self.max_label_prefix_len = max(  
517 - (runner.multi_token_tables.max_label_prefix_len if runner.multi_token_tables is not None else 0)  
518 - for runner in self.runners  
519 - )  
520 - self.mixed_buckets = {  
521 - (batch_size, continuation_len): self._build_mixed_bucket(  
522 - self.batch_layouts[batch_size],  
523 - self.mixed_prefix_caches[batch_size],  
524 - continuation_len,  
525 - )  
526 - for batch_size in self.cfg.batch_sizes  
527 - for continuation_len in self.cfg.continuation_buckets  
528 - }  
529 - t5 = time.perf_counter()  
530 - self._warmup_results: dict[int, BatchScoreResult] = {}  
531 - self._preload_report: PreloadReport | None = None  
532 - self._init_stage_ms = {  
533 - "load_model_and_tokenizer": (t1 - t0) * 1000.0,  
534 - "build_prompt_runtimes": (t2 - t1) * 1000.0,  
535 - "build_batch_layouts": (t3 - t2) * 1000.0,  
536 - "build_mixed_prefix_caches": (t4 - t3) * 1000.0,  
537 - "build_mixed_buckets_and_graphs": (t5 - t4) * 1000.0,  
538 - }  
539 - self._init_total_ms = sum(self._init_stage_ms.values())  
540 -  
541 - @staticmethod  
542 - def build_shared_runtime(cfg: RuntimeConfig) -> SharedRuntime:  
543 - device = torch.device(cfg.device)  
544 - dtype = torch.float16  
545 - tokenizer = AutoTokenizer.from_pretrained(  
546 - cfg.resolved_model_source,  
547 - trust_remote_code=cfg.resolved_trust_remote_code,  
548 - token=cfg.hf_token,  
549 - cache_dir=cfg.hf_cache_dir,  
550 - local_files_only=cfg.resolved_local_files_only,  
551 - )  
552 - if tokenizer.pad_token_id is None:  
553 - tokenizer.pad_token = tokenizer.eos_token  
554 - attn_impl = MultiPromptRunner._resolve_attn_impl(cfg.attn_backend)  
555 - quantization_config = None  
556 - model_kwargs: dict[str, object] = {  
557 - "trust_remote_code": cfg.resolved_trust_remote_code,  
558 - "attn_implementation": attn_impl,  
559 - "token": cfg.hf_token,  
560 - "cache_dir": cfg.hf_cache_dir,  
561 - "local_files_only": cfg.resolved_local_files_only,  
562 - }  
563 - if cfg.load_in_4bit:  
564 - if BitsAndBytesConfig is None:  
565 - raise ImportError("transformers BitsAndBytesConfig is unavailable; install bitsandbytes support first")  
566 - quantization_config = BitsAndBytesConfig(  
567 - load_in_4bit=True,  
568 - bnb_4bit_compute_dtype=dtype,  
569 - bnb_4bit_quant_type=cfg.bnb_4bit_quant_type,  
570 - bnb_4bit_use_double_quant=cfg.bnb_4bit_use_double_quant,  
571 - )  
572 - model_kwargs["quantization_config"] = quantization_config  
573 - model_kwargs["device_map"] = {"": device.index or 0}  
574 - else:  
575 - model_kwargs["dtype"] = dtype  
576 - model_kwargs["device_map"] = None  
577 - model = AutoModelForCausalLM.from_pretrained(  
578 - cfg.resolved_model_source,  
579 - **model_kwargs,  
580 - ).eval()  
581 - if not cfg.load_in_4bit:  
582 - model = model.to(device)  
583 - backbone = model.get_submodule(model.base_model_prefix)  
584 - hidden_size = model.get_output_embeddings().weight.shape[1]  
585 - graph_capture_pool = None  
586 - graph_capture_stream = None  
587 - if device.type == "cuda" and torch.cuda.is_available() and cfg.cuda_graphs and not cfg.load_in_4bit:  
588 - graph_capture_pool = torch.cuda.graph_pool_handle()  
589 - graph_capture_stream = torch.cuda.Stream(device=device)  
590 - return SharedRuntime(  
591 - device=device,  
592 - dtype=dtype,  
593 - tokenizer=tokenizer,  
594 - model=model,  
595 - backbone=backbone,  
596 - hidden_size=hidden_size,  
597 - graph_capture_pool=graph_capture_pool,  
598 - graph_capture_stream=graph_capture_stream,  
599 - )  
600 -  
601 - @staticmethod  
602 - def _resolve_attn_impl(requested: str) -> str:  
603 - if requested in {"sdpa", "eager"}:  
604 - return requested  
605 - if requested == "auto":  
606 - return "sdpa"  
607 - raise ValueError(f"unsupported attn_backend: {requested}")  
608 -  
609 - def _attn_context(self):  
610 - if sdpa_kernel is not None and self.cfg.attn_backend in {"auto", "sdpa"}:  
611 - return sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH])  
612 - return torch.no_grad()  
613 -  
614 - def _sync(self) -> None:  
615 - if self.device.type == "cuda":  
616 - torch.cuda.synchronize()  
617 -  
618 - def _pick_bucket(self, continuation_len: int) -> int:  
619 - for bucket in self.cfg.continuation_buckets:  
620 - if continuation_len <= bucket:  
621 - return bucket  
622 - if self.cfg.pad_to_bucket:  
623 - raise ValueError(  
624 - f"continuation length {continuation_len} exceeds configured buckets; extend continuation_buckets"  
625 - )  
626 - return continuation_len  
627 -  
628 - def _build_batch_layout(self, batch_size: int) -> BatchLayout:  
629 - plans: list[PromptBatchPlan] = []  
630 - row_start = 0  
631 - for runner in self.runners:  
632 - row_count = batch_size if runner.fast_path else batch_size * runner.num_labels  
633 - plans.append(  
634 - PromptBatchPlan(  
635 - runner=runner,  
636 - row_start=row_start,  
637 - row_count=row_count,  
638 - score_buffer=torch.empty((batch_size, runner.num_labels), device=self.device, dtype=torch.float32),  
639 - )  
640 - )  
641 - row_start += row_count  
642 - return BatchLayout(batch_size=batch_size, total_rows=row_start, plans=plans)  
643 -  
644 - def _build_mixed_prefix_cache(self, layout: BatchLayout) -> MixedPrefixCache:  
645 - prefix_lengths = torch.zeros((layout.total_rows,), device=self.device, dtype=torch.long)  
646 - non_empty = [plan.runner.prefix_cache.raw_cache for plan in layout.plans if plan.runner.prefix_cache.raw_cache]  
647 - if not non_empty:  
648 - return MixedPrefixCache(  
649 - batch_size=layout.batch_size,  
650 - total_rows=layout.total_rows,  
651 - prefix_lengths=prefix_lengths,  
652 - attention_mask=torch.zeros((layout.total_rows, 0), device=self.device, dtype=torch.long),  
653 - raw_cache=tuple(),  
654 - )  
655 -  
656 - max_prefix_len = max(plan.runner.prefix_cache.prefix_len for plan in layout.plans)  
657 - num_layers = len(non_empty[0])  
658 - attention_mask = torch.zeros((layout.total_rows, max_prefix_len), device=self.device, dtype=torch.long)  
659 - raw_layers: list[tuple[torch.Tensor, torch.Tensor]] = []  
660 - for layer_idx in range(num_layers):  
661 - sample_key, sample_value = non_empty[0][layer_idx]  
662 - merged_key = sample_key.new_zeros(  
663 - (layout.total_rows, sample_key.shape[1], max_prefix_len, sample_key.shape[3])  
664 - )  
665 - merged_value = sample_value.new_zeros(  
666 - (layout.total_rows, sample_value.shape[1], max_prefix_len, sample_value.shape[3])  
667 - )  
668 - raw_layers.append((merged_key, merged_value))  
669 -  
670 - for plan in layout.plans:  
671 - runner = plan.runner  
672 - prefix_len = runner.prefix_cache.prefix_len  
673 - row_slice = slice(plan.row_start, plan.row_stop)  
674 - prefix_lengths[row_slice] = prefix_len  
675 - if prefix_len == 0:  
676 - continue  
677 - attention_mask[row_slice, :prefix_len] = 1  
678 - raw_cache = runner.expand_prefix_raw_cache(plan.row_count)  
679 - for layer_idx, (key, value) in enumerate(raw_cache):  
680 - merged_key, merged_value = raw_layers[layer_idx]  
681 - merged_key[row_slice, :, :prefix_len, :] = key  
682 - merged_value[row_slice, :, :prefix_len, :] = value  
683 -  
684 - return MixedPrefixCache(  
685 - batch_size=layout.batch_size,  
686 - total_rows=layout.total_rows,  
687 - prefix_lengths=prefix_lengths,  
688 - attention_mask=attention_mask.contiguous(),  
689 - raw_cache=tuple(raw_layers),  
690 - )  
691 -  
692 - def _build_mixed_bucket(  
693 - self,  
694 - layout: BatchLayout,  
695 - prefix_cache: MixedPrefixCache,  
696 - continuation_len: int,  
697 - ) -> MixedBucketRuntime:  
698 - max_input_len = continuation_len + self.max_label_prefix_len  
699 - total_len = prefix_cache.max_prefix_len + max_input_len  
700 - input_ids = torch.full(  
701 - (layout.total_rows, max_input_len),  
702 - fill_value=self.tokenizer.pad_token_id,  
703 - device=self.device,  
704 - dtype=torch.long,  
705 - )  
706 - attention_mask = torch.zeros((layout.total_rows, total_len), device=self.device, dtype=torch.long)  
707 - if prefix_cache.max_prefix_len:  
708 - attention_mask[:, : prefix_cache.max_prefix_len] = prefix_cache.attention_mask  
709 - position_ids = (  
710 - prefix_cache.prefix_lengths[:, None]  
711 - + torch.arange(max_input_len, device=self.device, dtype=torch.long).unsqueeze(0)  
712 - ).contiguous()  
713 - last_hidden_state = torch.empty(  
714 - (layout.total_rows, max_input_len, self.hidden_size),  
715 - device=self.device,  
716 - dtype=self.dtype,  
717 - )  
718 - bucket = MixedBucketRuntime(  
719 - batch_size=layout.batch_size,  
720 - total_rows=layout.total_rows,  
721 - continuation_len=continuation_len,  
722 - max_input_len=max_input_len,  
723 - input_ids=input_ids,  
724 - attention_mask=attention_mask,  
725 - position_ids=position_ids,  
726 - last_hidden_state=last_hidden_state,  
727 - )  
728 - if self.cfg.cuda_graphs:  
729 - self._capture_mixed_bucket(bucket, prefix_cache)  
730 - return bucket  
731 -  
732 - @torch.inference_mode()  
733 - def _run_mixed_backbone(  
734 - self,  
735 - bucket: MixedBucketRuntime,  
736 - prefix_cache: MixedPrefixCache,  
737 - ) -> None:  
738 - cache = DynamicCache(ddp_cache_data=prefix_cache.raw_cache, config=self.model.config)  
739 - with self._attn_context():  
740 - outputs = self.backbone(  
741 - input_ids=bucket.input_ids,  
742 - attention_mask=bucket.attention_mask,  
743 - position_ids=bucket.position_ids,  
744 - past_key_values=cache,  
745 - use_cache=False,  
746 - return_dict=True,  
747 - )  
748 - bucket.last_hidden_state.copy_(outputs.last_hidden_state)  
749 -  
750 - def _capture_mixed_bucket(self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache) -> None:  
751 - if not torch.cuda.is_available():  
752 - return  
753 - try:  
754 - torch.cuda.synchronize()  
755 - stream = self.graph_capture_stream or torch.cuda.Stream(device=self.device)  
756 - with torch.cuda.stream(stream):  
757 - for _ in range(self.cfg.graph_warmups):  
758 - self._run_mixed_backbone(bucket, prefix_cache)  
759 - stream.synchronize()  
760 - graph = torch.cuda.CUDAGraph()  
761 - with torch.cuda.graph(graph, pool=self.graph_capture_pool, stream=stream):  
762 - self._run_mixed_backbone(bucket, prefix_cache)  
763 - bucket.graph = graph  
764 - except RuntimeError:  
765 - bucket.graph = None  
766 -  
767 - def _prepare_bucket(  
768 - self,  
769 - layout: BatchLayout,  
770 - prefix_cache: MixedPrefixCache,  
771 - bucket: MixedBucketRuntime,  
772 - query_ids_batch: list[list[int]],  
773 - ) -> tuple[list[list[int]], dict[str, list[object]]]:  
774 - del prefix_cache  
775 - bucket.input_ids.fill_(self.tokenizer.pad_token_id)  
776 - bucket.attention_mask.zero_()  
777 - if self.mixed_prefix_caches[layout.batch_size].max_prefix_len:  
778 - bucket.attention_mask[:, : self.mixed_prefix_caches[layout.batch_size].max_prefix_len] = (  
779 - self.mixed_prefix_caches[layout.batch_size].attention_mask  
780 - )  
781 - continuation_lengths_per_task: dict[str, list[int]] = {}  
782 - continuation_tokens_per_task: dict[str, list[list[int]]] = {}  
783 - prefix_base = self.mixed_prefix_caches[layout.batch_size].max_prefix_len  
784 - for plan in layout.plans:  
785 - runner = plan.runner  
786 - per_query_continuations = [runner.build_continuation_from_query_ids(query_ids) for query_ids in query_ids_batch]  
787 - continuation_tokens_per_task[runner.task_cfg.name] = per_query_continuations  
788 - continuation_lengths_per_task[runner.task_cfg.name] = [len(ids) for ids in per_query_continuations]  
789 - if runner.fast_path:  
790 - for batch_idx, continuation in enumerate(per_query_continuations):  
791 - cont_len = len(continuation)  
792 - row_idx = plan.row_start + batch_idx  
793 - bucket.input_ids[row_idx, :cont_len] = torch.tensor(continuation, device=self.device, dtype=torch.long)  
794 - bucket.attention_mask[row_idx, prefix_base : prefix_base + cont_len] = 1  
795 - continue  
796 -  
797 - assert runner.multi_token_tables is not None  
798 - for batch_idx, continuation in enumerate(per_query_continuations):  
799 - cont_len = len(continuation)  
800 - row_start = plan.row_start + batch_idx * runner.num_labels  
801 - row_stop = row_start + runner.num_labels  
802 - row_slice = slice(row_start, row_stop)  
803 - cont_tensor = torch.tensor(continuation, device=self.device, dtype=torch.long)  
804 - bucket.input_ids[row_slice, :cont_len] = cont_tensor.unsqueeze(0).expand(runner.num_labels, -1)  
805 - bucket.attention_mask[row_slice, prefix_base : prefix_base + cont_len] = 1  
806 - if runner.multi_token_tables.max_label_prefix_len:  
807 - bucket.input_ids[  
808 - row_slice,  
809 - cont_len : cont_len + runner.multi_token_tables.max_label_prefix_len,  
810 - ] = runner.multi_token_tables.label_prefix_ids  
811 - bucket.attention_mask[  
812 - row_slice,  
813 - prefix_base + cont_len : prefix_base + cont_len + runner.multi_token_tables.max_label_prefix_len,  
814 - ] = runner.multi_token_tables.label_prefix_mask  
815 - return query_ids_batch, {  
816 - "continuation_lengths_per_task": continuation_lengths_per_task,  
817 - "continuation_tokens_per_task": continuation_tokens_per_task,  
818 - }  
819 -  
820 - def _reduce_prompt_scores(  
821 - self,  
822 - layout: BatchLayout,  
823 - bucket: MixedBucketRuntime,  
824 - query_texts: list[str],  
825 - prep_meta: dict[str, list[object]],  
826 - shared_stage_ms: dict[str, float],  
827 - ) -> list[MultiPromptScoreResult]:  
828 - result_rows = [[] for _ in range(layout.batch_size)]  
829 - prompt_reduce_total_ms = 0.0  
830 - for plan in layout.plans:  
831 - runner = plan.runner  
832 - continuation_lengths = prep_meta["continuation_lengths_per_task"][runner.task_cfg.name]  
833 - reduce_start = time.perf_counter()  
834 - if runner.fast_path:  
835 - hidden_rows = []  
836 - row_slice = bucket.last_hidden_state[plan.row_start : plan.row_start + layout.batch_size]  
837 - for batch_idx, cont_len in enumerate(continuation_lengths):  
838 - hidden_rows.append(row_slice[batch_idx, cont_len - 1])  
839 - hidden = torch.stack(hidden_rows, dim=0)  
840 - runner.reduce_fast_scores(hidden=hidden, out_scores=plan.score_buffer)  
841 - stage_name = "tail_scorer"  
842 - else:  
843 - assert runner.multi_token_tables is not None  
844 - score_positions = torch.stack(  
845 - [  
846 - cont_len - 1 + runner.multi_token_tables.label_position_offsets  
847 - for cont_len in continuation_lengths  
848 - ],  
849 - dim=0,  
850 - )  
851 - runner.reduce_multi_token_scores(  
852 - last_hidden_state=bucket.last_hidden_state[plan.row_start : plan.row_stop],  
853 - batch_size=layout.batch_size,  
854 - max_input_len=bucket.max_input_len,  
855 - score_positions=score_positions,  
856 - out_scores=plan.score_buffer,  
857 - )  
858 - stage_name = "candidate_reduce"  
859 - self._sync()  
860 - reduce_end = time.perf_counter()  
861 - reduce_ms = (reduce_end - reduce_start) * 1000.0  
862 - prompt_reduce_total_ms += reduce_ms  
863 - for batch_idx, query in enumerate(query_texts):  
864 - stage_ms = dict(shared_stage_ms)  
865 - stage_ms[stage_name] = reduce_ms / layout.batch_size  
866 - result_rows[batch_idx].append(  
867 - runner.build_score_result(  
868 - query=query,  
869 - scores=plan.score_buffer[batch_idx],  
870 - stage_ms=stage_ms,  
871 - continuation_tokens=continuation_lengths[batch_idx],  
872 - )  
873 - )  
874 -  
875 - batch_total_ms = sum(shared_stage_ms.values()) + prompt_reduce_total_ms  
876 - shared_plus_reduce = dict(shared_stage_ms)  
877 - shared_plus_reduce["prompt_reduce_total"] = prompt_reduce_total_ms  
878 - results: list[MultiPromptScoreResult] = []  
879 - for batch_idx, query in enumerate(query_texts):  
880 - results.append(  
881 - MultiPromptScoreResult(  
882 - query=query,  
883 - total_ms=batch_total_ms / layout.batch_size,  
884 - details=result_rows[batch_idx],  
885 - stage_ms={  
886 - **shared_plus_reduce,  
887 - "per_query_total_estimate": batch_total_ms / layout.batch_size,  
888 - },  
889 - )  
890 - )  
891 - return results  
892 -  
893 - @torch.inference_mode()  
894 - def score_queries(self, queries: list[str]) -> BatchScoreResult:  
895 - if not queries:  
896 - raise ValueError("queries must not be empty")  
897 - batch_size = len(queries)  
898 - if batch_size not in self.batch_layouts:  
899 - raise ValueError(f"batch size {batch_size} is not preloaded; configured batch_sizes={self.cfg.batch_sizes}")  
900 - layout = self.batch_layouts[batch_size]  
901 - prefix_cache = self.mixed_prefix_caches[batch_size]  
902 -  
903 - self._sync()  
904 - t0 = time.perf_counter()  
905 - query_ids_batch = [self.tokenizer.encode(query, add_special_tokens=False) for query in queries]  
906 - self._sync()  
907 - t1 = time.perf_counter()  
908 -  
909 - max_continuation_len = max(  
910 - len(plan.runner.build_continuation_from_query_ids(query_ids))  
911 - for plan in layout.plans  
912 - for query_ids in query_ids_batch  
913 - )  
914 - picked_bucket = self._pick_bucket(max_continuation_len)  
915 - bucket = self.mixed_buckets[(batch_size, picked_bucket)]  
916 - _, prep_meta = self._prepare_bucket(layout, prefix_cache, bucket, query_ids_batch)  
917 - self._sync()  
918 - t2 = time.perf_counter()  
919 -  
920 - if bucket.graph is not None:  
921 - bucket.graph.replay()  
922 - else:  
923 - self._run_mixed_backbone(bucket, prefix_cache)  
924 - self._sync()  
925 - t3 = time.perf_counter()  
926 -  
927 - shared_stage_ms = {  
928 - "encode_queries_shared": (t1 - t0) * 1000.0,  
929 - "prepare_batch_shared": (t2 - t1) * 1000.0,  
930 - "backbone_shared": (t3 - t2) * 1000.0,  
931 - }  
932 - results = self._reduce_prompt_scores(layout, bucket, queries, prep_meta, shared_stage_ms)  
933 - total_ms = sum(shared_stage_ms.values()) + results[0].stage_ms["prompt_reduce_total"]  
934 - return BatchScoreResult(  
935 - batch_size=batch_size,  
936 - total_ms=total_ms,  
937 - results=results,  
938 - stage_ms={  
939 - **shared_stage_ms,  
940 - "prompt_reduce_total": results[0].stage_ms["prompt_reduce_total"],  
941 - },  
942 - )  
943 -  
944 - def score_query(self, query: str) -> MultiPromptScoreResult:  
945 - return self.score_queries([query]).results[0]  
946 -  
947 - def preload(self) -> PreloadReport:  
948 - if self._preload_report is not None:  
949 - return self._preload_report  
950 - stage_ms: dict[str, float] = dict(self._init_stage_ms)  
951 - start = time.perf_counter()  
952 - self._sync()  
953 - t0 = time.perf_counter()  
954 - warmup_batch_sizes = self.cfg.warmup_batch_sizes or self.cfg.batch_sizes  
955 - for batch_size in warmup_batch_sizes:  
956 - queries = [self.cfg.warmup_query] * batch_size  
957 - self._warmup_results[batch_size] = self.score_queries(queries)  
958 - self._sync()  
959 - t1 = time.perf_counter()  
960 - stage_ms["warmup_end_to_end"] = (t1 - t0) * 1000.0  
961 - stage_ms["startup_total_before_warmup"] = self._init_total_ms  
962 - total_ms = self._init_total_ms + (t1 - start) * 1000.0  
963 - runtime = self.preload_report()  
964 - self._preload_report = PreloadReport(total_ms=total_ms, stage_ms=stage_ms, runtime=runtime)  
965 - return self._preload_report  
966 -  
967 - def preload_report(self) -> dict[str, object]:  
968 - return {  
969 - "model_name": self.cfg.resolved_model_name,  
970 - "model_source": self.cfg.resolved_model_source,  
971 - "device": str(self.device),  
972 - "dtype": self.cfg.dtype,  
973 - "attn_backend": self.cfg.attn_backend,  
974 - "execution_model": "single_mixed_backbone_per_batch",  
975 - "num_tasks": len(self.runners),  
976 - "task_names": [runner.task_cfg.name for runner in self.runners],  
977 - "batch_sizes": list(self.cfg.batch_sizes),  
978 - "continuation_buckets": list(self.cfg.continuation_buckets),  
979 - "mixed_bucket_count": len(self.mixed_buckets),  
980 - "captured_mixed_buckets": sum(bucket.graph is not None for bucket in self.mixed_buckets.values()),  
981 - "all_configured_buckets_preloaded": True,  
982 - "init_stage_ms": dict(self._init_stage_ms),  
983 - "init_total_ms": self._init_total_ms,  
984 - "force_single_token_labels": self.cfg.force_single_token_labels,  
985 - "warmup_query": self.cfg.warmup_query,  
986 - "tasks": [  
987 - {  
988 - "task_name": runner.task_cfg.name,  
989 - "fast_path": runner.fast_path,  
990 - "num_labels": runner.num_labels,  
991 - "label_token_lengths": {item.text: len(item.token_ids) for item in runner.encoded_labels},  
992 - "prefix_tokens": runner.prefix_cache.prefix_len,  
993 - "prefix_hashes": runner.prefix_cache.prefix_hashes,  
994 - "label_prefix": runner.task_cfg.label_prefix,  
995 - }  
996 - for runner in self.runners  
997 - ],  
998 - }  
999 -  
1000 -但是关于多token的处理方式是低效的、是错的,不要参考他,请你重新实现。  
1001 -本地的.venv已经创建好,是复用llm-qp的,请使用该环境  
1002 -有用的代码我已经引用过来了,请你基于需求,搜寻相关资料,专注于重新开发全新的版本。  
1003 -任务较重,请耐心逐步完成,慢慢来,要长时间迭代一直到服务建设完成、测试完全符合预期。  
docs/issues/issue-2026-04-06-推理优化.md deleted
@@ -1,59 +0,0 @@ @@ -1,59 +0,0 @@
1 -  
2 -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。  
3 -先专注于推理的优化,不用考虑服务后,可以程序启动后标准输入读取query,输出分类词。  
4 -  
5 -示例prompt和对应的9个分类词:  
6 -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query}  
7 -  
8 -做专用执行路径,不是在通用生成引擎上做配置优化。  
9 -  
10 -prompt已经固定(不要考虑蒸馏、微调、或者缩短prompt。缩短prompt是可以我自己调整的,并且prompt可能改为其他场景,与你做专用推理优化不相关,我给的prompt只是一个例子,你专注于专用推理的框架,适配配置化的prompt+分类词列表打分检查)  
11 -  
12 -  
13 -使用Tesla T4,因此:  
14 -1. 使用FP16。不用 BF16 作为主路径  
15 -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention  
16 -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。  
17 -  
18 -  
19 -主要考虑优化方向为:  
20 -hidden_last -> N-class scorer -> argmax  
21 -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV  
22 -去 full vocab logits  
23 -去 decode / constrained decode  
24 -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8)  
25 -专用 tail kernel(输出 N 类原始分数)  
26 -  
27 -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化:  
28 -15.1 预分配  
29 -对每个 bucket 预分配:  
30 -input ids buffer  
31 -attention scratch buffer  
32 -output hidden buffer  
33 -class id buffer  
34 -15.2 Graph capture  
35 -  
36 -每个 bucket 预先 capture 一张图:  
37 -embedding  
38 -transformer 主干  
39 -compact last-state  
40 -9-way tail kernel  
41 -  
42 -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。  
43 -  
44 -你有sudo权限,你可以执行为本项目安装自己的环境  
45 -  
46 -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型)  
47 -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。  
48 -  
49 -  
50 -  
51 -  
52 -  
53 -  
54 -  
55 -另外,我想要有个命令行工具,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。  
56 -输入query,输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下),目前是否已经满足了。  
57 -  
58 -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。  
59 -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。  
docs/issues/issue-2026-04-07-服务化.md deleted
@@ -1,66 +0,0 @@ @@ -1,66 +0,0 @@
1 -彻底清除掉命令行交互方式的代码,改造为提供 HTTP 接口,端口6001。  
2 -并提供scripts/service_ctl.sh的方式管理服务的启停  
3 -返回分类结果的list(results):list由多个dict组成,每个dict的key为对应的分类任务,value是三元组,即值、打分、概率(如果值为none则该项不输出)  
4 -完善日志系统,当前cli方式输出的重要信息,在日志以及http接口的details字段体现。  
5 -提供一个压测脚本放到本项目的合适的目录下,作为压测工具。该压测工具针对每条请求打印出结果,并最后给出性能指标,参考:  
6 -#!/bin/bash  
7 -  
8 -# 默认值  
9 -concurrency=${1:-1}  
10 -top_lines=${2:-100}  
11 -  
12 -# 固定查询文件路径  
13 -query_file="/data/saas-search/scripts/evaluation/queries/queries.txt"  
14 -  
15 -# 检查文件是否存在  
16 -if [ ! -f "$query_file" ]; then  
17 - echo "错误: 查询文件不存在: $query_file" >&2  
18 - exit 1  
19 -fi  
20 -  
21 -# 检查 jq 是否可用  
22 -if ! command -v jq &> /dev/null; then  
23 - echo "错误: 需要安装 jq 来解析 JSON" >&2  
24 - exit 1  
25 -fi  
26 -  
27 -url="http://127.0.0.1:6001/..."  
28 -max_jobs=$concurrency  
29 -job_count=0  
30 -  
31 -# 读取文件前 top_lines 行,每行作为一个 query  
32 -while IFS= read -r query; do  
33 - # 跳过空行  
34 - [ -z "$query" ] && continue  
35 -  
36 - # 启动子进程执行请求  
37 - (  
38 - # 安全构建 JSON payload  
39 - payload=$(jq -n --arg q "$query" '{query: $q}')  
40 - # 发送请求并获取响应  
41 - response=$(curl -s -X POST "$url" \  
42 - -H 'Content-Type: application/json' \  
43 - -d "$payload")  
44 - # 提取 results 字段(紧凑 JSON 格式)  
45 - results=$(echo "$response" | jq -c '.results')  
46 - # 输出 query 和对应的 results  
47 - printf "%s\t%s\n" "$query" "$results"  
48 - ) &  
49 -  
50 - # 控制并发数量  
51 - ((job_count++))  
52 - if (( job_count >= max_jobs )); then  
53 - wait -n # 等待任意一个后台进程完成  
54 - ((job_count--))  
55 - fi  
56 -done < <(head -n "$top_lines" "$query_file")  
57 -  
58 -# 等待所有剩余后台进程完成  
59 -# 在这里统计处性能情况,指标:  
60 - 平均耗时:  
61 - 最大耗时:  
62 - 最小耗时:  
63 - TP50:  
64 - TP90:  
65 - TP99:  
66 -wait  
67 \ No newline at end of file 0 \ No newline at end of file
docs/issues/issue.md
1 项目 TODO 清单 1 项目 TODO 清单
2 2
  3 +CLAUDE.md需要更新
  4 +
3 2. 核心搜索功能优化 5 2. 核心搜索功能优化
4 6
5 2.1 意图识别模块 7 2.1 意图识别模块
scripts/evaluation/eval_framework/__init__.py
@@ -20,7 +20,6 @@ from .constants import ( # noqa: E402 @@ -20,7 +20,6 @@ from .constants import ( # noqa: E402
20 RELEVANCE_LOW, 20 RELEVANCE_LOW,
21 RELEVANCE_NON_IRRELEVANT, 21 RELEVANCE_NON_IRRELEVANT,
22 VALID_LABELS, 22 VALID_LABELS,
23 - normalize_stored_label,  
24 ) 23 )
25 from .framework import SearchEvaluationFramework # noqa: E402 24 from .framework import SearchEvaluationFramework # noqa: E402
26 from .store import EvalStore, QueryBuildResult # noqa: E402 25 from .store import EvalStore, QueryBuildResult # noqa: E402
@@ -51,7 +50,6 @@ __all__ = [ @@ -51,7 +50,6 @@ __all__ = [
51 "create_web_app", 50 "create_web_app",
52 "ensure_dir", 51 "ensure_dir",
53 "main", 52 "main",
54 - "normalize_stored_label",  
55 "render_batch_report_markdown", 53 "render_batch_report_markdown",
56 "sha1_text", 54 "sha1_text",
57 "utc_now_iso", 55 "utc_now_iso",
scripts/evaluation/eval_framework/constants.py
@@ -42,20 +42,6 @@ STOP_PROB_MAP = { @@ -42,20 +42,6 @@ STOP_PROB_MAP = {
42 RELEVANCE_IRRELEVANT: 0.0, 42 RELEVANCE_IRRELEVANT: 0.0,
43 } 43 }
44 44
45 -_LEGACY_LABEL_MAP = {  
46 - "Exact": RELEVANCE_EXACT,  
47 - "Partial": RELEVANCE_HIGH,  
48 -}  
49 -  
50 -  
51 -def normalize_stored_label(label: str) -> str:  
52 - """Map legacy 3-way SQLite labels to current 4-way strings; pass through canonical labels."""  
53 - s = str(label).strip()  
54 - if s in VALID_LABELS:  
55 - return s  
56 - return _LEGACY_LABEL_MAP.get(s, s)  
57 -  
58 -  
59 DEFAULT_ARTIFACT_ROOT = PROJECT_ROOT / "artifacts" / "search_evaluation" 45 DEFAULT_ARTIFACT_ROOT = PROJECT_ROOT / "artifacts" / "search_evaluation"
60 DEFAULT_QUERY_FILE = _SCRIPTS_EVAL_DIR / "queries" / "queries.txt" 46 DEFAULT_QUERY_FILE = _SCRIPTS_EVAL_DIR / "queries" / "queries.txt"
61 47
scripts/evaluation/eval_framework/static/eval_web.css
@@ -48,9 +48,9 @@ @@ -48,9 +48,9 @@
48 .results { display: grid; gap: 10px; } 48 .results { display: grid; gap: 10px; }
49 .result { display: grid; grid-template-columns: 110px 100px 1fr; gap: 14px; align-items: center; background: var(--panel); border: 1px solid var(--line); border-radius: 18px; padding: 12px; } 49 .result { display: grid; grid-template-columns: 110px 100px 1fr; gap: 14px; align-items: center; background: var(--panel); border: 1px solid var(--line); border-radius: 18px; padding: 12px; }
50 .badge { display: inline-block; padding: 8px 10px; border-radius: 999px; color: white; font-weight: 700; text-align: center; } 50 .badge { display: inline-block; padding: 8px 10px; border-radius: 999px; color: white; font-weight: 700; text-align: center; }
51 - .label-exact-match { background: var(--exact); }  
52 - .label-high-relevant { background: var(--high); }  
53 - .label-low-relevant { background: var(--low); } 51 + .label-fully-relevant { background: var(--exact); }
  52 + .label-mostly-relevant { background: var(--high); }
  53 + .label-weakly-relevant { background: var(--low); }
54 .label-irrelevant { background: var(--irrelevant); } 54 .label-irrelevant { background: var(--irrelevant); }
55 .badge-unknown { background: #637381; } 55 .badge-unknown { background: #637381; }
56 .thumb { width: 100px; height: 100px; object-fit: cover; border-radius: 14px; background: #e7e1d4; } 56 .thumb { width: 100px; height: 100px; object-fit: cover; border-radius: 14px; background: #e7e1d4; }
scripts/evaluation/eval_framework/store.py
@@ -8,7 +8,7 @@ from dataclasses import dataclass @@ -8,7 +8,7 @@ from dataclasses import dataclass
8 from pathlib import Path 8 from pathlib import Path
9 from typing import Any, Dict, List, Optional, Sequence 9 from typing import Any, Dict, List, Optional, Sequence
10 10
11 -from .constants import VALID_LABELS, normalize_stored_label 11 +from .constants import VALID_LABELS
12 from .utils import ensure_dir, safe_json_dumps, utc_now_iso 12 from .utils import ensure_dir, safe_json_dumps, utc_now_iso
13 13
14 14
@@ -220,7 +220,7 @@ class EvalStore: @@ -220,7 +220,7 @@ class EvalStore:
220 """, 220 """,
221 (tenant_id, query_text), 221 (tenant_id, query_text),
222 ).fetchall() 222 ).fetchall()
223 - return {str(row["spu_id"]): normalize_stored_label(str(row["label"])) for row in rows} 223 + return {str(row["spu_id"]): str(row["label"]) for row in rows}
224 224
225 def upsert_labels( 225 def upsert_labels(
226 self, 226 self,
@@ -379,8 +379,8 @@ class EvalStore: @@ -379,8 +379,8 @@ class EvalStore:
379 SELECT 379 SELECT
380 query_text, 380 query_text,
381 COUNT(*) AS total, 381 COUNT(*) AS total,
382 - SUM(CASE WHEN label IN ('Fully Relevant','Exact') THEN 1 ELSE 0 END) AS exact_count,  
383 - SUM(CASE WHEN label IN ('Mostly Relevant','Partial') THEN 1 ELSE 0 END) AS high_relevant_count, 382 + SUM(CASE WHEN label='Fully Relevant' THEN 1 ELSE 0 END) AS exact_count,
  383 + SUM(CASE WHEN label='Mostly Relevant' THEN 1 ELSE 0 END) AS high_relevant_count,
384 SUM(CASE WHEN label='Weakly Relevant' THEN 1 ELSE 0 END) AS low_relevant_count, 384 SUM(CASE WHEN label='Weakly Relevant' THEN 1 ELSE 0 END) AS low_relevant_count,
385 SUM(CASE WHEN label='Irrelevant' THEN 1 ELSE 0 END) AS irrelevant_count, 385 SUM(CASE WHEN label='Irrelevant' THEN 1 ELSE 0 END) AS irrelevant_count,
386 MAX(updated_at) AS updated_at 386 MAX(updated_at) AS updated_at
@@ -409,8 +409,8 @@ class EvalStore: @@ -409,8 +409,8 @@ class EvalStore:
409 """ 409 """
410 SELECT 410 SELECT
411 COUNT(*) AS total, 411 COUNT(*) AS total,
412 - SUM(CASE WHEN label IN ('Fully Relevant','Exact') THEN 1 ELSE 0 END) AS exact_count,  
413 - SUM(CASE WHEN label IN ('Mostly Relevant','Partial') THEN 1 ELSE 0 END) AS high_relevant_count, 412 + SUM(CASE WHEN label='Fully Relevant' THEN 1 ELSE 0 END) AS exact_count,
  413 + SUM(CASE WHEN label='Mostly Relevant' THEN 1 ELSE 0 END) AS high_relevant_count,
414 SUM(CASE WHEN label='Weakly Relevant' THEN 1 ELSE 0 END) AS low_relevant_count, 414 SUM(CASE WHEN label='Weakly Relevant' THEN 1 ELSE 0 END) AS low_relevant_count,
415 SUM(CASE WHEN label='Irrelevant' THEN 1 ELSE 0 END) AS irrelevant_count, 415 SUM(CASE WHEN label='Irrelevant' THEN 1 ELSE 0 END) AS irrelevant_count,
416 MAX(updated_at) AS updated_at 416 MAX(updated_at) AS updated_at
suggestion/ARCHITECTURE_V2.md deleted
@@ -1,304 +0,0 @@ @@ -1,304 +0,0 @@
1 -# Suggestion 架构方案 V2(仅 Suggest,去除结果直达)  
2 -  
3 -## 0. 结论  
4 -  
5 -本方案将 Suggest 设计为**独立高性能检索系统**,只返回建议词,不再返回商品卡片,也不做历史兼容。  
6 -  
7 -- 只保留 `/search/suggestions` 的词级自动补全能力  
8 -- 完全移除 `with_results/result_size/products[]` 链路  
9 -- 多语言优先,支持高并发、低延迟、可持续演进  
10 -  
11 ----  
12 -  
13 -## 1. 当前实现的关键问题(基于现有代码审视)  
14 -  
15 -1. 在线链路曾包含“suggest -> 二次商品查询”,属于典型 N+1 放大,QPS 上升后延迟和 ES 负载都不稳定。  
16 -2. `builder.py` 全量构建使用“大量 in-memory 聚合 + fetchall”,大租户下内存风险高。  
17 -3. 查询参数上限过大(原 `size<=200`),不符合自动补全接口性能边界。  
18 -4. 文档与实现长期混合(README 仍包含结果直达),导致认知不一致。  
19 -5. 多语言归一化仍偏基础(仅 lower/空白折叠),对 Unicode、变音符、跨语系兼容不够。  
20 -  
21 ----  
22 -  
23 -## 2. 目标与 SLO  
24 -  
25 -### 2.1 业务目标  
26 -  
27 -- 输入时实时返回高相关建议词(query suggestion)  
28 -- 多语言稳定(至少覆盖租户配置 `index_languages`)  
29 -- 支持词级排序和运营治理(黑白名单、降噪、降权)  
30 -  
31 -### 2.2 性能目标(建议)  
32 -  
33 -- P50 < 10ms,P95 < 25ms,P99 < 50ms(ES 查询耗时,不含网关)  
34 -- 单集群支持高并发(千级 QPS 可横向扩展)  
35 -- 数据新鲜度:增量 5-15 分钟可见  
36 -  
37 ----  
38 -  
39 -## 3. 总体架构  
40 -  
41 -## 3.1 在线路径(单跳)  
42 -  
43 -Client -> API `/search/suggestions` -> ES `search_suggestions_v2` -> 返回 suggestions  
44 -  
45 -原则:  
46 -  
47 -- **单次 ES 查询完成主路径**(可选双召回融合,但仍在同一次 API 请求内完成)  
48 -- 不调用 `search_products`,不返回商品结果  
49 -- 通过 `routing=tenant_id` 避免跨分片 fan-out  
50 -  
51 -## 3.2 离线路径(构建)  
52 -  
53 -数据源:  
54 -  
55 -- 商品字段:`title.{lang}`、`qanchors.{lang}`  
56 -- 搜索日志:`shoplazza_search_log`(含 `language/request_params`)  
57 -- 行为信号(可选增强):点击、加购、下单  
58 -  
59 -产物:  
60 -  
61 -- Suggest 文档(`tenant_id + lang + text_norm` 唯一)  
62 -- completion + prefix 检索字段  
63 -- 排序特征(热度、近期度、质量分)  
64 -  
65 -发布方式:  
66 -  
67 -- 写入新物理索引(版本化)  
68 -- 原子切换 alias(零停机)  
69 -  
70 ----  
71 -  
72 -## 4. 索引设计(ES)  
73 -  
74 -## 4.1 索引组织  
75 -  
76 -推荐两级策略:  
77 -  
78 -1. 默认:环境级共享索引(降低海量租户 index 数量)  
79 -2. 大租户:可升级为租户独享索引(隔离资源)  
80 -  
81 -统一通过 alias 暴露:  
82 -  
83 -- `search_suggestions_v2_current`  
84 -  
85 -## 4.2 Mapping(核心字段)  
86 -  
87 -```json  
88 -{  
89 - "settings": {  
90 - "number_of_shards": 3,  
91 - "number_of_replicas": 1,  
92 - "refresh_interval": "30s"  
93 - },  
94 - "mappings": {  
95 - "properties": {  
96 - "tenant_id": { "type": "keyword" },  
97 - "lang": { "type": "keyword" },  
98 - "text": { "type": "keyword" },  
99 - "text_norm": { "type": "keyword" },  
100 - "status": { "type": "byte" },  
101 - "sources": { "type": "keyword" },  
102 -  
103 - "query_count_7d": { "type": "integer" },  
104 - "query_count_30d": { "type": "integer" },  
105 - "ctr_30d": { "type": "float" },  
106 - "order_rate_30d": { "type": "float" },  
107 - "rank_score": { "type": "float" },  
108 -  
109 - "suggest": {  
110 - "type": "completion",  
111 - "contexts": [  
112 - { "name": "tenant", "type": "category" },  
113 - { "name": "lang", "type": "category" }  
114 - ]  
115 - },  
116 -  
117 - "sat": {  
118 - "properties": {  
119 - "zh": { "type": "search_as_you_type", "analyzer": "index_ik" },  
120 - "en": { "type": "search_as_you_type", "analyzer": "english" },  
121 - "ar": { "type": "search_as_you_type", "analyzer": "arabic" }  
122 - }  
123 - },  
124 -  
125 - "updated_at": { "type": "date" }  
126 - }  
127 - }  
128 -}  
129 -```  
130 -  
131 -说明:  
132 -  
133 -- `completion` 负责极速前缀命中(主召回)  
134 -- `search_as_you_type` 负责多词前缀和召回兜底  
135 -- `contexts` 强制租户与语言隔离  
136 -  
137 ----  
138 -  
139 -## 5. 多语言策略  
140 -  
141 -1. 语言归属优先级:`log.language > request_params.language > 脚本识别 > tenant.primary_language`  
142 -2. 统一归一化:NFKC、大小写折叠、空白折叠、标点清洗  
143 -3. 分词器按语言配置:  
144 - - 中文:IK/ANSJ(与主索引保持一致)  
145 - - 拉丁语系:对应内置 analyzer  
146 - - 未覆盖语种:`standard + ICU folding` 兜底  
147 -4. 保证写入语言必须在租户 `index_languages` 内  
148 -  
149 ----  
150 -  
151 -## 6. 在线检索策略(高性能)  
152 -  
153 -## 6.1 双通道召回(推荐)  
154 -  
155 -1. 通道 A:`completion suggester`(prefix,skip_duplicates)  
156 -2. 通道 B:`multi_match(type=bool_prefix)` on `search_as_you_type`  
157 -3. 融合去重:按 `text_norm` 去重,按最终分排序截断  
158 -  
159 -## 6.2 查询约束  
160 -  
161 -- 默认 `size=10`,最大 `size=50`  
162 -- `track_total_hits=false`  
163 -- `_source` 仅返回必要字段(`text/lang/rank_score/sources`)  
164 -- `routing=tenant_id`  
165 -  
166 -## 6.3 打分建议  
167 -  
168 -```text  
169 -final_score =  
170 - es_score  
171 - + a1*log1p(query_count_30d)  
172 - + a2*log1p(query_count_7d)  
173 - + a3*ctr_30d  
174 - + a4*order_rate_30d  
175 - + a5*freshness_decay  
176 -```  
177 -  
178 ----  
179 -  
180 -## 7. 构建与发布  
181 -  
182 -## 7.1 构建模式  
183 -  
184 -- 每日全量:重建全量特征,清理脏词  
185 -- 小时级增量:只处理新日志窗口  
186 -  
187 -## 7.2 工程要求  
188 -  
189 -- 禁止 `fetchall` 全量入内存,改为流式读取(分页/游标)  
190 -- ES 扫描采用 `search_after` 流式聚合  
191 -- 批量写入采用 bulk(分块 + 重试 + 失败重放)  
192 -  
193 -## 7.3 发布策略  
194 -  
195 -1. `search_suggestions_v2_YYYYMMDDHHmm` 写入完成  
196 -2. 校验 count/抽样查询/核心词覆盖  
197 -3. alias 原子切换到新索引  
198 -4. 保留上一个版本用于快速回滚  
199 -  
200 ----  
201 -  
202 -## 8. API 契约(V2)  
203 -  
204 -请求:  
205 -  
206 -- `GET /search/suggestions`  
207 -- 参数:`q`、`language`、`size`  
208 -- Header:`X-Tenant-ID`  
209 -  
210 -响应:  
211 -  
212 -```json  
213 -{  
214 - "query": "iph",  
215 - "language": "en",  
216 - "resolved_language": "en",  
217 - "suggestions": [  
218 - {  
219 - "text": "iphone 15",  
220 - "lang": "en",  
221 - "score": 8.31,  
222 - "rank_score": 6.72,  
223 - "sources": ["query_log", "qanchor"]  
224 - }  
225 - ],  
226 - "took_ms": 12  
227 -}  
228 -```  
229 -  
230 -删除项(明确不支持):  
231 -  
232 -- `with_results`  
233 -- `result_size`  
234 -- `products[]`  
235 -  
236 ----  
237 -  
238 -## 9. 观测与治理  
239 -  
240 -核心监控:  
241 -  
242 -- QPS、P50/P95/P99、错误率  
243 -- 空结果率(按语言、按租户)  
244 -- suggestion 覆盖率(top query 是否命中)  
245 -- 语言冲突率(log vs request_params)  
246 -- 噪声词比例、黑名单命中率  
247 -  
248 -治理机制:  
249 -  
250 -- 黑名单:强制下线  
251 -- 白名单:强制保留并可加权  
252 -- 最小热度阈值:低频垃圾词过滤  
253 -- 时间衰减:过期词自动下沉  
254 -  
255 ----  
256 -  
257 -## 10. 与官方最佳实践对齐(ES)  
258 -  
259 -本方案直接采用以下官方建议:  
260 -  
261 -1. `completion` 适合高性能自动补全,支持 `skip_duplicates` 与上下文过滤。  
262 -2. `search_as_you_type + bool_prefix` 是官方推荐的 as-you-type 查询方式。  
263 -3. `edge_ngram` 仅用于索引时分词,查询时应用普通 analyzer(`search_analyzer`)。  
264 -4. 多语言场景使用 ICU Analysis 插件增强 Unicode 处理。  
265 -5. 通过 `routing` 将租户请求路由到单分片,降低 fan-out。  
266 -  
267 ----  
268 -  
269 -## 11. 分阶段落地  
270 -  
271 -1. Phase 1(本次):去除结果直达,稳定 Suggest 单能力  
272 -2. Phase 2:流式增量构建 + alias 原子发布  
273 -3. Phase 3:行为信号排序(CTR/CVR)+ 运营治理台  
274 -4. Phase 4:大租户独享索引自动升降级  
275 -  
276 ----  
277 -  
278 -## 12. Phase 2 落地命令(当前仓库)  
279 -  
280 -全量重建(版本化索引 + alias 发布):  
281 -  
282 -```bash  
283 -python main.py build-suggestions \  
284 - --tenant-id 162 \  
285 - --mode full \  
286 - --days 365 \  
287 - --publish-alias \  
288 - --keep-versions 2  
289 -```  
290 -  
291 -增量更新(基于 watermark):  
292 -  
293 -```bash  
294 -python main.py build-suggestions \  
295 - --tenant-id 162 \  
296 - --mode incremental \  
297 - --overlap-minutes 30  
298 -```  
299 -  
300 -一键脚本(全量 + 增量 + ES/API 验证):  
301 -  
302 -```bash  
303 -./scripts/rebuild_suggestions.sh 162  
304 -```  
suggestion/README.md
1 -# Suggestion 模块说明(统一入口) 1 +# Suggestion 模块说明
2 2
3 -本文档是 suggestion 模块的统一入口,遵循 `docs/DEVELOPER_GUIDE.md` 的“单一入口、避免分叉”原则 3 +`suggestion/` 目录负责搜索框自动补全能力,当前实现只关注 suggestion 本身:离线构建建议词索引,在线根据输入前缀返回建议词列表
4 4
5 -## 1. 当前状态(Phase 2) 5 +这份 README 以当前代码实现为准,重点说明模块现状、关键设计、索引结构、构建发布方式,以及在线检索和排序细节。
6 6
7 -- 仅保留 Suggest 自动补全能力  
8 -- 不支持结果直达(`with_results` / `result_size` / `products[]` 已移除)  
9 -- 索引采用版本化发布:  
10 - - 物理索引:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_v<timestamp>`  
11 - - 读别名:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_current`  
12 -- 支持增量更新(watermark + overlap) 7 +## 1. 当前能力边界
13 8
14 -## 2. 文档导航(唯一推荐顺序) 9 +- 对外接口:`GET /search/suggestions`
  10 +- 输入参数:`q`、`size`、`language`
  11 +- Header:`X-Tenant-ID`
  12 +- 返回内容:建议词列表 `suggestions[]`
  13 +- 不做商品结果拼接,也不走二次商品查询链路
  14 +
  15 +当前模块由三部分组成:
  16 +
  17 +1. 离线构建:从商品索引和搜索日志构建 suggestion 文档
  18 +2. 索引发布:写入版本化索引,并通过 alias 原子切换
  19 +3. 在线查询:优先走 completion,必要时再走 `search_as_you_type` 兜底召回
  20 +
  21 +## 2. 目录与关键代码
  22 +
  23 +- [builder.py](/data/saas-search/suggestion/builder.py):离线构建、增量更新、alias 发布、meta 状态维护
  24 +- [mapping.py](/data/saas-search/suggestion/mapping.py):suggestion 索引 settings 和 mappings 生成
  25 +- [service.py](/data/saas-search/suggestion/service.py):在线查询服务,负责语言归一化、双路召回、去重和最终排序
  26 +- [RUNBOOK.md](/data/saas-search/suggestion/RUNBOOK.md):构建、发布、验证操作说明
  27 +- [TROUBLESHOOTING.md](/data/saas-search/suggestion/TROUBLESHOOTING.md):常见问题排查
  28 +
  29 +命令入口在 [main.py](/data/saas-search/main.py) 中的 `build-suggestions` 子命令。
  30 +
  31 +## 3. 整体架构
  32 +
  33 +在线路径:
  34 +
  35 +`Client -> /search/suggestions -> SuggestionService -> Elasticsearch suggestion alias`
  36 +
  37 +离线路径:
  38 +
  39 +`商品索引 + 搜索日志 -> SuggestionIndexBuilder -> 版本化 suggestion index -> alias publish`
  40 +
  41 +设计上有几个核心点:
  42 +
  43 +- suggestion 独立建索引,不依赖在线商品检索
  44 +- 每个租户单独维护 suggestion alias,避免租户间相互影响
  45 +- 全量构建写新索引,切换 alias 时零停机
  46 +- 增量更新只处理 query log 增量,减少重建成本
  47 +
  48 +## 4. 索引组织与发布
  49 +
  50 +索引命名在 [builder.py](/data/saas-search/suggestion/builder.py) 中统一定义:
  51 +
  52 +- 读别名:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_current`
  53 +- 版本索引:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_v<timestamp>`
  54 +- 元信息索引:`{ES_INDEX_NAMESPACE}search_suggestions_meta`
  55 +
  56 +当前实现是“每租户一个 suggestion alias + 多个版本索引”的模式,而不是环境级共享大索引。
  57 +
  58 +全量构建时的发布流程:
  59 +
  60 +1. 创建新的版本化索引
  61 +2. 写入本次构建出的 suggestion 文档
  62 +3. 校验新索引可分配、可读
  63 +4. alias 原子切换到新索引
  64 +5. 清理旧版本索引,只保留最近若干份
  65 +6. 更新 `search_suggestions_meta`
  66 +
  67 +元信息索引里记录:
  68 +
  69 +- `active_alias`
  70 +- `active_index`
  71 +- `last_full_build_at`
  72 +- `last_incremental_build_at`
  73 +- `last_incremental_watermark`
  74 +
  75 +这些信息主要服务于增量更新和排障。
  76 +
  77 +## 5. Mapping 与索引字段
  78 +
  79 +[mapping.py](/data/saas-search/suggestion/mapping.py) 会根据租户的 `index_languages` 动态生成字段。
  80 +
  81 +### 5.1 索引设置
  82 +
  83 +- `number_of_shards = 1`
  84 +- `number_of_replicas = 0`
  85 +- `refresh_interval = 30s`
  86 +
  87 +中文使用自定义 analyzer:
  88 +
  89 +- `index_ik`:`ik_max_word + lowercase + asciifolding`
  90 +- `query_ik`:`ik_smart + lowercase + asciifolding`
  91 +
  92 +其他语言优先使用 Elasticsearch 内置 analyzer,例如 `english`、`arabic`、`french`、`german` 等;未覆盖语言回退到 `standard`。
  93 +
  94 +### 5.2 核心字段
  95 +
  96 +- `tenant_id`:租户隔离
  97 +- `lang`:建议词所属语言
  98 +- `text`:原始展示文本
  99 +- `text_norm`:归一化文本,用于唯一键和去重
  100 +- `sources`:来源集合,可能包含 `title`、`qanchor`、`tag`、`query_log`
  101 +- `title_doc_count` / `qanchor_doc_count` / `tag_doc_count`:该词被多少商品字段支撑
  102 +- `query_count_7d` / `query_count_30d`:近 7/30 天搜索热度
  103 +- `rank_score`:离线预计算排序分
  104 +- `lang_confidence` / `lang_source` / `lang_conflict`:语言识别与冲突信息
  105 +- `status`:当前是否有效
  106 +- `updated_at`:最近更新时间
  107 +
  108 +### 5.3 两类检索字段
  109 +
  110 +1. `completion.<lang>`
  111 +2. `sat.<lang>`
  112 +
  113 +`completion.<lang>` 用于极速前缀补全,是短 query 下的主召回通道。
  114 +
  115 +`sat.<lang>` 使用 `search_as_you_type`,用于多词前缀和 completion 未补足时的兜底召回。
  116 +
  117 +也就是说,当前线上不是只靠一种召回方式,而是 completion 优先、SAT 补全。
  118 +
  119 +## 6. 候选词从哪里来
  120 +
  121 +[builder.py](/data/saas-search/suggestion/builder.py) 在全量构建中会聚合两大类数据源。
  122 +
  123 +### 6.1 商品侧
  124 +
  125 +从租户商品索引中流式读取:
  126 +
  127 +- `title`
  128 +- `qanchors`
  129 +- `enriched_tags`
  130 +
  131 +处理方式:
  132 +
  133 +- `title.<lang>`:经 `_prepare_title_for_suggest()` 裁剪后作为候选词
  134 +- `qanchors.<lang>`:按分隔符拆分后作为候选词
  135 +- `enriched_tags`:支持多语言对象或普通列表,必要时做语言识别
  136 +
  137 +商品扫描不是一次性全量拉入内存,而是通过 `search_after` 分批读取,这一点在 [_iter_products()](/data/saas-search/suggestion/builder.py#L363) 已实现。
  138 +
  139 +### 6.2 搜索日志侧
  140 +
  141 +从 MySQL `shoplazza_search_log` 中按时间窗口流式读取:
  142 +
  143 +- `query`
  144 +- `language`
  145 +- `request_params`
  146 +- `create_time`
  147 +
  148 +读取方式使用 `stream_results=True + fetchmany()`,避免 `fetchall()` 带来的内存风险,这也是当前实现相对旧方案的重要改进。
  149 +
  150 +搜索日志主要用于补充:
  151 +
  152 +- 用户真实搜索词
  153 +- 近 7/30 天热度
  154 +- 语言归属信息
  155 +
  156 +## 7. 文本清洗与语言策略
  157 +
  158 +### 7.1 文本归一化
  159 +
  160 +在 [_normalize_text()](/data/saas-search/suggestion/builder.py#L176) 中,当前实现会做:
  161 +
  162 +- Unicode `NFKC` 归一化
  163 +- 去首尾空白
  164 +- 转小写
  165 +- 多空白折叠为单空格
  166 +
  167 +这份 `text_norm` 是 suggestion 文档的稳定键的一部分,文档 `_id` 形式为:
  168 +
  169 +`{tenant_id}|{lang}|{text_norm}`
  170 +
  171 +这保证了同租户、同语言、同一归一化词面只会保留一份文档。
  172 +
  173 +### 7.2 噪声过滤
  174 +
  175 +在 [_looks_noise()](/data/saas-search/suggestion/builder.py#L264) 中,以下内容会被过滤:
  176 +
  177 +- 空文本
  178 +- 长度超过 120
  179 +- 全部由符号组成的文本
15 180
16 -1. `ARCHITECTURE_V2.md`:架构与设计原则  
17 -2. `RUNBOOK.md`:构建/发布/验证流程  
18 -3. `TROUBLESHOOTING.md`:常见问题排查 181 +### 7.3 语言判定优先级
19 182
20 -## 3. 命令入口 183 +日志 query 的语言归属由 [_resolve_query_language()](/data/saas-search/suggestion/builder.py#L299) 负责,优先级是:
21 184
22 -- 全量或增量构建: 185 +1. `shoplazza_search_log.language`
  186 +2. `request_params.language`
  187 +3. `detect_text_language_for_suggestions()`
  188 +4. 租户 `primary_language`
  189 +
  190 +同时会记录:
  191 +
  192 +- `lang_source`:语言来自哪里
  193 +- `lang_confidence`:识别置信度
  194 +- `lang_conflict`:日志语言与请求语言是否冲突
  195 +
  196 +在线查询侧在 [_resolve_language()](/data/saas-search/suggestion/service.py#L24) 也会做一次语言归一化,确保查询只打到租户允许的 `index_languages`。
  197 +
  198 +## 8. 排序与 rank 细节
  199 +
  200 +当前排序分成两层:离线 `rank_score`,以及在线最终排序。
  201 +
  202 +### 8.1 离线 `rank_score`
  203 +
  204 +在 [_compute_rank_score()](/data/saas-search/suggestion/builder.py#L338) 中,当前公式是:
  205 +
  206 +```text
  207 +rank_score =
  208 + 1.8 * log1p(query_count_30d)
  209 + + 1.2 * log1p(query_count_7d)
  210 + + 1.0 * log1p(qanchor_doc_count)
  211 + + 0.85 * log1p(tag_doc_count)
  212 + + 0.6 * log1p(title_doc_count)
  213 +```
  214 +
  215 +含义上是:
  216 +
  217 +- 搜索日志热度权重大于商品静态字段
  218 +- 30 天热度权重大于 7 天热度,但 7 天热度也会强化近期趋势
  219 +- `qanchor` 比普通标题更像“可搜索表达”,所以权重更高
  220 +- `tag` 次之
  221 +- `title` 提供基础覆盖,但权重相对更低
  222 +
  223 +这个分数会被写入:
  224 +
  225 +- 文档字段 `rank_score`
  226 +- `completion.<lang>.weight`
  227 +
  228 +因此它同时影响 completion 通道和 SAT 通道。
  229 +
  230 +### 8.2 在线召回排序
  231 +
  232 +[service.py](/data/saas-search/suggestion/service.py) 中的在线策略如下:
  233 +
  234 +1. 先查 `completion.<lang>`
  235 +2. 若 query 长度大于 2 且 completion 结果不足,再查 `sat.<lang>`
  236 +3. 按 `text` 归一化结果去重
  237 +4. 最终排序后截断
  238 +
  239 +completion 通道本身依赖 ES completion 的 `_score`;SAT 通道则用 `function_score + field_value_factor(rank_score)`。
  240 +
  241 +最终排序由 [_finalize_suggestion_list()](/data/saas-search/suggestion/service.py#L155) 负责,排序 key 为:
  242 +
  243 +1. `score * 长度惩罚系数`
  244 +2. `rank_score`
  245 +
  246 +长度惩罚定义在 [_suggestion_length_factor()](/data/saas-search/suggestion/service.py#L16):
  247 +
  248 +```text
  249 +length_factor = 1 / sqrt(token_len)
  250 +```
  251 +
  252 +这意味着在分数相近时,较短、较直接的 suggestion 会更容易排在前面,避免长尾长句把前缀补全结果“顶掉”。
  253 +
  254 +## 9. 在线查询细节
  255 +
  256 +[SuggestionService.search()](/data/saas-search/suggestion/service.py#L110) 的行为可以概括为:
  257 +
  258 +- 如果 alias 不存在,直接返回空数组,不抛 500
  259 +- 短 query 优先走 completion 快速返回
  260 +- 对于更长 query,再补一次 `bool_prefix`
  261 +- 查询时始终带 `routing=tenant_id`
  262 +
  263 +这里的 `routing` 很重要,它保证 suggestion 查询尽量只落在目标租户对应的分片路由上,减少无效 fan-out。
  264 +
  265 +SAT 查询部分还有两个显式过滤条件:
  266 +
  267 +- `lang == resolved_language`
  268 +- `status == 1`
  269 +
  270 +这能保证召回结果只来自当前语言、当前有效文档。
  271 +
  272 +## 10. 全量构建与增量更新
  273 +
  274 +### 10.1 全量构建
  275 +
  276 +入口在 [main.py](/data/saas-search/main.py#L104) 的 `build-suggestions --mode full`。
  277 +
  278 +行为是:
  279 +
  280 +1. 读取租户配置中的 `index_languages` 和 `primary_language`
  281 +2. 创建新版本索引并等待 ready
  282 +3. 聚合商品数据和搜索日志,构造候选词
  283 +4. 计算 `rank_score`
  284 +5. bulk 写入新索引
  285 +6. refresh
  286 +7. 发布 alias
  287 +8. 更新 meta 信息
  288 +
  289 +### 10.2 增量更新
  290 +
  291 +入口在 [main.py](/data/saas-search/main.py#L104) 的 `build-suggestions --mode incremental`。
  292 +
  293 +当前增量只处理 query log,不回扫商品数据。它依赖 meta 中的 watermark:
  294 +
  295 +- `last_incremental_watermark`
  296 +- 不存在时回退到 `last_full_build_at`
  297 +- 再不行就使用 `fallback_days`
  298 +
  299 +为了避免边界时间漏数,会额外减去 `overlap_minutes`,形成一个带重叠窗口的增量区间。
  300 +
  301 +增量写入不是整文档重建,而是通过 `scripted_upsert` 做原地累加:
  302 +
  303 +- 增加 `query_count_30d`
  304 +- 增加 `query_count_7d`
  305 +- 更新 `lang_confidence` / `lang_source` / `lang_conflict`
  306 +- 重新计算 `rank_score`
  307 +- 更新 `completion` 和 `sat`
  308 +
  309 +对应逻辑在 [_build_incremental_update_script()](/data/saas-search/suggestion/builder.py#L834)。
  310 +
  311 +如果 alias 尚不存在,而 `bootstrap_if_missing=True`,增量任务会先自动做一次全量构建作为初始化。
  312 +
  313 +## 11. 当前实现的一些取舍
  314 +
  315 +### 11.1 优点
  316 +
  317 +- 在线链路很短,没有 suggestion 后再查商品的放大成本
  318 +- 构建和发布流程清晰,支持零停机切换
  319 +- 商品侧和日志侧都采用流式处理,能控制内存占用
  320 +- completion + SAT 双路召回兼顾低延迟和补全能力
  321 +- 语言、热度、来源信息都保存在索引中,便于后续优化
  322 +
  323 +### 11.2 现阶段边界
  324 +
  325 +- 增量更新目前只增量处理 query log,商品标题、qanchor、tag 变更仍依赖全量构建刷新
  326 +- `rank_score` 目前只使用热度和商品字段覆盖度,没有接入点击、转化等行为质量信号
  327 +- 文本归一化目前以 `NFKC + lower + whitespace fold` 为主,尚未做更激进的跨语种归并策略
  328 +- `number_of_replicas=0` 更偏开发或成本优先配置,生产是否需要副本要结合集群策略评估
  329 +
  330 +## 12. 常用命令
  331 +
  332 +全量构建:
23 333
24 ```bash 334 ```bash
25 ./scripts/build_suggestions.sh <tenant_id> --mode full 335 ./scripts/build_suggestions.sh <tenant_id> --mode full
26 -./scripts/build_suggestions.sh <tenant_id> --mode incremental  
27 ``` 336 ```
28 337
29 -- 一键重建 + 验证 338 +增量构建
30 339
31 ```bash 340 ```bash
32 -./scripts/rebuild_suggestions.sh <tenant_id> 341 +./scripts/build_suggestions.sh <tenant_id> --mode incremental
33 ``` 342 ```
34 343
35 -## 4. API 约定(简版) 344 +一键重建并验证:
36 345
37 -- 端点:`GET /search/suggestions`  
38 -- 参数:`q`, `size`, `language`  
39 -- Header:`X-Tenant-ID` 346 +```bash
  347 +./scripts/rebuild_suggestions.sh <tenant_id>
  348 +```
40 349
41 -示例: 350 +接口示例:
42 351
43 ```bash 352 ```bash
44 curl "http://localhost:6002/search/suggestions?q=shi&size=10&language=en" \ 353 curl "http://localhost:6002/search/suggestions?q=shi&size=10&language=en" \
45 -H "X-Tenant-ID: 162" 354 -H "X-Tenant-ID: 162"
46 ``` 355 ```
  356 +
  357 +更多操作细节见 [RUNBOOK.md](/data/saas-search/suggestion/RUNBOOK.md),故障排查看 [TROUBLESHOOTING.md](/data/saas-search/suggestion/TROUBLESHOOTING.md)。
suggestion/RUNBOOK.md
@@ -34,7 +34,7 @@ DB_PASSWORD=... @@ -34,7 +34,7 @@ DB_PASSWORD=...
34 ### 3.1 执行 34 ### 3.1 执行
35 35
36 ```bash 36 ```bash
37 -./scripts/build_suggestions.sh 162 \ 37 +./scripts/build_suggestions.sh 163 \
38 --mode full \ 38 --mode full \
39 --days 365 \ 39 --days 365 \
40 --publish-alias \ 40 --publish-alias \
@@ -56,7 +56,7 @@ DB_PASSWORD=... @@ -56,7 +56,7 @@ DB_PASSWORD=...
56 ### 4.1 执行 56 ### 4.1 执行
57 57
58 ```bash 58 ```bash
59 -./scripts/build_suggestions.sh 162 \ 59 +./scripts/build_suggestions.sh 163 \
60 --mode incremental \ 60 --mode incremental \
61 --overlap-minutes 30 61 --overlap-minutes 30
62 ``` 62 ```
@@ -76,7 +76,7 @@ DB_PASSWORD=... @@ -76,7 +76,7 @@ DB_PASSWORD=...
76 > 若 ES 开启鉴权,请附带 `-u "$ES_USERNAME:$ES_PASSWORD"`。 76 > 若 ES 开启鉴权,请附带 `-u "$ES_USERNAME:$ES_PASSWORD"`。
77 77
78 ```bash 78 ```bash
79 -ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_162_current" 79 +ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_163_current"
80 80
81 curl "$ES_HOST/$ALIAS_NAME/_count?pretty" 81 curl "$ES_HOST/$ALIAS_NAME/_count?pretty"
82 82
@@ -97,10 +97,10 @@ curl &quot;$ES_HOST/$ALIAS_NAME/_search?pretty&quot; -H &#39;Content-Type: application/json&#39; - @@ -97,10 +97,10 @@ curl &quot;$ES_HOST/$ALIAS_NAME/_search?pretty&quot; -H &#39;Content-Type: application/json&#39; -
97 97
98 ```bash 98 ```bash
99 curl "http://localhost:6002/search/suggestions?q=shirt&size=10&language=en" \ 99 curl "http://localhost:6002/search/suggestions?q=shirt&size=10&language=en" \
100 - -H "X-Tenant-ID: 162" 100 + -H "X-Tenant-ID: 163"
101 101
102 curl "http://localhost:6002/search/suggestions?q=玩具&size=10&language=zh" \ 102 curl "http://localhost:6002/search/suggestions?q=玩具&size=10&language=zh" \
103 - -H "X-Tenant-ID: 162" 103 + -H "X-Tenant-ID: 163"
104 ``` 104 ```
105 105
106 通过标准: 106 通过标准:
@@ -112,7 +112,7 @@ curl &quot;http://localhost:6002/search/suggestions?q=玩具&amp;size=10&amp;language=zh&quot; \ @@ -112,7 +112,7 @@ curl &quot;http://localhost:6002/search/suggestions?q=玩具&amp;size=10&amp;language=zh&quot; \
112 ## 7. 一键验证脚本 112 ## 7. 一键验证脚本
113 113
114 ```bash 114 ```bash
115 -./scripts/rebuild_suggestions.sh 162 115 +./scripts/rebuild_suggestions.sh 163
116 ``` 116 ```
117 117
118 该脚本执行: 118 该脚本执行: