Commit 6e3e677078c096cef77127bd0426099e16159cef
1 parent
9f33fe3c
suggest文档维护
Showing
14 changed files
with
351 additions
and
1657 deletions
Show diff stats
docs/issues/issue-2026-04-06-推理优化-2.md deleted
| ... | ... | @@ -1,69 +0,0 @@ |
| 1 | -这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查: | |
| 2 | - | |
| 3 | -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | |
| 4 | -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | |
| 5 | - | |
| 6 | -示例prompt和对应的9个分类词: | |
| 7 | -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} | |
| 8 | - | |
| 9 | -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。 | |
| 10 | - | |
| 11 | -使用Tesla T4,因此: | |
| 12 | -1. 使用FP16。不用 BF16 作为主路径 | |
| 13 | -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | |
| 14 | -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 | |
| 15 | - | |
| 16 | - | |
| 17 | -主要考虑优化方向为: | |
| 18 | -hidden_last -> N-class scorer -> argmax | |
| 19 | -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | |
| 20 | -去 full vocab logits | |
| 21 | -去 decode / constrained decode | |
| 22 | -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) | |
| 23 | -专用 tail kernel(输出 N 类原始分数) | |
| 24 | - | |
| 25 | -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: | |
| 26 | -15.1 预分配 | |
| 27 | -对每个 bucket 预分配: | |
| 28 | -input ids buffer | |
| 29 | -attention scratch buffer | |
| 30 | -output hidden buffer | |
| 31 | -class id buffer | |
| 32 | -15.2 Graph capture | |
| 33 | - | |
| 34 | -每个 bucket 预先 capture 一张图: | |
| 35 | -embedding | |
| 36 | -transformer 主干 | |
| 37 | -compact last-state | |
| 38 | -9-way tail kernel | |
| 39 | - | |
| 40 | -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | |
| 41 | - | |
| 42 | -你有sudo权限,你可以执行为本项目安装自己的环境 | |
| 43 | - | |
| 44 | -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) | |
| 45 | -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 | |
| 46 | - | |
| 47 | - | |
| 48 | -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | |
| 49 | -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 | |
| 50 | -注意使用fp16版本,不要使用量化版本。 | |
| 51 | - | |
| 52 | -上面的需求你已经满足了大部分, | |
| 53 | -但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。 | |
| 54 | -需要通过多 token 标签做极致的性能优化,避免串行decode。 | |
| 55 | -一个考虑的方向是: | |
| 56 | -将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch) | |
| 57 | - | |
| 58 | -只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。 | |
| 59 | - | |
| 60 | -所有后续聚合均为向量化操作(GPU 上几乎不花时间)。 | |
| 61 | - | |
| 62 | -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 | |
| 63 | - | |
| 64 | -参考文档: | |
| 65 | - | |
| 66 | -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ | |
| 67 | -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ | |
| 68 | -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | |
| 69 | -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq | |
| 70 | 0 | \ No newline at end of file |
docs/issues/issue-2026-04-06-推理优化-3.md deleted
| ... | ... | @@ -1,98 +0,0 @@ |
| 1 | -这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查: | |
| 2 | - | |
| 3 | -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | |
| 4 | -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | |
| 5 | - | |
| 6 | -示例prompt和对应的9个分类词: | |
| 7 | -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} | |
| 8 | - | |
| 9 | -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。 | |
| 10 | - | |
| 11 | -使用Tesla T4,因此: | |
| 12 | -1. 使用FP16。不用 BF16 作为主路径 | |
| 13 | -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | |
| 14 | -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 | |
| 15 | - | |
| 16 | - | |
| 17 | -主要考虑优化方向为: | |
| 18 | -hidden_last -> N-class scorer -> argmax | |
| 19 | -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | |
| 20 | -去 full vocab logits | |
| 21 | -去 decode / constrained decode | |
| 22 | -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) | |
| 23 | -专用 tail kernel(输出 N 类原始分数) | |
| 24 | - | |
| 25 | -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: | |
| 26 | -15.1 预分配 | |
| 27 | -对每个 bucket 预分配: | |
| 28 | -input ids buffer | |
| 29 | -attention scratch buffer | |
| 30 | -output hidden buffer | |
| 31 | -class id buffer | |
| 32 | -15.2 Graph capture | |
| 33 | - | |
| 34 | -每个 bucket 预先 capture 一张图: | |
| 35 | -embedding | |
| 36 | -transformer 主干 | |
| 37 | -compact last-state | |
| 38 | -9-way tail kernel | |
| 39 | - | |
| 40 | -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | |
| 41 | - | |
| 42 | -你有sudo权限,你可以执行为本项目安装自己的环境 | |
| 43 | - | |
| 44 | -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) | |
| 45 | -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 | |
| 46 | - | |
| 47 | - | |
| 48 | -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | |
| 49 | -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 | |
| 50 | -注意使用fp16版本,不要使用量化版本。 | |
| 51 | - | |
| 52 | -上面的需求你已经满足了大部分, | |
| 53 | -但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。 | |
| 54 | -需要通过多 token 标签做极致的性能优化,避免串行decode。 | |
| 55 | -一个考虑的方向是: | |
| 56 | -将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch) | |
| 57 | - | |
| 58 | -只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。 | |
| 59 | - | |
| 60 | -所有后续聚合均为向量化操作(GPU 上几乎不花时间)。 | |
| 61 | - | |
| 62 | -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 | |
| 63 | - | |
| 64 | -参考文档: | |
| 65 | - | |
| 66 | -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ | |
| 67 | -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ | |
| 68 | -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | |
| 69 | -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq | |
| 70 | - | |
| 71 | - | |
| 72 | -现在还有以下问题: | |
| 73 | -1. 请满足: | |
| 74 | -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | |
| 75 | - | |
| 76 | -2. 打分结果不对:不管输入什么都输出jeans,输入裙子也输出jeans,只有精确的输入skirt才能输出skirt。 | |
| 77 | - | |
| 78 | -3. 现在是一个prompt,我需要增加6个prompt,都进行配置化,每次输入一个query、对7个prompt进行批量的推理,对每个prompt都输出词典中每个分类的打分分布 | |
| 79 | -人群 | |
| 80 | -Analyze the target user group of the given query. Output exactly one label from: boy, girl, man, woman, pregnant. Output nothing else query:{query} | |
| 81 | - | |
| 82 | -尺寸 | |
| 83 | -Analyze the size intent of the given query. Output exactly one label from: xs, s, m, l, xl, xxl, plus, petite. Output nothing else query:{query} | |
| 84 | - | |
| 85 | -领口 | |
| 86 | -Analyze the neckline type of the given query. Output exactly one label from: round, v, turtleneck, collared, scoop, offshoulder. Output nothing else query:{query} | |
| 87 | - | |
| 88 | -袖长类型 | |
| 89 | -Analyze the sleeve length of the given query. Output exactly one label from: long, short, sleeveless, half, threequarter, cap. Output nothing else query:{query} | |
| 90 | - | |
| 91 | -上衣长度类型 | |
| 92 | -Analyze the top length of the given query. Output exactly one label from: long, short, regular, cropped, tunic, midi. Output nothing else query:{query} | |
| 93 | - | |
| 94 | -面料 | |
| 95 | -Analyze the fabric material of the given query. Output exactly one label from: cotton, linen, silk, wool, polyester, denim, leather, chiffon, fleece. Output nothing else query:{query} | |
| 96 | - | |
| 97 | - | |
| 98 | -注意要使用测试用例进行测试。包括打分结果是否符合预期、性能测试。 |
docs/issues/issue-2026-04-06-推理优化-4.md deleted
| ... | ... | @@ -1,4 +0,0 @@ |
| 1 | - | |
| 2 | -1. “仓库启动时会为所有 batch_sizes x continuation_buckets x prompt 预建 bucket”,之前可能是根据“极致的性能要求”、“不要做任何懒加载,确保真实请求发生时得到极致的响应时间”所做的设计,这显然太过了,我是希望一些基本的可以事先加载的应该先加载,但是牺牲巨大的显存占用来换取微弱的耗时提升是不提倡的,请你站在更高的角度,理会我的需求(先加载好模型、跨session的KV cache、并针对特殊用法 即score方式而不是逐步decode方式,来极致的优化性能,用于对线上的单个query,以最短耗时得到7个prompt的分类结果) | |
| 3 | - | |
| 4 | -2. 现在的 7 个 prompt推理是串行的,MultiPromptRunner.score_query() 里就是 for runner in self.runners 一个一个跑,需要把执行模型改成“按 prompt 分组批量并行”,但是现在有 2 个 fast path 和 5 个 multi-token path,是各合成一次 forward,还是有可能合成一个?因为multi-token也可以一次线prefill进去,是否能做到跟fast path同级别的性能?请你站在更高的角度进行思考,保证性能的同时降低复杂性。 | |
| 5 | 0 | \ No newline at end of file |
docs/issues/issue-2026-04-06-推理优化-重建.md deleted
| ... | ... | @@ -1,1003 +0,0 @@ |
| 1 | -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | |
| 2 | -先专注于推理的优化,最后再考虑服务化,支持一定程度的并发(比如4)的请求,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,使用N个prompt进行N个维度的分类,对于每个prompt,推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | |
| 3 | - | |
| 4 | -下面有一些参考技术资料,但是你并不需要严格,你应该有一定的灵活度,来追求极致的性能。 | |
| 5 | - | |
| 6 | -在 Tesla T4 上,用 3B 到 6B 级别的开源 decoder-only 基座模型做 query 分类。 | |
| 7 | -启动时完成 tokenizer、权重、prefix cache 和共享执行器准备工作。 | |
| 8 | -每次输入一个 query,输出每个 prompt 下每个 label 的分数分布,以及预测耗时和阶段耗时。 | |
| 9 | -不走通用生成路径,不做 decode,不取 full vocab logits,不做 constrained decode。 | |
| 10 | -对 multi-token label 做专门优化,避免 Python 侧串行 decode。 | |
| 11 | -prompt 和 label 集合必须可配置(目前只有两个,以后我会加到8个,每次请求输入一个query,并行的调用8个prompt进行推理得到得分最高的label: | |
| 12 | - | |
| 13 | - | |
| 14 | -prompts: | |
| 15 | - - name: category | |
| 16 | - prompt_template: | | |
| 17 | - <|im_start|>system | |
| 18 | - Analyze the category intent. Output exactly one label from [none, dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other]. Use 'none' if no category intent.<|im_end|> | |
| 19 | - <|im_start|>user | |
| 20 | - query: {query}<|im_end|> | |
| 21 | - <|im_start|>assistant | |
| 22 | - label: | |
| 23 | - label_prefix: " " | |
| 24 | - labels: [dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other, none] | |
| 25 | - | |
| 26 | - - name: audience | |
| 27 | - prompt_template: | | |
| 28 | - <|im_start|>system | |
| 29 | - Analyze the target user group. Output exactly one label from [none, boy, girl, man, woman, pregnant]. Use 'none' if no audience mentioned.<|im_end|> | |
| 30 | - <|im_start|>user | |
| 31 | - query: {query}<|im_end|> | |
| 32 | - <|im_start|>assistant | |
| 33 | - label: | |
| 34 | - label_prefix: " " | |
| 35 | - labels: [boy, girl, man, woman, pregnant, none] | |
| 36 | - | |
| 37 | - | |
| 38 | -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟。 | |
| 39 | - | |
| 40 | -主要考虑优化方向为: | |
| 41 | -1. hidden_last -> N-class scorer -> argmax | |
| 42 | -2. 参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | |
| 43 | -3. 去 full vocab logits | |
| 44 | -4. 去 decode / constrained decode | |
| 45 | -5. 专用 tail kernel(输出 N 类原始分数) | |
| 46 | -6. 配置的N个 prompt推理要并行推理(2-8个) | |
| 47 | -7. 使用Tesla T4,因此不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | |
| 48 | -多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | |
| 49 | - | |
| 50 | - | |
| 51 | -你有sudo权限,你可以执行为本项目安装自己的环境 | |
| 52 | - | |
| 53 | -使用Qwen/Qwen3-8B的Q4或Q8模型,具体用哪个版本,请你查找huggingface相关资料,选择合适的版本完成部署,并进行推理耗时的测试。 | |
| 54 | - | |
| 55 | -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | |
| 56 | - | |
| 57 | -一个重要的问题:一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况。 | |
| 58 | -需要通过多 token 标签做极致的性能优化,避免串行decode。 | |
| 59 | -我们最终目的是得到哪个label的得分最高,不一定要精确的概率,计算log P(id1 | query, prompt) + log P(id2 | query, prompt, id1)有可能导致难以优化性能,精确的概率是可以考虑放弃的,要清楚我们的最终目的,达到分类的目的即可,只要得到分类,优先考虑性能,精确的概率可以放下。 | |
| 60 | -如何通过一次模型 forward处理包括多token label的整个 batch,是你需要探索的问题。 | |
| 61 | - | |
| 62 | -单 token fast path 的做法比较确定: last_hidden -> small class scorer -> argmax。 只取目标 label 对应 LM head 行,不做 full vocab 输出。 | |
| 63 | -multi-token 怎么做需要搜索相关资料进行考量,最好要做到跟单token开销相同(放弃精确的log-prob的前提下。但是:多token和单token的label的打分的对比,一定要是可比的才能正确的分类,兼顾性能和打分的准确性) | |
| 64 | - | |
| 65 | -还需要增加一个配置:force_single_token_labels,所有 label 都按首 token 处理,因为,如果各个label收token不同,那么可以近似的认为首token打分代表整个label打分。 | |
| 66 | -你需要找到多label打分性能和准确性上面的最佳实践。同时也支持force_single_token_labels以达到极致的性能。 | |
| 67 | - | |
| 68 | -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 | |
| 69 | -以下是一些参考示例: | |
| 70 | -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ | |
| 71 | -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ | |
| 72 | -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | |
| 73 | -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq | |
| 74 | - | |
| 75 | -TensorRT support matrix: T4 / SM7.5 supports FP16 and INT8, but not BF16/FP8 in the main matrix. https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Support-Matrix-Guide.pdf | |
| 76 | -TensorRT-LLM support matrix: current official hardware list omits Turing/T4, so T4 is effectively community-support territory there. https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | |
| 77 | -FlashAttention repo: FA3 is Hopper-focused; current published benchmarks are A100/H100-centric. https://github.com/Dao-AILab/flash-attention | |
| 78 | -vLLM APC docs: KV reuse via hashed KV blocks was the baseline idea for the prefix-cache metadata. https://docs.vllm.ai/_/downloads/en/v0.6.2/pdf/ | |
| 79 | -SGLang HiCache/RadixAttention docs: useful reference for prefix-cache reuse and page-granular KV organization. https://docs.sglang.io/advanced_features/hicache_design.html | |
| 80 | -FasterTransformer repo: still a useful T4 FP16 optimization baseline and historical Turing-oriented reference. https://github.com/NVIDIA/FasterTransformer | |
| 81 | -xFormers README: relevant as a Turing-friendly attention alternative; my mainline choice here is PyTorch SDPA on T4, which is an engineering inference from these sources rather than a direct vendor recommendation. https://github.com/facebookresearch/xformers | |
| 82 | - | |
| 83 | -注意:已经有一个项目 llm-qp, llm-qp2,这两个项目,对于单token的处理方式是可以的: | |
| 84 | -SDPA | |
| 85 | -prefix cache | |
| 86 | -prebuilt bucket + CUDA graph | |
| 87 | -他的核心代码是: | |
| 88 | -from __future__ import annotations | |
| 89 | - | |
| 90 | -import hashlib | |
| 91 | -import time | |
| 92 | -from dataclasses import asdict, dataclass | |
| 93 | -from typing import Iterable | |
| 94 | - | |
| 95 | -import torch | |
| 96 | -import torch.nn.functional as F | |
| 97 | -from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache | |
| 98 | - | |
| 99 | -try: | |
| 100 | - from transformers import BitsAndBytesConfig | |
| 101 | -except ImportError: # pragma: no cover | |
| 102 | - BitsAndBytesConfig = None | |
| 103 | - | |
| 104 | -from llm_qp.config import PromptTaskConfig, RuntimeConfig | |
| 105 | -from llm_qp.scorer import SmallClassScorer | |
| 106 | - | |
| 107 | -try: | |
| 108 | - from torch.nn.attention import SDPBackend, sdpa_kernel | |
| 109 | -except ImportError: # pragma: no cover | |
| 110 | - SDPBackend = None | |
| 111 | - sdpa_kernel = None | |
| 112 | - | |
| 113 | - | |
| 114 | -@dataclass(slots=True) | |
| 115 | -class EncodedLabel: | |
| 116 | - text: str | |
| 117 | - token_ids: list[int] | |
| 118 | - | |
| 119 | - | |
| 120 | -@dataclass(slots=True) | |
| 121 | -class PrefixCache: | |
| 122 | - prefix_ids: list[int] | |
| 123 | - prefix_hashes: list[str] | |
| 124 | - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] | |
| 125 | - | |
| 126 | - @property | |
| 127 | - def prefix_len(self) -> int: | |
| 128 | - return len(self.prefix_ids) | |
| 129 | - | |
| 130 | - | |
| 131 | -@dataclass(slots=True) | |
| 132 | -class MultiTokenTables: | |
| 133 | - label_token_ids: torch.Tensor | |
| 134 | - label_token_mask: torch.Tensor | |
| 135 | - label_prefix_ids: torch.Tensor | |
| 136 | - label_prefix_mask: torch.Tensor | |
| 137 | - label_position_offsets: torch.Tensor | |
| 138 | - | |
| 139 | - @property | |
| 140 | - def max_label_len(self) -> int: | |
| 141 | - return self.label_token_ids.shape[1] | |
| 142 | - | |
| 143 | - @property | |
| 144 | - def max_label_prefix_len(self) -> int: | |
| 145 | - return self.label_prefix_ids.shape[1] | |
| 146 | - | |
| 147 | - | |
| 148 | -@dataclass(slots=True) | |
| 149 | -class QueryScoreResult: | |
| 150 | - task_name: str | |
| 151 | - query: str | |
| 152 | - predicted_label: str | |
| 153 | - scores: list[tuple[str, float, float]] | |
| 154 | - total_ms: float | |
| 155 | - stage_ms: dict[str, float] | |
| 156 | - fast_path: bool | |
| 157 | - prefix_tokens: int | |
| 158 | - continuation_tokens: int | |
| 159 | - label_token_lengths: dict[str, int] | |
| 160 | - | |
| 161 | - @property | |
| 162 | - def predicted_prob(self) -> float: | |
| 163 | - for label, _score, prob in self.scores: | |
| 164 | - if label == self.predicted_label: | |
| 165 | - return prob | |
| 166 | - return 0.0 | |
| 167 | - | |
| 168 | - | |
| 169 | -@dataclass(slots=True) | |
| 170 | -class MultiPromptScoreResult: | |
| 171 | - query: str | |
| 172 | - total_ms: float | |
| 173 | - details: list[QueryScoreResult] | |
| 174 | - stage_ms: dict[str, float] | |
| 175 | - | |
| 176 | - def http_json(self) -> dict[str, object]: | |
| 177 | - return { | |
| 178 | - "query": self.query, | |
| 179 | - "total_ms": self.total_ms, | |
| 180 | - "stage_ms": self.stage_ms, | |
| 181 | - "details": [asdict(t) for t in self.details], | |
| 182 | - "task_results": { | |
| 183 | - t.task_name: [t.predicted_label, t.continuation_tokens, t.predicted_prob] for t in self.details if t.predicted_label != 'none' | |
| 184 | - }, | |
| 185 | - } | |
| 186 | - | |
| 187 | - | |
| 188 | -@dataclass(slots=True) | |
| 189 | -class BatchScoreResult: | |
| 190 | - batch_size: int | |
| 191 | - total_ms: float | |
| 192 | - results: list[MultiPromptScoreResult] | |
| 193 | - stage_ms: dict[str, float] | |
| 194 | - | |
| 195 | - | |
| 196 | -@dataclass(slots=True) | |
| 197 | -class SharedRuntime: | |
| 198 | - device: torch.device | |
| 199 | - dtype: torch.dtype | |
| 200 | - tokenizer: object | |
| 201 | - model: object | |
| 202 | - backbone: object | |
| 203 | - hidden_size: int | |
| 204 | - graph_capture_pool: object | None = None | |
| 205 | - graph_capture_stream: torch.cuda.Stream | None = None | |
| 206 | - | |
| 207 | - | |
| 208 | -@dataclass(slots=True) | |
| 209 | -class PromptBatchPlan: | |
| 210 | - runner: "PromptClassifierRunner" | |
| 211 | - row_start: int | |
| 212 | - row_count: int | |
| 213 | - score_buffer: torch.Tensor | |
| 214 | - | |
| 215 | - @property | |
| 216 | - def row_stop(self) -> int: | |
| 217 | - return self.row_start + self.row_count | |
| 218 | - | |
| 219 | - | |
| 220 | -@dataclass(slots=True) | |
| 221 | -class MixedPrefixCache: | |
| 222 | - batch_size: int | |
| 223 | - total_rows: int | |
| 224 | - prefix_lengths: torch.Tensor | |
| 225 | - attention_mask: torch.Tensor | |
| 226 | - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] | |
| 227 | - | |
| 228 | - @property | |
| 229 | - def max_prefix_len(self) -> int: | |
| 230 | - return int(self.prefix_lengths.max().item()) if self.prefix_lengths.numel() else 0 | |
| 231 | - | |
| 232 | - | |
| 233 | -@dataclass(slots=True) | |
| 234 | -class BatchLayout: | |
| 235 | - batch_size: int | |
| 236 | - total_rows: int | |
| 237 | - plans: list[PromptBatchPlan] | |
| 238 | - | |
| 239 | - | |
| 240 | -@dataclass(slots=True) | |
| 241 | -class MixedBucketRuntime: | |
| 242 | - batch_size: int | |
| 243 | - total_rows: int | |
| 244 | - continuation_len: int | |
| 245 | - max_input_len: int | |
| 246 | - input_ids: torch.Tensor | |
| 247 | - attention_mask: torch.Tensor | |
| 248 | - position_ids: torch.Tensor | |
| 249 | - last_hidden_state: torch.Tensor | |
| 250 | - graph: torch.cuda.CUDAGraph | None = None | |
| 251 | - | |
| 252 | - | |
| 253 | -@dataclass(slots=True) | |
| 254 | -class PreloadReport: | |
| 255 | - total_ms: float | |
| 256 | - stage_ms: dict[str, float] | |
| 257 | - runtime: dict[str, object] | |
| 258 | - | |
| 259 | - | |
| 260 | -def _hash_blocks(token_ids: Iterable[int], block_size: int) -> list[str]: | |
| 261 | - token_list = list(token_ids) | |
| 262 | - hashes: list[str] = [] | |
| 263 | - for start in range(0, len(token_list), block_size): | |
| 264 | - block = token_list[start : start + block_size] | |
| 265 | - payload = ",".join(str(x) for x in block).encode("utf-8") | |
| 266 | - hashes.append(hashlib.sha1(payload).hexdigest()) | |
| 267 | - return hashes | |
| 268 | - | |
| 269 | - | |
| 270 | -def _expand_legacy_cache( | |
| 271 | - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...], | |
| 272 | - batch_size: int, | |
| 273 | -) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: | |
| 274 | - expanded: list[tuple[torch.Tensor, torch.Tensor]] = [] | |
| 275 | - for key, value in raw_cache: | |
| 276 | - expanded.append( | |
| 277 | - ( | |
| 278 | - key.expand(batch_size, *key.shape[1:]).contiguous(), | |
| 279 | - value.expand(batch_size, *value.shape[1:]).contiguous(), | |
| 280 | - ) | |
| 281 | - ) | |
| 282 | - return tuple(expanded) | |
| 283 | - | |
| 284 | - | |
| 285 | -class PromptClassifierRunner: | |
| 286 | - def __init__( | |
| 287 | - self, | |
| 288 | - cfg: RuntimeConfig, | |
| 289 | - task_cfg: PromptTaskConfig, | |
| 290 | - shared_runtime: SharedRuntime, | |
| 291 | - ): | |
| 292 | - self.cfg = cfg | |
| 293 | - self.task_cfg = task_cfg | |
| 294 | - self.device = shared_runtime.device | |
| 295 | - self.dtype = shared_runtime.dtype | |
| 296 | - self.tokenizer = shared_runtime.tokenizer | |
| 297 | - self.model = shared_runtime.model | |
| 298 | - self.backbone = shared_runtime.backbone | |
| 299 | - self.hidden_size = shared_runtime.hidden_size | |
| 300 | - self.prefix_text, self.suffix_text = task_cfg.prompt_parts | |
| 301 | - self.prefix_ids = self.tokenizer.encode(self.prefix_text, add_special_tokens=False) | |
| 302 | - self.suffix_ids = self.tokenizer.encode(self.suffix_text, add_special_tokens=False) | |
| 303 | - self.labels = list(task_cfg.labels) | |
| 304 | - self.encoded_labels = [ | |
| 305 | - EncodedLabel(text=label, token_ids=self._encode_label_token_ids(label)) | |
| 306 | - for label in self.labels | |
| 307 | - ] | |
| 308 | - self.num_labels = len(self.labels) | |
| 309 | - self.lm_head = self.model.get_output_embeddings() | |
| 310 | - self.lm_head_weight = self.lm_head.weight.detach() | |
| 311 | - self.lm_head_bias = self.lm_head.bias.detach() if getattr(self.lm_head, "bias", None) is not None else None | |
| 312 | - if self.cfg.force_single_token_labels and not self._has_unique_single_token_labels(): | |
| 313 | - raise ValueError( | |
| 314 | - f"prompt task '{self.task_cfg.name}' cannot force single-token labels because first tokens collide" | |
| 315 | - ) | |
| 316 | - self.fast_path = self._has_unique_single_token_labels() | |
| 317 | - self.fast_path_token_ids = [item.token_ids[0] for item in self.encoded_labels] if self.fast_path else [] | |
| 318 | - self.scorer = self._build_scorer() if self.fast_path else None | |
| 319 | - self.multi_token_tables = self._build_multi_token_tables() if not self.fast_path else None | |
| 320 | - self.prefix_cache = self._build_prefix_cache() | |
| 321 | - | |
| 322 | - def _encode_label_token_ids(self, label: str) -> list[int]: | |
| 323 | - token_ids = self.tokenizer.encode( | |
| 324 | - f"{self.task_cfg.label_prefix}{label}", | |
| 325 | - add_special_tokens=False, | |
| 326 | - ) | |
| 327 | - if not token_ids: | |
| 328 | - raise ValueError(f"label '{label}' in prompt '{self.task_cfg.name}' tokenizes to an empty sequence") | |
| 329 | - if self.cfg.force_single_token_labels: | |
| 330 | - return token_ids[:1] | |
| 331 | - return token_ids | |
| 332 | - | |
| 333 | - def _has_unique_single_token_labels(self) -> bool: | |
| 334 | - token_ids: list[int] = [] | |
| 335 | - for item in self.encoded_labels: | |
| 336 | - if len(item.token_ids) != 1: | |
| 337 | - return False | |
| 338 | - token_ids.append(item.token_ids[0]) | |
| 339 | - return len(token_ids) == len(set(token_ids)) | |
| 340 | - | |
| 341 | - def _build_scorer(self) -> SmallClassScorer: | |
| 342 | - token_ids = torch.tensor(self.fast_path_token_ids, dtype=torch.long, device=self.device) | |
| 343 | - weights = torch.index_select(self.lm_head_weight, 0, token_ids).to(dtype=self.dtype).contiguous() | |
| 344 | - bias = None | |
| 345 | - if self.lm_head_bias is not None: | |
| 346 | - bias = torch.index_select(self.lm_head_bias, 0, token_ids).to(dtype=self.dtype).contiguous() | |
| 347 | - return SmallClassScorer(weights=weights, bias=bias) | |
| 348 | - | |
| 349 | - def _build_multi_token_tables(self) -> MultiTokenTables: | |
| 350 | - max_label_len = max(len(item.token_ids) for item in self.encoded_labels) | |
| 351 | - max_prefix_len = max(len(item.token_ids) - 1 for item in self.encoded_labels) | |
| 352 | - label_token_ids = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.long) | |
| 353 | - label_token_mask = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.float32) | |
| 354 | - label_prefix_ids = torch.full( | |
| 355 | - (self.num_labels, max_prefix_len), | |
| 356 | - fill_value=self.tokenizer.pad_token_id, | |
| 357 | - device=self.device, | |
| 358 | - dtype=torch.long, | |
| 359 | - ) | |
| 360 | - label_prefix_mask = torch.zeros((self.num_labels, max_prefix_len), device=self.device, dtype=torch.long) | |
| 361 | - for idx, item in enumerate(self.encoded_labels): | |
| 362 | - token_ids = torch.tensor(item.token_ids, device=self.device, dtype=torch.long) | |
| 363 | - token_len = token_ids.numel() | |
| 364 | - label_token_ids[idx, :token_len] = token_ids | |
| 365 | - label_token_mask[idx, :token_len] = 1.0 | |
| 366 | - if token_len > 1: | |
| 367 | - prefix_len = token_len - 1 | |
| 368 | - label_prefix_ids[idx, :prefix_len] = token_ids[:-1] | |
| 369 | - label_prefix_mask[idx, :prefix_len] = 1 | |
| 370 | - return MultiTokenTables( | |
| 371 | - label_token_ids=label_token_ids.contiguous(), | |
| 372 | - label_token_mask=label_token_mask.contiguous(), | |
| 373 | - label_prefix_ids=label_prefix_ids.contiguous(), | |
| 374 | - label_prefix_mask=label_prefix_mask.contiguous(), | |
| 375 | - label_position_offsets=torch.arange(max_label_len, device=self.device, dtype=torch.long), | |
| 376 | - ) | |
| 377 | - | |
| 378 | - @torch.inference_mode() | |
| 379 | - def _build_prefix_cache(self) -> PrefixCache: | |
| 380 | - if not self.prefix_ids: | |
| 381 | - return PrefixCache(prefix_ids=[], prefix_hashes=[], raw_cache=tuple()) | |
| 382 | - prefix_tensor = torch.tensor([self.prefix_ids], dtype=torch.long, device=self.device) | |
| 383 | - attention_mask = torch.ones_like(prefix_tensor, dtype=torch.long, device=self.device) | |
| 384 | - outputs = self.model( | |
| 385 | - input_ids=prefix_tensor, | |
| 386 | - attention_mask=attention_mask, | |
| 387 | - use_cache=True, | |
| 388 | - return_dict=True, | |
| 389 | - ) | |
| 390 | - raw_cache = tuple( | |
| 391 | - (layer.keys.detach(), layer.values.detach()) | |
| 392 | - for layer in outputs.past_key_values.layers | |
| 393 | - ) | |
| 394 | - return PrefixCache( | |
| 395 | - prefix_ids=list(self.prefix_ids), | |
| 396 | - prefix_hashes=_hash_blocks(self.prefix_ids, self.cfg.prefix_block_size), | |
| 397 | - raw_cache=raw_cache, | |
| 398 | - ) | |
| 399 | - | |
| 400 | - def expand_prefix_raw_cache(self, batch_size: int) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: | |
| 401 | - if not self.prefix_cache.raw_cache: | |
| 402 | - return tuple() | |
| 403 | - return _expand_legacy_cache(self.prefix_cache.raw_cache, batch_size) | |
| 404 | - | |
| 405 | - def build_continuation_from_query_ids(self, query_ids: list[int]) -> list[int]: | |
| 406 | - continuation = query_ids + self.suffix_ids | |
| 407 | - if not continuation: | |
| 408 | - raise ValueError("prompt continuation is empty after substituting query") | |
| 409 | - if self.prefix_cache.prefix_len + len(continuation) > self.cfg.max_length: | |
| 410 | - raise ValueError( | |
| 411 | - f"sequence length {self.prefix_cache.prefix_len + len(continuation)} exceeds max_length={self.cfg.max_length}" | |
| 412 | - ) | |
| 413 | - return continuation | |
| 414 | - | |
| 415 | - @torch.inference_mode() | |
| 416 | - def reduce_fast_scores( | |
| 417 | - self, | |
| 418 | - hidden: torch.Tensor, | |
| 419 | - out_scores: torch.Tensor, | |
| 420 | - ) -> None: | |
| 421 | - assert self.scorer is not None | |
| 422 | - out_scores.copy_(self.scorer(hidden)) | |
| 423 | - | |
| 424 | - @torch.inference_mode() | |
| 425 | - def reduce_multi_token_scores( | |
| 426 | - self, | |
| 427 | - last_hidden_state: torch.Tensor, | |
| 428 | - batch_size: int, | |
| 429 | - max_input_len: int, | |
| 430 | - score_positions: torch.Tensor, | |
| 431 | - out_scores: torch.Tensor, | |
| 432 | - ) -> None: | |
| 433 | - assert self.multi_token_tables is not None | |
| 434 | - hidden = last_hidden_state.reshape(batch_size, self.num_labels, max_input_len, self.hidden_size) | |
| 435 | - gather_index = score_positions[:, None, :, None].expand( | |
| 436 | - batch_size, | |
| 437 | - self.num_labels, | |
| 438 | - self.multi_token_tables.max_label_len, | |
| 439 | - self.hidden_size, | |
| 440 | - ) | |
| 441 | - gathered_hidden = torch.gather(hidden, 2, gather_index) | |
| 442 | - used_mask = self.multi_token_tables.label_token_mask.unsqueeze(0).expand(batch_size, -1, -1).bool() | |
| 443 | - | |
| 444 | - token_log_probs = torch.zeros( | |
| 445 | - (batch_size, self.num_labels, self.multi_token_tables.max_label_len), | |
| 446 | - device=self.device, | |
| 447 | - dtype=torch.float32, | |
| 448 | - ) | |
| 449 | - if used_mask.any(): | |
| 450 | - flat_hidden = gathered_hidden[used_mask] | |
| 451 | - flat_token_ids = self.multi_token_tables.label_token_ids.unsqueeze(0).expand(batch_size, -1, -1)[used_mask] | |
| 452 | - linear_hidden = flat_hidden.to(self.dtype) if self.device.type == "cuda" else flat_hidden.float() | |
| 453 | - linear_weight = self.lm_head_weight if self.device.type == "cuda" else self.lm_head_weight.float() | |
| 454 | - linear_bias = self.lm_head_bias | |
| 455 | - if linear_bias is not None and self.device.type != "cuda": | |
| 456 | - linear_bias = linear_bias.float() | |
| 457 | - flat_logits = F.linear(linear_hidden, linear_weight, linear_bias) | |
| 458 | - flat_selected = flat_logits.gather(1, flat_token_ids.unsqueeze(1)).squeeze(1).float() | |
| 459 | - flat_log_norm = torch.logsumexp(flat_logits.float(), dim=-1) | |
| 460 | - token_log_probs[used_mask] = flat_selected - flat_log_norm | |
| 461 | - out_scores.copy_(token_log_probs.sum(dim=-1)) | |
| 462 | - | |
| 463 | - def build_score_result( | |
| 464 | - self, | |
| 465 | - query: str, | |
| 466 | - scores: torch.Tensor, | |
| 467 | - stage_ms: dict[str, float], | |
| 468 | - continuation_tokens: int, | |
| 469 | - ) -> QueryScoreResult: | |
| 470 | - score_values = scores.detach().float().cpu().tolist() | |
| 471 | - best_idx = max(range(len(score_values)), key=score_values.__getitem__) | |
| 472 | - probs = torch.softmax(torch.tensor(score_values, dtype=torch.float32), dim=0).tolist() | |
| 473 | - return QueryScoreResult( | |
| 474 | - task_name=self.task_cfg.name, | |
| 475 | - query=query, | |
| 476 | - predicted_label=self.labels[best_idx], | |
| 477 | - scores=[ | |
| 478 | - (label, score, prob) | |
| 479 | - for label, score, prob in zip(self.labels, score_values, probs, strict=True) | |
| 480 | - ], | |
| 481 | - total_ms=sum(stage_ms.values()), | |
| 482 | - stage_ms=stage_ms, | |
| 483 | - fast_path=self.fast_path, | |
| 484 | - prefix_tokens=self.prefix_cache.prefix_len, | |
| 485 | - continuation_tokens=continuation_tokens, | |
| 486 | - label_token_lengths={item.text: len(item.token_ids) for item in self.encoded_labels}, | |
| 487 | - ) | |
| 488 | - | |
| 489 | - | |
| 490 | -class MultiPromptRunner: | |
| 491 | - def __init__(self, cfg: RuntimeConfig): | |
| 492 | - self.cfg = cfg | |
| 493 | - t0 = time.perf_counter() | |
| 494 | - self.shared_runtime = self.build_shared_runtime(cfg) | |
| 495 | - t1 = time.perf_counter() | |
| 496 | - self.device = self.shared_runtime.device | |
| 497 | - self.dtype = self.shared_runtime.dtype | |
| 498 | - self.tokenizer = self.shared_runtime.tokenizer | |
| 499 | - self.model = self.shared_runtime.model | |
| 500 | - self.backbone = self.shared_runtime.backbone | |
| 501 | - self.hidden_size = self.shared_runtime.hidden_size | |
| 502 | - self.graph_capture_pool = self.shared_runtime.graph_capture_pool | |
| 503 | - self.graph_capture_stream = self.shared_runtime.graph_capture_stream | |
| 504 | - self.runners = [ | |
| 505 | - PromptClassifierRunner(cfg=cfg, task_cfg=task_cfg, shared_runtime=self.shared_runtime) | |
| 506 | - for task_cfg in cfg.tasks | |
| 507 | - ] | |
| 508 | - t2 = time.perf_counter() | |
| 509 | - self.batch_layouts = {batch_size: self._build_batch_layout(batch_size) for batch_size in self.cfg.batch_sizes} | |
| 510 | - t3 = time.perf_counter() | |
| 511 | - self.mixed_prefix_caches = { | |
| 512 | - batch_size: self._build_mixed_prefix_cache(self.batch_layouts[batch_size]) | |
| 513 | - for batch_size in self.cfg.batch_sizes | |
| 514 | - } | |
| 515 | - t4 = time.perf_counter() | |
| 516 | - self.max_label_prefix_len = max( | |
| 517 | - (runner.multi_token_tables.max_label_prefix_len if runner.multi_token_tables is not None else 0) | |
| 518 | - for runner in self.runners | |
| 519 | - ) | |
| 520 | - self.mixed_buckets = { | |
| 521 | - (batch_size, continuation_len): self._build_mixed_bucket( | |
| 522 | - self.batch_layouts[batch_size], | |
| 523 | - self.mixed_prefix_caches[batch_size], | |
| 524 | - continuation_len, | |
| 525 | - ) | |
| 526 | - for batch_size in self.cfg.batch_sizes | |
| 527 | - for continuation_len in self.cfg.continuation_buckets | |
| 528 | - } | |
| 529 | - t5 = time.perf_counter() | |
| 530 | - self._warmup_results: dict[int, BatchScoreResult] = {} | |
| 531 | - self._preload_report: PreloadReport | None = None | |
| 532 | - self._init_stage_ms = { | |
| 533 | - "load_model_and_tokenizer": (t1 - t0) * 1000.0, | |
| 534 | - "build_prompt_runtimes": (t2 - t1) * 1000.0, | |
| 535 | - "build_batch_layouts": (t3 - t2) * 1000.0, | |
| 536 | - "build_mixed_prefix_caches": (t4 - t3) * 1000.0, | |
| 537 | - "build_mixed_buckets_and_graphs": (t5 - t4) * 1000.0, | |
| 538 | - } | |
| 539 | - self._init_total_ms = sum(self._init_stage_ms.values()) | |
| 540 | - | |
| 541 | - @staticmethod | |
| 542 | - def build_shared_runtime(cfg: RuntimeConfig) -> SharedRuntime: | |
| 543 | - device = torch.device(cfg.device) | |
| 544 | - dtype = torch.float16 | |
| 545 | - tokenizer = AutoTokenizer.from_pretrained( | |
| 546 | - cfg.resolved_model_source, | |
| 547 | - trust_remote_code=cfg.resolved_trust_remote_code, | |
| 548 | - token=cfg.hf_token, | |
| 549 | - cache_dir=cfg.hf_cache_dir, | |
| 550 | - local_files_only=cfg.resolved_local_files_only, | |
| 551 | - ) | |
| 552 | - if tokenizer.pad_token_id is None: | |
| 553 | - tokenizer.pad_token = tokenizer.eos_token | |
| 554 | - attn_impl = MultiPromptRunner._resolve_attn_impl(cfg.attn_backend) | |
| 555 | - quantization_config = None | |
| 556 | - model_kwargs: dict[str, object] = { | |
| 557 | - "trust_remote_code": cfg.resolved_trust_remote_code, | |
| 558 | - "attn_implementation": attn_impl, | |
| 559 | - "token": cfg.hf_token, | |
| 560 | - "cache_dir": cfg.hf_cache_dir, | |
| 561 | - "local_files_only": cfg.resolved_local_files_only, | |
| 562 | - } | |
| 563 | - if cfg.load_in_4bit: | |
| 564 | - if BitsAndBytesConfig is None: | |
| 565 | - raise ImportError("transformers BitsAndBytesConfig is unavailable; install bitsandbytes support first") | |
| 566 | - quantization_config = BitsAndBytesConfig( | |
| 567 | - load_in_4bit=True, | |
| 568 | - bnb_4bit_compute_dtype=dtype, | |
| 569 | - bnb_4bit_quant_type=cfg.bnb_4bit_quant_type, | |
| 570 | - bnb_4bit_use_double_quant=cfg.bnb_4bit_use_double_quant, | |
| 571 | - ) | |
| 572 | - model_kwargs["quantization_config"] = quantization_config | |
| 573 | - model_kwargs["device_map"] = {"": device.index or 0} | |
| 574 | - else: | |
| 575 | - model_kwargs["dtype"] = dtype | |
| 576 | - model_kwargs["device_map"] = None | |
| 577 | - model = AutoModelForCausalLM.from_pretrained( | |
| 578 | - cfg.resolved_model_source, | |
| 579 | - **model_kwargs, | |
| 580 | - ).eval() | |
| 581 | - if not cfg.load_in_4bit: | |
| 582 | - model = model.to(device) | |
| 583 | - backbone = model.get_submodule(model.base_model_prefix) | |
| 584 | - hidden_size = model.get_output_embeddings().weight.shape[1] | |
| 585 | - graph_capture_pool = None | |
| 586 | - graph_capture_stream = None | |
| 587 | - if device.type == "cuda" and torch.cuda.is_available() and cfg.cuda_graphs and not cfg.load_in_4bit: | |
| 588 | - graph_capture_pool = torch.cuda.graph_pool_handle() | |
| 589 | - graph_capture_stream = torch.cuda.Stream(device=device) | |
| 590 | - return SharedRuntime( | |
| 591 | - device=device, | |
| 592 | - dtype=dtype, | |
| 593 | - tokenizer=tokenizer, | |
| 594 | - model=model, | |
| 595 | - backbone=backbone, | |
| 596 | - hidden_size=hidden_size, | |
| 597 | - graph_capture_pool=graph_capture_pool, | |
| 598 | - graph_capture_stream=graph_capture_stream, | |
| 599 | - ) | |
| 600 | - | |
| 601 | - @staticmethod | |
| 602 | - def _resolve_attn_impl(requested: str) -> str: | |
| 603 | - if requested in {"sdpa", "eager"}: | |
| 604 | - return requested | |
| 605 | - if requested == "auto": | |
| 606 | - return "sdpa" | |
| 607 | - raise ValueError(f"unsupported attn_backend: {requested}") | |
| 608 | - | |
| 609 | - def _attn_context(self): | |
| 610 | - if sdpa_kernel is not None and self.cfg.attn_backend in {"auto", "sdpa"}: | |
| 611 | - return sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]) | |
| 612 | - return torch.no_grad() | |
| 613 | - | |
| 614 | - def _sync(self) -> None: | |
| 615 | - if self.device.type == "cuda": | |
| 616 | - torch.cuda.synchronize() | |
| 617 | - | |
| 618 | - def _pick_bucket(self, continuation_len: int) -> int: | |
| 619 | - for bucket in self.cfg.continuation_buckets: | |
| 620 | - if continuation_len <= bucket: | |
| 621 | - return bucket | |
| 622 | - if self.cfg.pad_to_bucket: | |
| 623 | - raise ValueError( | |
| 624 | - f"continuation length {continuation_len} exceeds configured buckets; extend continuation_buckets" | |
| 625 | - ) | |
| 626 | - return continuation_len | |
| 627 | - | |
| 628 | - def _build_batch_layout(self, batch_size: int) -> BatchLayout: | |
| 629 | - plans: list[PromptBatchPlan] = [] | |
| 630 | - row_start = 0 | |
| 631 | - for runner in self.runners: | |
| 632 | - row_count = batch_size if runner.fast_path else batch_size * runner.num_labels | |
| 633 | - plans.append( | |
| 634 | - PromptBatchPlan( | |
| 635 | - runner=runner, | |
| 636 | - row_start=row_start, | |
| 637 | - row_count=row_count, | |
| 638 | - score_buffer=torch.empty((batch_size, runner.num_labels), device=self.device, dtype=torch.float32), | |
| 639 | - ) | |
| 640 | - ) | |
| 641 | - row_start += row_count | |
| 642 | - return BatchLayout(batch_size=batch_size, total_rows=row_start, plans=plans) | |
| 643 | - | |
| 644 | - def _build_mixed_prefix_cache(self, layout: BatchLayout) -> MixedPrefixCache: | |
| 645 | - prefix_lengths = torch.zeros((layout.total_rows,), device=self.device, dtype=torch.long) | |
| 646 | - non_empty = [plan.runner.prefix_cache.raw_cache for plan in layout.plans if plan.runner.prefix_cache.raw_cache] | |
| 647 | - if not non_empty: | |
| 648 | - return MixedPrefixCache( | |
| 649 | - batch_size=layout.batch_size, | |
| 650 | - total_rows=layout.total_rows, | |
| 651 | - prefix_lengths=prefix_lengths, | |
| 652 | - attention_mask=torch.zeros((layout.total_rows, 0), device=self.device, dtype=torch.long), | |
| 653 | - raw_cache=tuple(), | |
| 654 | - ) | |
| 655 | - | |
| 656 | - max_prefix_len = max(plan.runner.prefix_cache.prefix_len for plan in layout.plans) | |
| 657 | - num_layers = len(non_empty[0]) | |
| 658 | - attention_mask = torch.zeros((layout.total_rows, max_prefix_len), device=self.device, dtype=torch.long) | |
| 659 | - raw_layers: list[tuple[torch.Tensor, torch.Tensor]] = [] | |
| 660 | - for layer_idx in range(num_layers): | |
| 661 | - sample_key, sample_value = non_empty[0][layer_idx] | |
| 662 | - merged_key = sample_key.new_zeros( | |
| 663 | - (layout.total_rows, sample_key.shape[1], max_prefix_len, sample_key.shape[3]) | |
| 664 | - ) | |
| 665 | - merged_value = sample_value.new_zeros( | |
| 666 | - (layout.total_rows, sample_value.shape[1], max_prefix_len, sample_value.shape[3]) | |
| 667 | - ) | |
| 668 | - raw_layers.append((merged_key, merged_value)) | |
| 669 | - | |
| 670 | - for plan in layout.plans: | |
| 671 | - runner = plan.runner | |
| 672 | - prefix_len = runner.prefix_cache.prefix_len | |
| 673 | - row_slice = slice(plan.row_start, plan.row_stop) | |
| 674 | - prefix_lengths[row_slice] = prefix_len | |
| 675 | - if prefix_len == 0: | |
| 676 | - continue | |
| 677 | - attention_mask[row_slice, :prefix_len] = 1 | |
| 678 | - raw_cache = runner.expand_prefix_raw_cache(plan.row_count) | |
| 679 | - for layer_idx, (key, value) in enumerate(raw_cache): | |
| 680 | - merged_key, merged_value = raw_layers[layer_idx] | |
| 681 | - merged_key[row_slice, :, :prefix_len, :] = key | |
| 682 | - merged_value[row_slice, :, :prefix_len, :] = value | |
| 683 | - | |
| 684 | - return MixedPrefixCache( | |
| 685 | - batch_size=layout.batch_size, | |
| 686 | - total_rows=layout.total_rows, | |
| 687 | - prefix_lengths=prefix_lengths, | |
| 688 | - attention_mask=attention_mask.contiguous(), | |
| 689 | - raw_cache=tuple(raw_layers), | |
| 690 | - ) | |
| 691 | - | |
| 692 | - def _build_mixed_bucket( | |
| 693 | - self, | |
| 694 | - layout: BatchLayout, | |
| 695 | - prefix_cache: MixedPrefixCache, | |
| 696 | - continuation_len: int, | |
| 697 | - ) -> MixedBucketRuntime: | |
| 698 | - max_input_len = continuation_len + self.max_label_prefix_len | |
| 699 | - total_len = prefix_cache.max_prefix_len + max_input_len | |
| 700 | - input_ids = torch.full( | |
| 701 | - (layout.total_rows, max_input_len), | |
| 702 | - fill_value=self.tokenizer.pad_token_id, | |
| 703 | - device=self.device, | |
| 704 | - dtype=torch.long, | |
| 705 | - ) | |
| 706 | - attention_mask = torch.zeros((layout.total_rows, total_len), device=self.device, dtype=torch.long) | |
| 707 | - if prefix_cache.max_prefix_len: | |
| 708 | - attention_mask[:, : prefix_cache.max_prefix_len] = prefix_cache.attention_mask | |
| 709 | - position_ids = ( | |
| 710 | - prefix_cache.prefix_lengths[:, None] | |
| 711 | - + torch.arange(max_input_len, device=self.device, dtype=torch.long).unsqueeze(0) | |
| 712 | - ).contiguous() | |
| 713 | - last_hidden_state = torch.empty( | |
| 714 | - (layout.total_rows, max_input_len, self.hidden_size), | |
| 715 | - device=self.device, | |
| 716 | - dtype=self.dtype, | |
| 717 | - ) | |
| 718 | - bucket = MixedBucketRuntime( | |
| 719 | - batch_size=layout.batch_size, | |
| 720 | - total_rows=layout.total_rows, | |
| 721 | - continuation_len=continuation_len, | |
| 722 | - max_input_len=max_input_len, | |
| 723 | - input_ids=input_ids, | |
| 724 | - attention_mask=attention_mask, | |
| 725 | - position_ids=position_ids, | |
| 726 | - last_hidden_state=last_hidden_state, | |
| 727 | - ) | |
| 728 | - if self.cfg.cuda_graphs: | |
| 729 | - self._capture_mixed_bucket(bucket, prefix_cache) | |
| 730 | - return bucket | |
| 731 | - | |
| 732 | - @torch.inference_mode() | |
| 733 | - def _run_mixed_backbone( | |
| 734 | - self, | |
| 735 | - bucket: MixedBucketRuntime, | |
| 736 | - prefix_cache: MixedPrefixCache, | |
| 737 | - ) -> None: | |
| 738 | - cache = DynamicCache(ddp_cache_data=prefix_cache.raw_cache, config=self.model.config) | |
| 739 | - with self._attn_context(): | |
| 740 | - outputs = self.backbone( | |
| 741 | - input_ids=bucket.input_ids, | |
| 742 | - attention_mask=bucket.attention_mask, | |
| 743 | - position_ids=bucket.position_ids, | |
| 744 | - past_key_values=cache, | |
| 745 | - use_cache=False, | |
| 746 | - return_dict=True, | |
| 747 | - ) | |
| 748 | - bucket.last_hidden_state.copy_(outputs.last_hidden_state) | |
| 749 | - | |
| 750 | - def _capture_mixed_bucket(self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache) -> None: | |
| 751 | - if not torch.cuda.is_available(): | |
| 752 | - return | |
| 753 | - try: | |
| 754 | - torch.cuda.synchronize() | |
| 755 | - stream = self.graph_capture_stream or torch.cuda.Stream(device=self.device) | |
| 756 | - with torch.cuda.stream(stream): | |
| 757 | - for _ in range(self.cfg.graph_warmups): | |
| 758 | - self._run_mixed_backbone(bucket, prefix_cache) | |
| 759 | - stream.synchronize() | |
| 760 | - graph = torch.cuda.CUDAGraph() | |
| 761 | - with torch.cuda.graph(graph, pool=self.graph_capture_pool, stream=stream): | |
| 762 | - self._run_mixed_backbone(bucket, prefix_cache) | |
| 763 | - bucket.graph = graph | |
| 764 | - except RuntimeError: | |
| 765 | - bucket.graph = None | |
| 766 | - | |
| 767 | - def _prepare_bucket( | |
| 768 | - self, | |
| 769 | - layout: BatchLayout, | |
| 770 | - prefix_cache: MixedPrefixCache, | |
| 771 | - bucket: MixedBucketRuntime, | |
| 772 | - query_ids_batch: list[list[int]], | |
| 773 | - ) -> tuple[list[list[int]], dict[str, list[object]]]: | |
| 774 | - del prefix_cache | |
| 775 | - bucket.input_ids.fill_(self.tokenizer.pad_token_id) | |
| 776 | - bucket.attention_mask.zero_() | |
| 777 | - if self.mixed_prefix_caches[layout.batch_size].max_prefix_len: | |
| 778 | - bucket.attention_mask[:, : self.mixed_prefix_caches[layout.batch_size].max_prefix_len] = ( | |
| 779 | - self.mixed_prefix_caches[layout.batch_size].attention_mask | |
| 780 | - ) | |
| 781 | - continuation_lengths_per_task: dict[str, list[int]] = {} | |
| 782 | - continuation_tokens_per_task: dict[str, list[list[int]]] = {} | |
| 783 | - prefix_base = self.mixed_prefix_caches[layout.batch_size].max_prefix_len | |
| 784 | - for plan in layout.plans: | |
| 785 | - runner = plan.runner | |
| 786 | - per_query_continuations = [runner.build_continuation_from_query_ids(query_ids) for query_ids in query_ids_batch] | |
| 787 | - continuation_tokens_per_task[runner.task_cfg.name] = per_query_continuations | |
| 788 | - continuation_lengths_per_task[runner.task_cfg.name] = [len(ids) for ids in per_query_continuations] | |
| 789 | - if runner.fast_path: | |
| 790 | - for batch_idx, continuation in enumerate(per_query_continuations): | |
| 791 | - cont_len = len(continuation) | |
| 792 | - row_idx = plan.row_start + batch_idx | |
| 793 | - bucket.input_ids[row_idx, :cont_len] = torch.tensor(continuation, device=self.device, dtype=torch.long) | |
| 794 | - bucket.attention_mask[row_idx, prefix_base : prefix_base + cont_len] = 1 | |
| 795 | - continue | |
| 796 | - | |
| 797 | - assert runner.multi_token_tables is not None | |
| 798 | - for batch_idx, continuation in enumerate(per_query_continuations): | |
| 799 | - cont_len = len(continuation) | |
| 800 | - row_start = plan.row_start + batch_idx * runner.num_labels | |
| 801 | - row_stop = row_start + runner.num_labels | |
| 802 | - row_slice = slice(row_start, row_stop) | |
| 803 | - cont_tensor = torch.tensor(continuation, device=self.device, dtype=torch.long) | |
| 804 | - bucket.input_ids[row_slice, :cont_len] = cont_tensor.unsqueeze(0).expand(runner.num_labels, -1) | |
| 805 | - bucket.attention_mask[row_slice, prefix_base : prefix_base + cont_len] = 1 | |
| 806 | - if runner.multi_token_tables.max_label_prefix_len: | |
| 807 | - bucket.input_ids[ | |
| 808 | - row_slice, | |
| 809 | - cont_len : cont_len + runner.multi_token_tables.max_label_prefix_len, | |
| 810 | - ] = runner.multi_token_tables.label_prefix_ids | |
| 811 | - bucket.attention_mask[ | |
| 812 | - row_slice, | |
| 813 | - prefix_base + cont_len : prefix_base + cont_len + runner.multi_token_tables.max_label_prefix_len, | |
| 814 | - ] = runner.multi_token_tables.label_prefix_mask | |
| 815 | - return query_ids_batch, { | |
| 816 | - "continuation_lengths_per_task": continuation_lengths_per_task, | |
| 817 | - "continuation_tokens_per_task": continuation_tokens_per_task, | |
| 818 | - } | |
| 819 | - | |
| 820 | - def _reduce_prompt_scores( | |
| 821 | - self, | |
| 822 | - layout: BatchLayout, | |
| 823 | - bucket: MixedBucketRuntime, | |
| 824 | - query_texts: list[str], | |
| 825 | - prep_meta: dict[str, list[object]], | |
| 826 | - shared_stage_ms: dict[str, float], | |
| 827 | - ) -> list[MultiPromptScoreResult]: | |
| 828 | - result_rows = [[] for _ in range(layout.batch_size)] | |
| 829 | - prompt_reduce_total_ms = 0.0 | |
| 830 | - for plan in layout.plans: | |
| 831 | - runner = plan.runner | |
| 832 | - continuation_lengths = prep_meta["continuation_lengths_per_task"][runner.task_cfg.name] | |
| 833 | - reduce_start = time.perf_counter() | |
| 834 | - if runner.fast_path: | |
| 835 | - hidden_rows = [] | |
| 836 | - row_slice = bucket.last_hidden_state[plan.row_start : plan.row_start + layout.batch_size] | |
| 837 | - for batch_idx, cont_len in enumerate(continuation_lengths): | |
| 838 | - hidden_rows.append(row_slice[batch_idx, cont_len - 1]) | |
| 839 | - hidden = torch.stack(hidden_rows, dim=0) | |
| 840 | - runner.reduce_fast_scores(hidden=hidden, out_scores=plan.score_buffer) | |
| 841 | - stage_name = "tail_scorer" | |
| 842 | - else: | |
| 843 | - assert runner.multi_token_tables is not None | |
| 844 | - score_positions = torch.stack( | |
| 845 | - [ | |
| 846 | - cont_len - 1 + runner.multi_token_tables.label_position_offsets | |
| 847 | - for cont_len in continuation_lengths | |
| 848 | - ], | |
| 849 | - dim=0, | |
| 850 | - ) | |
| 851 | - runner.reduce_multi_token_scores( | |
| 852 | - last_hidden_state=bucket.last_hidden_state[plan.row_start : plan.row_stop], | |
| 853 | - batch_size=layout.batch_size, | |
| 854 | - max_input_len=bucket.max_input_len, | |
| 855 | - score_positions=score_positions, | |
| 856 | - out_scores=plan.score_buffer, | |
| 857 | - ) | |
| 858 | - stage_name = "candidate_reduce" | |
| 859 | - self._sync() | |
| 860 | - reduce_end = time.perf_counter() | |
| 861 | - reduce_ms = (reduce_end - reduce_start) * 1000.0 | |
| 862 | - prompt_reduce_total_ms += reduce_ms | |
| 863 | - for batch_idx, query in enumerate(query_texts): | |
| 864 | - stage_ms = dict(shared_stage_ms) | |
| 865 | - stage_ms[stage_name] = reduce_ms / layout.batch_size | |
| 866 | - result_rows[batch_idx].append( | |
| 867 | - runner.build_score_result( | |
| 868 | - query=query, | |
| 869 | - scores=plan.score_buffer[batch_idx], | |
| 870 | - stage_ms=stage_ms, | |
| 871 | - continuation_tokens=continuation_lengths[batch_idx], | |
| 872 | - ) | |
| 873 | - ) | |
| 874 | - | |
| 875 | - batch_total_ms = sum(shared_stage_ms.values()) + prompt_reduce_total_ms | |
| 876 | - shared_plus_reduce = dict(shared_stage_ms) | |
| 877 | - shared_plus_reduce["prompt_reduce_total"] = prompt_reduce_total_ms | |
| 878 | - results: list[MultiPromptScoreResult] = [] | |
| 879 | - for batch_idx, query in enumerate(query_texts): | |
| 880 | - results.append( | |
| 881 | - MultiPromptScoreResult( | |
| 882 | - query=query, | |
| 883 | - total_ms=batch_total_ms / layout.batch_size, | |
| 884 | - details=result_rows[batch_idx], | |
| 885 | - stage_ms={ | |
| 886 | - **shared_plus_reduce, | |
| 887 | - "per_query_total_estimate": batch_total_ms / layout.batch_size, | |
| 888 | - }, | |
| 889 | - ) | |
| 890 | - ) | |
| 891 | - return results | |
| 892 | - | |
| 893 | - @torch.inference_mode() | |
| 894 | - def score_queries(self, queries: list[str]) -> BatchScoreResult: | |
| 895 | - if not queries: | |
| 896 | - raise ValueError("queries must not be empty") | |
| 897 | - batch_size = len(queries) | |
| 898 | - if batch_size not in self.batch_layouts: | |
| 899 | - raise ValueError(f"batch size {batch_size} is not preloaded; configured batch_sizes={self.cfg.batch_sizes}") | |
| 900 | - layout = self.batch_layouts[batch_size] | |
| 901 | - prefix_cache = self.mixed_prefix_caches[batch_size] | |
| 902 | - | |
| 903 | - self._sync() | |
| 904 | - t0 = time.perf_counter() | |
| 905 | - query_ids_batch = [self.tokenizer.encode(query, add_special_tokens=False) for query in queries] | |
| 906 | - self._sync() | |
| 907 | - t1 = time.perf_counter() | |
| 908 | - | |
| 909 | - max_continuation_len = max( | |
| 910 | - len(plan.runner.build_continuation_from_query_ids(query_ids)) | |
| 911 | - for plan in layout.plans | |
| 912 | - for query_ids in query_ids_batch | |
| 913 | - ) | |
| 914 | - picked_bucket = self._pick_bucket(max_continuation_len) | |
| 915 | - bucket = self.mixed_buckets[(batch_size, picked_bucket)] | |
| 916 | - _, prep_meta = self._prepare_bucket(layout, prefix_cache, bucket, query_ids_batch) | |
| 917 | - self._sync() | |
| 918 | - t2 = time.perf_counter() | |
| 919 | - | |
| 920 | - if bucket.graph is not None: | |
| 921 | - bucket.graph.replay() | |
| 922 | - else: | |
| 923 | - self._run_mixed_backbone(bucket, prefix_cache) | |
| 924 | - self._sync() | |
| 925 | - t3 = time.perf_counter() | |
| 926 | - | |
| 927 | - shared_stage_ms = { | |
| 928 | - "encode_queries_shared": (t1 - t0) * 1000.0, | |
| 929 | - "prepare_batch_shared": (t2 - t1) * 1000.0, | |
| 930 | - "backbone_shared": (t3 - t2) * 1000.0, | |
| 931 | - } | |
| 932 | - results = self._reduce_prompt_scores(layout, bucket, queries, prep_meta, shared_stage_ms) | |
| 933 | - total_ms = sum(shared_stage_ms.values()) + results[0].stage_ms["prompt_reduce_total"] | |
| 934 | - return BatchScoreResult( | |
| 935 | - batch_size=batch_size, | |
| 936 | - total_ms=total_ms, | |
| 937 | - results=results, | |
| 938 | - stage_ms={ | |
| 939 | - **shared_stage_ms, | |
| 940 | - "prompt_reduce_total": results[0].stage_ms["prompt_reduce_total"], | |
| 941 | - }, | |
| 942 | - ) | |
| 943 | - | |
| 944 | - def score_query(self, query: str) -> MultiPromptScoreResult: | |
| 945 | - return self.score_queries([query]).results[0] | |
| 946 | - | |
| 947 | - def preload(self) -> PreloadReport: | |
| 948 | - if self._preload_report is not None: | |
| 949 | - return self._preload_report | |
| 950 | - stage_ms: dict[str, float] = dict(self._init_stage_ms) | |
| 951 | - start = time.perf_counter() | |
| 952 | - self._sync() | |
| 953 | - t0 = time.perf_counter() | |
| 954 | - warmup_batch_sizes = self.cfg.warmup_batch_sizes or self.cfg.batch_sizes | |
| 955 | - for batch_size in warmup_batch_sizes: | |
| 956 | - queries = [self.cfg.warmup_query] * batch_size | |
| 957 | - self._warmup_results[batch_size] = self.score_queries(queries) | |
| 958 | - self._sync() | |
| 959 | - t1 = time.perf_counter() | |
| 960 | - stage_ms["warmup_end_to_end"] = (t1 - t0) * 1000.0 | |
| 961 | - stage_ms["startup_total_before_warmup"] = self._init_total_ms | |
| 962 | - total_ms = self._init_total_ms + (t1 - start) * 1000.0 | |
| 963 | - runtime = self.preload_report() | |
| 964 | - self._preload_report = PreloadReport(total_ms=total_ms, stage_ms=stage_ms, runtime=runtime) | |
| 965 | - return self._preload_report | |
| 966 | - | |
| 967 | - def preload_report(self) -> dict[str, object]: | |
| 968 | - return { | |
| 969 | - "model_name": self.cfg.resolved_model_name, | |
| 970 | - "model_source": self.cfg.resolved_model_source, | |
| 971 | - "device": str(self.device), | |
| 972 | - "dtype": self.cfg.dtype, | |
| 973 | - "attn_backend": self.cfg.attn_backend, | |
| 974 | - "execution_model": "single_mixed_backbone_per_batch", | |
| 975 | - "num_tasks": len(self.runners), | |
| 976 | - "task_names": [runner.task_cfg.name for runner in self.runners], | |
| 977 | - "batch_sizes": list(self.cfg.batch_sizes), | |
| 978 | - "continuation_buckets": list(self.cfg.continuation_buckets), | |
| 979 | - "mixed_bucket_count": len(self.mixed_buckets), | |
| 980 | - "captured_mixed_buckets": sum(bucket.graph is not None for bucket in self.mixed_buckets.values()), | |
| 981 | - "all_configured_buckets_preloaded": True, | |
| 982 | - "init_stage_ms": dict(self._init_stage_ms), | |
| 983 | - "init_total_ms": self._init_total_ms, | |
| 984 | - "force_single_token_labels": self.cfg.force_single_token_labels, | |
| 985 | - "warmup_query": self.cfg.warmup_query, | |
| 986 | - "tasks": [ | |
| 987 | - { | |
| 988 | - "task_name": runner.task_cfg.name, | |
| 989 | - "fast_path": runner.fast_path, | |
| 990 | - "num_labels": runner.num_labels, | |
| 991 | - "label_token_lengths": {item.text: len(item.token_ids) for item in runner.encoded_labels}, | |
| 992 | - "prefix_tokens": runner.prefix_cache.prefix_len, | |
| 993 | - "prefix_hashes": runner.prefix_cache.prefix_hashes, | |
| 994 | - "label_prefix": runner.task_cfg.label_prefix, | |
| 995 | - } | |
| 996 | - for runner in self.runners | |
| 997 | - ], | |
| 998 | - } | |
| 999 | - | |
| 1000 | -但是关于多token的处理方式是低效的、是错的,不要参考他,请你重新实现。 | |
| 1001 | -本地的.venv已经创建好,是复用llm-qp的,请使用该环境 | |
| 1002 | -有用的代码我已经引用过来了,请你基于需求,搜寻相关资料,专注于重新开发全新的版本。 | |
| 1003 | -任务较重,请耐心逐步完成,慢慢来,要长时间迭代一直到服务建设完成、测试完全符合预期。 |
docs/issues/issue-2026-04-06-推理优化.md deleted
| ... | ... | @@ -1,59 +0,0 @@ |
| 1 | - | |
| 2 | -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | |
| 3 | -先专注于推理的优化,不用考虑服务后,可以程序启动后标准输入读取query,输出分类词。 | |
| 4 | - | |
| 5 | -示例prompt和对应的9个分类词: | |
| 6 | -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} | |
| 7 | - | |
| 8 | -做专用执行路径,不是在通用生成引擎上做配置优化。 | |
| 9 | - | |
| 10 | -prompt已经固定(不要考虑蒸馏、微调、或者缩短prompt。缩短prompt是可以我自己调整的,并且prompt可能改为其他场景,与你做专用推理优化不相关,我给的prompt只是一个例子,你专注于专用推理的框架,适配配置化的prompt+分类词列表打分检查) | |
| 11 | - | |
| 12 | - | |
| 13 | -使用Tesla T4,因此: | |
| 14 | -1. 使用FP16。不用 BF16 作为主路径 | |
| 15 | -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | |
| 16 | -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 | |
| 17 | - | |
| 18 | - | |
| 19 | -主要考虑优化方向为: | |
| 20 | -hidden_last -> N-class scorer -> argmax | |
| 21 | -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | |
| 22 | -去 full vocab logits | |
| 23 | -去 decode / constrained decode | |
| 24 | -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) | |
| 25 | -专用 tail kernel(输出 N 类原始分数) | |
| 26 | - | |
| 27 | -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: | |
| 28 | -15.1 预分配 | |
| 29 | -对每个 bucket 预分配: | |
| 30 | -input ids buffer | |
| 31 | -attention scratch buffer | |
| 32 | -output hidden buffer | |
| 33 | -class id buffer | |
| 34 | -15.2 Graph capture | |
| 35 | - | |
| 36 | -每个 bucket 预先 capture 一张图: | |
| 37 | -embedding | |
| 38 | -transformer 主干 | |
| 39 | -compact last-state | |
| 40 | -9-way tail kernel | |
| 41 | - | |
| 42 | -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | |
| 43 | - | |
| 44 | -你有sudo权限,你可以执行为本项目安装自己的环境 | |
| 45 | - | |
| 46 | -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) | |
| 47 | -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 | |
| 48 | - | |
| 49 | - | |
| 50 | - | |
| 51 | - | |
| 52 | - | |
| 53 | - | |
| 54 | - | |
| 55 | -另外,我想要有个命令行工具,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。 | |
| 56 | -输入query,输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下),目前是否已经满足了。 | |
| 57 | - | |
| 58 | -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | |
| 59 | -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 |
docs/issues/issue-2026-04-07-服务化.md deleted
| ... | ... | @@ -1,66 +0,0 @@ |
| 1 | -彻底清除掉命令行交互方式的代码,改造为提供 HTTP 接口,端口6001。 | |
| 2 | -并提供scripts/service_ctl.sh的方式管理服务的启停 | |
| 3 | -返回分类结果的list(results):list由多个dict组成,每个dict的key为对应的分类任务,value是三元组,即值、打分、概率(如果值为none则该项不输出) | |
| 4 | -完善日志系统,当前cli方式输出的重要信息,在日志以及http接口的details字段体现。 | |
| 5 | -提供一个压测脚本放到本项目的合适的目录下,作为压测工具。该压测工具针对每条请求打印出结果,并最后给出性能指标,参考: | |
| 6 | -#!/bin/bash | |
| 7 | - | |
| 8 | -# 默认值 | |
| 9 | -concurrency=${1:-1} | |
| 10 | -top_lines=${2:-100} | |
| 11 | - | |
| 12 | -# 固定查询文件路径 | |
| 13 | -query_file="/data/saas-search/scripts/evaluation/queries/queries.txt" | |
| 14 | - | |
| 15 | -# 检查文件是否存在 | |
| 16 | -if [ ! -f "$query_file" ]; then | |
| 17 | - echo "错误: 查询文件不存在: $query_file" >&2 | |
| 18 | - exit 1 | |
| 19 | -fi | |
| 20 | - | |
| 21 | -# 检查 jq 是否可用 | |
| 22 | -if ! command -v jq &> /dev/null; then | |
| 23 | - echo "错误: 需要安装 jq 来解析 JSON" >&2 | |
| 24 | - exit 1 | |
| 25 | -fi | |
| 26 | - | |
| 27 | -url="http://127.0.0.1:6001/..." | |
| 28 | -max_jobs=$concurrency | |
| 29 | -job_count=0 | |
| 30 | - | |
| 31 | -# 读取文件前 top_lines 行,每行作为一个 query | |
| 32 | -while IFS= read -r query; do | |
| 33 | - # 跳过空行 | |
| 34 | - [ -z "$query" ] && continue | |
| 35 | - | |
| 36 | - # 启动子进程执行请求 | |
| 37 | - ( | |
| 38 | - # 安全构建 JSON payload | |
| 39 | - payload=$(jq -n --arg q "$query" '{query: $q}') | |
| 40 | - # 发送请求并获取响应 | |
| 41 | - response=$(curl -s -X POST "$url" \ | |
| 42 | - -H 'Content-Type: application/json' \ | |
| 43 | - -d "$payload") | |
| 44 | - # 提取 results 字段(紧凑 JSON 格式) | |
| 45 | - results=$(echo "$response" | jq -c '.results') | |
| 46 | - # 输出 query 和对应的 results | |
| 47 | - printf "%s\t%s\n" "$query" "$results" | |
| 48 | - ) & | |
| 49 | - | |
| 50 | - # 控制并发数量 | |
| 51 | - ((job_count++)) | |
| 52 | - if (( job_count >= max_jobs )); then | |
| 53 | - wait -n # 等待任意一个后台进程完成 | |
| 54 | - ((job_count--)) | |
| 55 | - fi | |
| 56 | -done < <(head -n "$top_lines" "$query_file") | |
| 57 | - | |
| 58 | -# 等待所有剩余后台进程完成 | |
| 59 | -# 在这里统计处性能情况,指标: | |
| 60 | - 平均耗时: | |
| 61 | - 最大耗时: | |
| 62 | - 最小耗时: | |
| 63 | - TP50: | |
| 64 | - TP90: | |
| 65 | - TP99: | |
| 66 | -wait | |
| 67 | 0 | \ No newline at end of file |
docs/issues/issue.md
scripts/evaluation/eval_framework/__init__.py
| ... | ... | @@ -20,7 +20,6 @@ from .constants import ( # noqa: E402 |
| 20 | 20 | RELEVANCE_LOW, |
| 21 | 21 | RELEVANCE_NON_IRRELEVANT, |
| 22 | 22 | VALID_LABELS, |
| 23 | - normalize_stored_label, | |
| 24 | 23 | ) |
| 25 | 24 | from .framework import SearchEvaluationFramework # noqa: E402 |
| 26 | 25 | from .store import EvalStore, QueryBuildResult # noqa: E402 |
| ... | ... | @@ -51,7 +50,6 @@ __all__ = [ |
| 51 | 50 | "create_web_app", |
| 52 | 51 | "ensure_dir", |
| 53 | 52 | "main", |
| 54 | - "normalize_stored_label", | |
| 55 | 53 | "render_batch_report_markdown", |
| 56 | 54 | "sha1_text", |
| 57 | 55 | "utc_now_iso", | ... | ... |
scripts/evaluation/eval_framework/constants.py
| ... | ... | @@ -42,20 +42,6 @@ STOP_PROB_MAP = { |
| 42 | 42 | RELEVANCE_IRRELEVANT: 0.0, |
| 43 | 43 | } |
| 44 | 44 | |
| 45 | -_LEGACY_LABEL_MAP = { | |
| 46 | - "Exact": RELEVANCE_EXACT, | |
| 47 | - "Partial": RELEVANCE_HIGH, | |
| 48 | -} | |
| 49 | - | |
| 50 | - | |
| 51 | -def normalize_stored_label(label: str) -> str: | |
| 52 | - """Map legacy 3-way SQLite labels to current 4-way strings; pass through canonical labels.""" | |
| 53 | - s = str(label).strip() | |
| 54 | - if s in VALID_LABELS: | |
| 55 | - return s | |
| 56 | - return _LEGACY_LABEL_MAP.get(s, s) | |
| 57 | - | |
| 58 | - | |
| 59 | 45 | DEFAULT_ARTIFACT_ROOT = PROJECT_ROOT / "artifacts" / "search_evaluation" |
| 60 | 46 | DEFAULT_QUERY_FILE = _SCRIPTS_EVAL_DIR / "queries" / "queries.txt" |
| 61 | 47 | ... | ... |
scripts/evaluation/eval_framework/static/eval_web.css
| ... | ... | @@ -48,9 +48,9 @@ |
| 48 | 48 | .results { display: grid; gap: 10px; } |
| 49 | 49 | .result { display: grid; grid-template-columns: 110px 100px 1fr; gap: 14px; align-items: center; background: var(--panel); border: 1px solid var(--line); border-radius: 18px; padding: 12px; } |
| 50 | 50 | .badge { display: inline-block; padding: 8px 10px; border-radius: 999px; color: white; font-weight: 700; text-align: center; } |
| 51 | - .label-exact-match { background: var(--exact); } | |
| 52 | - .label-high-relevant { background: var(--high); } | |
| 53 | - .label-low-relevant { background: var(--low); } | |
| 51 | + .label-fully-relevant { background: var(--exact); } | |
| 52 | + .label-mostly-relevant { background: var(--high); } | |
| 53 | + .label-weakly-relevant { background: var(--low); } | |
| 54 | 54 | .label-irrelevant { background: var(--irrelevant); } |
| 55 | 55 | .badge-unknown { background: #637381; } |
| 56 | 56 | .thumb { width: 100px; height: 100px; object-fit: cover; border-radius: 14px; background: #e7e1d4; } | ... | ... |
scripts/evaluation/eval_framework/store.py
| ... | ... | @@ -8,7 +8,7 @@ from dataclasses import dataclass |
| 8 | 8 | from pathlib import Path |
| 9 | 9 | from typing import Any, Dict, List, Optional, Sequence |
| 10 | 10 | |
| 11 | -from .constants import VALID_LABELS, normalize_stored_label | |
| 11 | +from .constants import VALID_LABELS | |
| 12 | 12 | from .utils import ensure_dir, safe_json_dumps, utc_now_iso |
| 13 | 13 | |
| 14 | 14 | |
| ... | ... | @@ -220,7 +220,7 @@ class EvalStore: |
| 220 | 220 | """, |
| 221 | 221 | (tenant_id, query_text), |
| 222 | 222 | ).fetchall() |
| 223 | - return {str(row["spu_id"]): normalize_stored_label(str(row["label"])) for row in rows} | |
| 223 | + return {str(row["spu_id"]): str(row["label"]) for row in rows} | |
| 224 | 224 | |
| 225 | 225 | def upsert_labels( |
| 226 | 226 | self, |
| ... | ... | @@ -379,8 +379,8 @@ class EvalStore: |
| 379 | 379 | SELECT |
| 380 | 380 | query_text, |
| 381 | 381 | COUNT(*) AS total, |
| 382 | - SUM(CASE WHEN label IN ('Fully Relevant','Exact') THEN 1 ELSE 0 END) AS exact_count, | |
| 383 | - SUM(CASE WHEN label IN ('Mostly Relevant','Partial') THEN 1 ELSE 0 END) AS high_relevant_count, | |
| 382 | + SUM(CASE WHEN label='Fully Relevant' THEN 1 ELSE 0 END) AS exact_count, | |
| 383 | + SUM(CASE WHEN label='Mostly Relevant' THEN 1 ELSE 0 END) AS high_relevant_count, | |
| 384 | 384 | SUM(CASE WHEN label='Weakly Relevant' THEN 1 ELSE 0 END) AS low_relevant_count, |
| 385 | 385 | SUM(CASE WHEN label='Irrelevant' THEN 1 ELSE 0 END) AS irrelevant_count, |
| 386 | 386 | MAX(updated_at) AS updated_at |
| ... | ... | @@ -409,8 +409,8 @@ class EvalStore: |
| 409 | 409 | """ |
| 410 | 410 | SELECT |
| 411 | 411 | COUNT(*) AS total, |
| 412 | - SUM(CASE WHEN label IN ('Fully Relevant','Exact') THEN 1 ELSE 0 END) AS exact_count, | |
| 413 | - SUM(CASE WHEN label IN ('Mostly Relevant','Partial') THEN 1 ELSE 0 END) AS high_relevant_count, | |
| 412 | + SUM(CASE WHEN label='Fully Relevant' THEN 1 ELSE 0 END) AS exact_count, | |
| 413 | + SUM(CASE WHEN label='Mostly Relevant' THEN 1 ELSE 0 END) AS high_relevant_count, | |
| 414 | 414 | SUM(CASE WHEN label='Weakly Relevant' THEN 1 ELSE 0 END) AS low_relevant_count, |
| 415 | 415 | SUM(CASE WHEN label='Irrelevant' THEN 1 ELSE 0 END) AS irrelevant_count, |
| 416 | 416 | MAX(updated_at) AS updated_at | ... | ... |
suggestion/ARCHITECTURE_V2.md deleted
| ... | ... | @@ -1,304 +0,0 @@ |
| 1 | -# Suggestion 架构方案 V2(仅 Suggest,去除结果直达) | |
| 2 | - | |
| 3 | -## 0. 结论 | |
| 4 | - | |
| 5 | -本方案将 Suggest 设计为**独立高性能检索系统**,只返回建议词,不再返回商品卡片,也不做历史兼容。 | |
| 6 | - | |
| 7 | -- 只保留 `/search/suggestions` 的词级自动补全能力 | |
| 8 | -- 完全移除 `with_results/result_size/products[]` 链路 | |
| 9 | -- 多语言优先,支持高并发、低延迟、可持续演进 | |
| 10 | - | |
| 11 | ---- | |
| 12 | - | |
| 13 | -## 1. 当前实现的关键问题(基于现有代码审视) | |
| 14 | - | |
| 15 | -1. 在线链路曾包含“suggest -> 二次商品查询”,属于典型 N+1 放大,QPS 上升后延迟和 ES 负载都不稳定。 | |
| 16 | -2. `builder.py` 全量构建使用“大量 in-memory 聚合 + fetchall”,大租户下内存风险高。 | |
| 17 | -3. 查询参数上限过大(原 `size<=200`),不符合自动补全接口性能边界。 | |
| 18 | -4. 文档与实现长期混合(README 仍包含结果直达),导致认知不一致。 | |
| 19 | -5. 多语言归一化仍偏基础(仅 lower/空白折叠),对 Unicode、变音符、跨语系兼容不够。 | |
| 20 | - | |
| 21 | ---- | |
| 22 | - | |
| 23 | -## 2. 目标与 SLO | |
| 24 | - | |
| 25 | -### 2.1 业务目标 | |
| 26 | - | |
| 27 | -- 输入时实时返回高相关建议词(query suggestion) | |
| 28 | -- 多语言稳定(至少覆盖租户配置 `index_languages`) | |
| 29 | -- 支持词级排序和运营治理(黑白名单、降噪、降权) | |
| 30 | - | |
| 31 | -### 2.2 性能目标(建议) | |
| 32 | - | |
| 33 | -- P50 < 10ms,P95 < 25ms,P99 < 50ms(ES 查询耗时,不含网关) | |
| 34 | -- 单集群支持高并发(千级 QPS 可横向扩展) | |
| 35 | -- 数据新鲜度:增量 5-15 分钟可见 | |
| 36 | - | |
| 37 | ---- | |
| 38 | - | |
| 39 | -## 3. 总体架构 | |
| 40 | - | |
| 41 | -## 3.1 在线路径(单跳) | |
| 42 | - | |
| 43 | -Client -> API `/search/suggestions` -> ES `search_suggestions_v2` -> 返回 suggestions | |
| 44 | - | |
| 45 | -原则: | |
| 46 | - | |
| 47 | -- **单次 ES 查询完成主路径**(可选双召回融合,但仍在同一次 API 请求内完成) | |
| 48 | -- 不调用 `search_products`,不返回商品结果 | |
| 49 | -- 通过 `routing=tenant_id` 避免跨分片 fan-out | |
| 50 | - | |
| 51 | -## 3.2 离线路径(构建) | |
| 52 | - | |
| 53 | -数据源: | |
| 54 | - | |
| 55 | -- 商品字段:`title.{lang}`、`qanchors.{lang}` | |
| 56 | -- 搜索日志:`shoplazza_search_log`(含 `language/request_params`) | |
| 57 | -- 行为信号(可选增强):点击、加购、下单 | |
| 58 | - | |
| 59 | -产物: | |
| 60 | - | |
| 61 | -- Suggest 文档(`tenant_id + lang + text_norm` 唯一) | |
| 62 | -- completion + prefix 检索字段 | |
| 63 | -- 排序特征(热度、近期度、质量分) | |
| 64 | - | |
| 65 | -发布方式: | |
| 66 | - | |
| 67 | -- 写入新物理索引(版本化) | |
| 68 | -- 原子切换 alias(零停机) | |
| 69 | - | |
| 70 | ---- | |
| 71 | - | |
| 72 | -## 4. 索引设计(ES) | |
| 73 | - | |
| 74 | -## 4.1 索引组织 | |
| 75 | - | |
| 76 | -推荐两级策略: | |
| 77 | - | |
| 78 | -1. 默认:环境级共享索引(降低海量租户 index 数量) | |
| 79 | -2. 大租户:可升级为租户独享索引(隔离资源) | |
| 80 | - | |
| 81 | -统一通过 alias 暴露: | |
| 82 | - | |
| 83 | -- `search_suggestions_v2_current` | |
| 84 | - | |
| 85 | -## 4.2 Mapping(核心字段) | |
| 86 | - | |
| 87 | -```json | |
| 88 | -{ | |
| 89 | - "settings": { | |
| 90 | - "number_of_shards": 3, | |
| 91 | - "number_of_replicas": 1, | |
| 92 | - "refresh_interval": "30s" | |
| 93 | - }, | |
| 94 | - "mappings": { | |
| 95 | - "properties": { | |
| 96 | - "tenant_id": { "type": "keyword" }, | |
| 97 | - "lang": { "type": "keyword" }, | |
| 98 | - "text": { "type": "keyword" }, | |
| 99 | - "text_norm": { "type": "keyword" }, | |
| 100 | - "status": { "type": "byte" }, | |
| 101 | - "sources": { "type": "keyword" }, | |
| 102 | - | |
| 103 | - "query_count_7d": { "type": "integer" }, | |
| 104 | - "query_count_30d": { "type": "integer" }, | |
| 105 | - "ctr_30d": { "type": "float" }, | |
| 106 | - "order_rate_30d": { "type": "float" }, | |
| 107 | - "rank_score": { "type": "float" }, | |
| 108 | - | |
| 109 | - "suggest": { | |
| 110 | - "type": "completion", | |
| 111 | - "contexts": [ | |
| 112 | - { "name": "tenant", "type": "category" }, | |
| 113 | - { "name": "lang", "type": "category" } | |
| 114 | - ] | |
| 115 | - }, | |
| 116 | - | |
| 117 | - "sat": { | |
| 118 | - "properties": { | |
| 119 | - "zh": { "type": "search_as_you_type", "analyzer": "index_ik" }, | |
| 120 | - "en": { "type": "search_as_you_type", "analyzer": "english" }, | |
| 121 | - "ar": { "type": "search_as_you_type", "analyzer": "arabic" } | |
| 122 | - } | |
| 123 | - }, | |
| 124 | - | |
| 125 | - "updated_at": { "type": "date" } | |
| 126 | - } | |
| 127 | - } | |
| 128 | -} | |
| 129 | -``` | |
| 130 | - | |
| 131 | -说明: | |
| 132 | - | |
| 133 | -- `completion` 负责极速前缀命中(主召回) | |
| 134 | -- `search_as_you_type` 负责多词前缀和召回兜底 | |
| 135 | -- `contexts` 强制租户与语言隔离 | |
| 136 | - | |
| 137 | ---- | |
| 138 | - | |
| 139 | -## 5. 多语言策略 | |
| 140 | - | |
| 141 | -1. 语言归属优先级:`log.language > request_params.language > 脚本识别 > tenant.primary_language` | |
| 142 | -2. 统一归一化:NFKC、大小写折叠、空白折叠、标点清洗 | |
| 143 | -3. 分词器按语言配置: | |
| 144 | - - 中文:IK/ANSJ(与主索引保持一致) | |
| 145 | - - 拉丁语系:对应内置 analyzer | |
| 146 | - - 未覆盖语种:`standard + ICU folding` 兜底 | |
| 147 | -4. 保证写入语言必须在租户 `index_languages` 内 | |
| 148 | - | |
| 149 | ---- | |
| 150 | - | |
| 151 | -## 6. 在线检索策略(高性能) | |
| 152 | - | |
| 153 | -## 6.1 双通道召回(推荐) | |
| 154 | - | |
| 155 | -1. 通道 A:`completion suggester`(prefix,skip_duplicates) | |
| 156 | -2. 通道 B:`multi_match(type=bool_prefix)` on `search_as_you_type` | |
| 157 | -3. 融合去重:按 `text_norm` 去重,按最终分排序截断 | |
| 158 | - | |
| 159 | -## 6.2 查询约束 | |
| 160 | - | |
| 161 | -- 默认 `size=10`,最大 `size=50` | |
| 162 | -- `track_total_hits=false` | |
| 163 | -- `_source` 仅返回必要字段(`text/lang/rank_score/sources`) | |
| 164 | -- `routing=tenant_id` | |
| 165 | - | |
| 166 | -## 6.3 打分建议 | |
| 167 | - | |
| 168 | -```text | |
| 169 | -final_score = | |
| 170 | - es_score | |
| 171 | - + a1*log1p(query_count_30d) | |
| 172 | - + a2*log1p(query_count_7d) | |
| 173 | - + a3*ctr_30d | |
| 174 | - + a4*order_rate_30d | |
| 175 | - + a5*freshness_decay | |
| 176 | -``` | |
| 177 | - | |
| 178 | ---- | |
| 179 | - | |
| 180 | -## 7. 构建与发布 | |
| 181 | - | |
| 182 | -## 7.1 构建模式 | |
| 183 | - | |
| 184 | -- 每日全量:重建全量特征,清理脏词 | |
| 185 | -- 小时级增量:只处理新日志窗口 | |
| 186 | - | |
| 187 | -## 7.2 工程要求 | |
| 188 | - | |
| 189 | -- 禁止 `fetchall` 全量入内存,改为流式读取(分页/游标) | |
| 190 | -- ES 扫描采用 `search_after` 流式聚合 | |
| 191 | -- 批量写入采用 bulk(分块 + 重试 + 失败重放) | |
| 192 | - | |
| 193 | -## 7.3 发布策略 | |
| 194 | - | |
| 195 | -1. `search_suggestions_v2_YYYYMMDDHHmm` 写入完成 | |
| 196 | -2. 校验 count/抽样查询/核心词覆盖 | |
| 197 | -3. alias 原子切换到新索引 | |
| 198 | -4. 保留上一个版本用于快速回滚 | |
| 199 | - | |
| 200 | ---- | |
| 201 | - | |
| 202 | -## 8. API 契约(V2) | |
| 203 | - | |
| 204 | -请求: | |
| 205 | - | |
| 206 | -- `GET /search/suggestions` | |
| 207 | -- 参数:`q`、`language`、`size` | |
| 208 | -- Header:`X-Tenant-ID` | |
| 209 | - | |
| 210 | -响应: | |
| 211 | - | |
| 212 | -```json | |
| 213 | -{ | |
| 214 | - "query": "iph", | |
| 215 | - "language": "en", | |
| 216 | - "resolved_language": "en", | |
| 217 | - "suggestions": [ | |
| 218 | - { | |
| 219 | - "text": "iphone 15", | |
| 220 | - "lang": "en", | |
| 221 | - "score": 8.31, | |
| 222 | - "rank_score": 6.72, | |
| 223 | - "sources": ["query_log", "qanchor"] | |
| 224 | - } | |
| 225 | - ], | |
| 226 | - "took_ms": 12 | |
| 227 | -} | |
| 228 | -``` | |
| 229 | - | |
| 230 | -删除项(明确不支持): | |
| 231 | - | |
| 232 | -- `with_results` | |
| 233 | -- `result_size` | |
| 234 | -- `products[]` | |
| 235 | - | |
| 236 | ---- | |
| 237 | - | |
| 238 | -## 9. 观测与治理 | |
| 239 | - | |
| 240 | -核心监控: | |
| 241 | - | |
| 242 | -- QPS、P50/P95/P99、错误率 | |
| 243 | -- 空结果率(按语言、按租户) | |
| 244 | -- suggestion 覆盖率(top query 是否命中) | |
| 245 | -- 语言冲突率(log vs request_params) | |
| 246 | -- 噪声词比例、黑名单命中率 | |
| 247 | - | |
| 248 | -治理机制: | |
| 249 | - | |
| 250 | -- 黑名单:强制下线 | |
| 251 | -- 白名单:强制保留并可加权 | |
| 252 | -- 最小热度阈值:低频垃圾词过滤 | |
| 253 | -- 时间衰减:过期词自动下沉 | |
| 254 | - | |
| 255 | ---- | |
| 256 | - | |
| 257 | -## 10. 与官方最佳实践对齐(ES) | |
| 258 | - | |
| 259 | -本方案直接采用以下官方建议: | |
| 260 | - | |
| 261 | -1. `completion` 适合高性能自动补全,支持 `skip_duplicates` 与上下文过滤。 | |
| 262 | -2. `search_as_you_type + bool_prefix` 是官方推荐的 as-you-type 查询方式。 | |
| 263 | -3. `edge_ngram` 仅用于索引时分词,查询时应用普通 analyzer(`search_analyzer`)。 | |
| 264 | -4. 多语言场景使用 ICU Analysis 插件增强 Unicode 处理。 | |
| 265 | -5. 通过 `routing` 将租户请求路由到单分片,降低 fan-out。 | |
| 266 | - | |
| 267 | ---- | |
| 268 | - | |
| 269 | -## 11. 分阶段落地 | |
| 270 | - | |
| 271 | -1. Phase 1(本次):去除结果直达,稳定 Suggest 单能力 | |
| 272 | -2. Phase 2:流式增量构建 + alias 原子发布 | |
| 273 | -3. Phase 3:行为信号排序(CTR/CVR)+ 运营治理台 | |
| 274 | -4. Phase 4:大租户独享索引自动升降级 | |
| 275 | - | |
| 276 | ---- | |
| 277 | - | |
| 278 | -## 12. Phase 2 落地命令(当前仓库) | |
| 279 | - | |
| 280 | -全量重建(版本化索引 + alias 发布): | |
| 281 | - | |
| 282 | -```bash | |
| 283 | -python main.py build-suggestions \ | |
| 284 | - --tenant-id 162 \ | |
| 285 | - --mode full \ | |
| 286 | - --days 365 \ | |
| 287 | - --publish-alias \ | |
| 288 | - --keep-versions 2 | |
| 289 | -``` | |
| 290 | - | |
| 291 | -增量更新(基于 watermark): | |
| 292 | - | |
| 293 | -```bash | |
| 294 | -python main.py build-suggestions \ | |
| 295 | - --tenant-id 162 \ | |
| 296 | - --mode incremental \ | |
| 297 | - --overlap-minutes 30 | |
| 298 | -``` | |
| 299 | - | |
| 300 | -一键脚本(全量 + 增量 + ES/API 验证): | |
| 301 | - | |
| 302 | -```bash | |
| 303 | -./scripts/rebuild_suggestions.sh 162 | |
| 304 | -``` |
suggestion/README.md
| 1 | -# Suggestion 模块说明(统一入口) | |
| 1 | +# Suggestion 模块说明 | |
| 2 | 2 | |
| 3 | -本文档是 suggestion 模块的统一入口,遵循 `docs/DEVELOPER_GUIDE.md` 的“单一入口、避免分叉”原则。 | |
| 3 | +`suggestion/` 目录负责搜索框自动补全能力,当前实现只关注 suggestion 本身:离线构建建议词索引,在线根据输入前缀返回建议词列表。 | |
| 4 | 4 | |
| 5 | -## 1. 当前状态(Phase 2) | |
| 5 | +这份 README 以当前代码实现为准,重点说明模块现状、关键设计、索引结构、构建发布方式,以及在线检索和排序细节。 | |
| 6 | 6 | |
| 7 | -- 仅保留 Suggest 自动补全能力 | |
| 8 | -- 不支持结果直达(`with_results` / `result_size` / `products[]` 已移除) | |
| 9 | -- 索引采用版本化发布: | |
| 10 | - - 物理索引:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_v<timestamp>` | |
| 11 | - - 读别名:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_current` | |
| 12 | -- 支持增量更新(watermark + overlap) | |
| 7 | +## 1. 当前能力边界 | |
| 13 | 8 | |
| 14 | -## 2. 文档导航(唯一推荐顺序) | |
| 9 | +- 对外接口:`GET /search/suggestions` | |
| 10 | +- 输入参数:`q`、`size`、`language` | |
| 11 | +- Header:`X-Tenant-ID` | |
| 12 | +- 返回内容:建议词列表 `suggestions[]` | |
| 13 | +- 不做商品结果拼接,也不走二次商品查询链路 | |
| 14 | + | |
| 15 | +当前模块由三部分组成: | |
| 16 | + | |
| 17 | +1. 离线构建:从商品索引和搜索日志构建 suggestion 文档 | |
| 18 | +2. 索引发布:写入版本化索引,并通过 alias 原子切换 | |
| 19 | +3. 在线查询:优先走 completion,必要时再走 `search_as_you_type` 兜底召回 | |
| 20 | + | |
| 21 | +## 2. 目录与关键代码 | |
| 22 | + | |
| 23 | +- [builder.py](/data/saas-search/suggestion/builder.py):离线构建、增量更新、alias 发布、meta 状态维护 | |
| 24 | +- [mapping.py](/data/saas-search/suggestion/mapping.py):suggestion 索引 settings 和 mappings 生成 | |
| 25 | +- [service.py](/data/saas-search/suggestion/service.py):在线查询服务,负责语言归一化、双路召回、去重和最终排序 | |
| 26 | +- [RUNBOOK.md](/data/saas-search/suggestion/RUNBOOK.md):构建、发布、验证操作说明 | |
| 27 | +- [TROUBLESHOOTING.md](/data/saas-search/suggestion/TROUBLESHOOTING.md):常见问题排查 | |
| 28 | + | |
| 29 | +命令入口在 [main.py](/data/saas-search/main.py) 中的 `build-suggestions` 子命令。 | |
| 30 | + | |
| 31 | +## 3. 整体架构 | |
| 32 | + | |
| 33 | +在线路径: | |
| 34 | + | |
| 35 | +`Client -> /search/suggestions -> SuggestionService -> Elasticsearch suggestion alias` | |
| 36 | + | |
| 37 | +离线路径: | |
| 38 | + | |
| 39 | +`商品索引 + 搜索日志 -> SuggestionIndexBuilder -> 版本化 suggestion index -> alias publish` | |
| 40 | + | |
| 41 | +设计上有几个核心点: | |
| 42 | + | |
| 43 | +- suggestion 独立建索引,不依赖在线商品检索 | |
| 44 | +- 每个租户单独维护 suggestion alias,避免租户间相互影响 | |
| 45 | +- 全量构建写新索引,切换 alias 时零停机 | |
| 46 | +- 增量更新只处理 query log 增量,减少重建成本 | |
| 47 | + | |
| 48 | +## 4. 索引组织与发布 | |
| 49 | + | |
| 50 | +索引命名在 [builder.py](/data/saas-search/suggestion/builder.py) 中统一定义: | |
| 51 | + | |
| 52 | +- 读别名:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_current` | |
| 53 | +- 版本索引:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_v<timestamp>` | |
| 54 | +- 元信息索引:`{ES_INDEX_NAMESPACE}search_suggestions_meta` | |
| 55 | + | |
| 56 | +当前实现是“每租户一个 suggestion alias + 多个版本索引”的模式,而不是环境级共享大索引。 | |
| 57 | + | |
| 58 | +全量构建时的发布流程: | |
| 59 | + | |
| 60 | +1. 创建新的版本化索引 | |
| 61 | +2. 写入本次构建出的 suggestion 文档 | |
| 62 | +3. 校验新索引可分配、可读 | |
| 63 | +4. alias 原子切换到新索引 | |
| 64 | +5. 清理旧版本索引,只保留最近若干份 | |
| 65 | +6. 更新 `search_suggestions_meta` | |
| 66 | + | |
| 67 | +元信息索引里记录: | |
| 68 | + | |
| 69 | +- `active_alias` | |
| 70 | +- `active_index` | |
| 71 | +- `last_full_build_at` | |
| 72 | +- `last_incremental_build_at` | |
| 73 | +- `last_incremental_watermark` | |
| 74 | + | |
| 75 | +这些信息主要服务于增量更新和排障。 | |
| 76 | + | |
| 77 | +## 5. Mapping 与索引字段 | |
| 78 | + | |
| 79 | +[mapping.py](/data/saas-search/suggestion/mapping.py) 会根据租户的 `index_languages` 动态生成字段。 | |
| 80 | + | |
| 81 | +### 5.1 索引设置 | |
| 82 | + | |
| 83 | +- `number_of_shards = 1` | |
| 84 | +- `number_of_replicas = 0` | |
| 85 | +- `refresh_interval = 30s` | |
| 86 | + | |
| 87 | +中文使用自定义 analyzer: | |
| 88 | + | |
| 89 | +- `index_ik`:`ik_max_word + lowercase + asciifolding` | |
| 90 | +- `query_ik`:`ik_smart + lowercase + asciifolding` | |
| 91 | + | |
| 92 | +其他语言优先使用 Elasticsearch 内置 analyzer,例如 `english`、`arabic`、`french`、`german` 等;未覆盖语言回退到 `standard`。 | |
| 93 | + | |
| 94 | +### 5.2 核心字段 | |
| 95 | + | |
| 96 | +- `tenant_id`:租户隔离 | |
| 97 | +- `lang`:建议词所属语言 | |
| 98 | +- `text`:原始展示文本 | |
| 99 | +- `text_norm`:归一化文本,用于唯一键和去重 | |
| 100 | +- `sources`:来源集合,可能包含 `title`、`qanchor`、`tag`、`query_log` | |
| 101 | +- `title_doc_count` / `qanchor_doc_count` / `tag_doc_count`:该词被多少商品字段支撑 | |
| 102 | +- `query_count_7d` / `query_count_30d`:近 7/30 天搜索热度 | |
| 103 | +- `rank_score`:离线预计算排序分 | |
| 104 | +- `lang_confidence` / `lang_source` / `lang_conflict`:语言识别与冲突信息 | |
| 105 | +- `status`:当前是否有效 | |
| 106 | +- `updated_at`:最近更新时间 | |
| 107 | + | |
| 108 | +### 5.3 两类检索字段 | |
| 109 | + | |
| 110 | +1. `completion.<lang>` | |
| 111 | +2. `sat.<lang>` | |
| 112 | + | |
| 113 | +`completion.<lang>` 用于极速前缀补全,是短 query 下的主召回通道。 | |
| 114 | + | |
| 115 | +`sat.<lang>` 使用 `search_as_you_type`,用于多词前缀和 completion 未补足时的兜底召回。 | |
| 116 | + | |
| 117 | +也就是说,当前线上不是只靠一种召回方式,而是 completion 优先、SAT 补全。 | |
| 118 | + | |
| 119 | +## 6. 候选词从哪里来 | |
| 120 | + | |
| 121 | +[builder.py](/data/saas-search/suggestion/builder.py) 在全量构建中会聚合两大类数据源。 | |
| 122 | + | |
| 123 | +### 6.1 商品侧 | |
| 124 | + | |
| 125 | +从租户商品索引中流式读取: | |
| 126 | + | |
| 127 | +- `title` | |
| 128 | +- `qanchors` | |
| 129 | +- `enriched_tags` | |
| 130 | + | |
| 131 | +处理方式: | |
| 132 | + | |
| 133 | +- `title.<lang>`:经 `_prepare_title_for_suggest()` 裁剪后作为候选词 | |
| 134 | +- `qanchors.<lang>`:按分隔符拆分后作为候选词 | |
| 135 | +- `enriched_tags`:支持多语言对象或普通列表,必要时做语言识别 | |
| 136 | + | |
| 137 | +商品扫描不是一次性全量拉入内存,而是通过 `search_after` 分批读取,这一点在 [_iter_products()](/data/saas-search/suggestion/builder.py#L363) 已实现。 | |
| 138 | + | |
| 139 | +### 6.2 搜索日志侧 | |
| 140 | + | |
| 141 | +从 MySQL `shoplazza_search_log` 中按时间窗口流式读取: | |
| 142 | + | |
| 143 | +- `query` | |
| 144 | +- `language` | |
| 145 | +- `request_params` | |
| 146 | +- `create_time` | |
| 147 | + | |
| 148 | +读取方式使用 `stream_results=True + fetchmany()`,避免 `fetchall()` 带来的内存风险,这也是当前实现相对旧方案的重要改进。 | |
| 149 | + | |
| 150 | +搜索日志主要用于补充: | |
| 151 | + | |
| 152 | +- 用户真实搜索词 | |
| 153 | +- 近 7/30 天热度 | |
| 154 | +- 语言归属信息 | |
| 155 | + | |
| 156 | +## 7. 文本清洗与语言策略 | |
| 157 | + | |
| 158 | +### 7.1 文本归一化 | |
| 159 | + | |
| 160 | +在 [_normalize_text()](/data/saas-search/suggestion/builder.py#L176) 中,当前实现会做: | |
| 161 | + | |
| 162 | +- Unicode `NFKC` 归一化 | |
| 163 | +- 去首尾空白 | |
| 164 | +- 转小写 | |
| 165 | +- 多空白折叠为单空格 | |
| 166 | + | |
| 167 | +这份 `text_norm` 是 suggestion 文档的稳定键的一部分,文档 `_id` 形式为: | |
| 168 | + | |
| 169 | +`{tenant_id}|{lang}|{text_norm}` | |
| 170 | + | |
| 171 | +这保证了同租户、同语言、同一归一化词面只会保留一份文档。 | |
| 172 | + | |
| 173 | +### 7.2 噪声过滤 | |
| 174 | + | |
| 175 | +在 [_looks_noise()](/data/saas-search/suggestion/builder.py#L264) 中,以下内容会被过滤: | |
| 176 | + | |
| 177 | +- 空文本 | |
| 178 | +- 长度超过 120 | |
| 179 | +- 全部由符号组成的文本 | |
| 15 | 180 | |
| 16 | -1. `ARCHITECTURE_V2.md`:架构与设计原则 | |
| 17 | -2. `RUNBOOK.md`:构建/发布/验证流程 | |
| 18 | -3. `TROUBLESHOOTING.md`:常见问题排查 | |
| 181 | +### 7.3 语言判定优先级 | |
| 19 | 182 | |
| 20 | -## 3. 命令入口 | |
| 183 | +日志 query 的语言归属由 [_resolve_query_language()](/data/saas-search/suggestion/builder.py#L299) 负责,优先级是: | |
| 21 | 184 | |
| 22 | -- 全量或增量构建: | |
| 185 | +1. `shoplazza_search_log.language` | |
| 186 | +2. `request_params.language` | |
| 187 | +3. `detect_text_language_for_suggestions()` | |
| 188 | +4. 租户 `primary_language` | |
| 189 | + | |
| 190 | +同时会记录: | |
| 191 | + | |
| 192 | +- `lang_source`:语言来自哪里 | |
| 193 | +- `lang_confidence`:识别置信度 | |
| 194 | +- `lang_conflict`:日志语言与请求语言是否冲突 | |
| 195 | + | |
| 196 | +在线查询侧在 [_resolve_language()](/data/saas-search/suggestion/service.py#L24) 也会做一次语言归一化,确保查询只打到租户允许的 `index_languages`。 | |
| 197 | + | |
| 198 | +## 8. 排序与 rank 细节 | |
| 199 | + | |
| 200 | +当前排序分成两层:离线 `rank_score`,以及在线最终排序。 | |
| 201 | + | |
| 202 | +### 8.1 离线 `rank_score` | |
| 203 | + | |
| 204 | +在 [_compute_rank_score()](/data/saas-search/suggestion/builder.py#L338) 中,当前公式是: | |
| 205 | + | |
| 206 | +```text | |
| 207 | +rank_score = | |
| 208 | + 1.8 * log1p(query_count_30d) | |
| 209 | + + 1.2 * log1p(query_count_7d) | |
| 210 | + + 1.0 * log1p(qanchor_doc_count) | |
| 211 | + + 0.85 * log1p(tag_doc_count) | |
| 212 | + + 0.6 * log1p(title_doc_count) | |
| 213 | +``` | |
| 214 | + | |
| 215 | +含义上是: | |
| 216 | + | |
| 217 | +- 搜索日志热度权重大于商品静态字段 | |
| 218 | +- 30 天热度权重大于 7 天热度,但 7 天热度也会强化近期趋势 | |
| 219 | +- `qanchor` 比普通标题更像“可搜索表达”,所以权重更高 | |
| 220 | +- `tag` 次之 | |
| 221 | +- `title` 提供基础覆盖,但权重相对更低 | |
| 222 | + | |
| 223 | +这个分数会被写入: | |
| 224 | + | |
| 225 | +- 文档字段 `rank_score` | |
| 226 | +- `completion.<lang>.weight` | |
| 227 | + | |
| 228 | +因此它同时影响 completion 通道和 SAT 通道。 | |
| 229 | + | |
| 230 | +### 8.2 在线召回排序 | |
| 231 | + | |
| 232 | +[service.py](/data/saas-search/suggestion/service.py) 中的在线策略如下: | |
| 233 | + | |
| 234 | +1. 先查 `completion.<lang>` | |
| 235 | +2. 若 query 长度大于 2 且 completion 结果不足,再查 `sat.<lang>` | |
| 236 | +3. 按 `text` 归一化结果去重 | |
| 237 | +4. 最终排序后截断 | |
| 238 | + | |
| 239 | +completion 通道本身依赖 ES completion 的 `_score`;SAT 通道则用 `function_score + field_value_factor(rank_score)`。 | |
| 240 | + | |
| 241 | +最终排序由 [_finalize_suggestion_list()](/data/saas-search/suggestion/service.py#L155) 负责,排序 key 为: | |
| 242 | + | |
| 243 | +1. `score * 长度惩罚系数` | |
| 244 | +2. `rank_score` | |
| 245 | + | |
| 246 | +长度惩罚定义在 [_suggestion_length_factor()](/data/saas-search/suggestion/service.py#L16): | |
| 247 | + | |
| 248 | +```text | |
| 249 | +length_factor = 1 / sqrt(token_len) | |
| 250 | +``` | |
| 251 | + | |
| 252 | +这意味着在分数相近时,较短、较直接的 suggestion 会更容易排在前面,避免长尾长句把前缀补全结果“顶掉”。 | |
| 253 | + | |
| 254 | +## 9. 在线查询细节 | |
| 255 | + | |
| 256 | +[SuggestionService.search()](/data/saas-search/suggestion/service.py#L110) 的行为可以概括为: | |
| 257 | + | |
| 258 | +- 如果 alias 不存在,直接返回空数组,不抛 500 | |
| 259 | +- 短 query 优先走 completion 快速返回 | |
| 260 | +- 对于更长 query,再补一次 `bool_prefix` | |
| 261 | +- 查询时始终带 `routing=tenant_id` | |
| 262 | + | |
| 263 | +这里的 `routing` 很重要,它保证 suggestion 查询尽量只落在目标租户对应的分片路由上,减少无效 fan-out。 | |
| 264 | + | |
| 265 | +SAT 查询部分还有两个显式过滤条件: | |
| 266 | + | |
| 267 | +- `lang == resolved_language` | |
| 268 | +- `status == 1` | |
| 269 | + | |
| 270 | +这能保证召回结果只来自当前语言、当前有效文档。 | |
| 271 | + | |
| 272 | +## 10. 全量构建与增量更新 | |
| 273 | + | |
| 274 | +### 10.1 全量构建 | |
| 275 | + | |
| 276 | +入口在 [main.py](/data/saas-search/main.py#L104) 的 `build-suggestions --mode full`。 | |
| 277 | + | |
| 278 | +行为是: | |
| 279 | + | |
| 280 | +1. 读取租户配置中的 `index_languages` 和 `primary_language` | |
| 281 | +2. 创建新版本索引并等待 ready | |
| 282 | +3. 聚合商品数据和搜索日志,构造候选词 | |
| 283 | +4. 计算 `rank_score` | |
| 284 | +5. bulk 写入新索引 | |
| 285 | +6. refresh | |
| 286 | +7. 发布 alias | |
| 287 | +8. 更新 meta 信息 | |
| 288 | + | |
| 289 | +### 10.2 增量更新 | |
| 290 | + | |
| 291 | +入口在 [main.py](/data/saas-search/main.py#L104) 的 `build-suggestions --mode incremental`。 | |
| 292 | + | |
| 293 | +当前增量只处理 query log,不回扫商品数据。它依赖 meta 中的 watermark: | |
| 294 | + | |
| 295 | +- `last_incremental_watermark` | |
| 296 | +- 不存在时回退到 `last_full_build_at` | |
| 297 | +- 再不行就使用 `fallback_days` | |
| 298 | + | |
| 299 | +为了避免边界时间漏数,会额外减去 `overlap_minutes`,形成一个带重叠窗口的增量区间。 | |
| 300 | + | |
| 301 | +增量写入不是整文档重建,而是通过 `scripted_upsert` 做原地累加: | |
| 302 | + | |
| 303 | +- 增加 `query_count_30d` | |
| 304 | +- 增加 `query_count_7d` | |
| 305 | +- 更新 `lang_confidence` / `lang_source` / `lang_conflict` | |
| 306 | +- 重新计算 `rank_score` | |
| 307 | +- 更新 `completion` 和 `sat` | |
| 308 | + | |
| 309 | +对应逻辑在 [_build_incremental_update_script()](/data/saas-search/suggestion/builder.py#L834)。 | |
| 310 | + | |
| 311 | +如果 alias 尚不存在,而 `bootstrap_if_missing=True`,增量任务会先自动做一次全量构建作为初始化。 | |
| 312 | + | |
| 313 | +## 11. 当前实现的一些取舍 | |
| 314 | + | |
| 315 | +### 11.1 优点 | |
| 316 | + | |
| 317 | +- 在线链路很短,没有 suggestion 后再查商品的放大成本 | |
| 318 | +- 构建和发布流程清晰,支持零停机切换 | |
| 319 | +- 商品侧和日志侧都采用流式处理,能控制内存占用 | |
| 320 | +- completion + SAT 双路召回兼顾低延迟和补全能力 | |
| 321 | +- 语言、热度、来源信息都保存在索引中,便于后续优化 | |
| 322 | + | |
| 323 | +### 11.2 现阶段边界 | |
| 324 | + | |
| 325 | +- 增量更新目前只增量处理 query log,商品标题、qanchor、tag 变更仍依赖全量构建刷新 | |
| 326 | +- `rank_score` 目前只使用热度和商品字段覆盖度,没有接入点击、转化等行为质量信号 | |
| 327 | +- 文本归一化目前以 `NFKC + lower + whitespace fold` 为主,尚未做更激进的跨语种归并策略 | |
| 328 | +- `number_of_replicas=0` 更偏开发或成本优先配置,生产是否需要副本要结合集群策略评估 | |
| 329 | + | |
| 330 | +## 12. 常用命令 | |
| 331 | + | |
| 332 | +全量构建: | |
| 23 | 333 | |
| 24 | 334 | ```bash |
| 25 | 335 | ./scripts/build_suggestions.sh <tenant_id> --mode full |
| 26 | -./scripts/build_suggestions.sh <tenant_id> --mode incremental | |
| 27 | 336 | ``` |
| 28 | 337 | |
| 29 | -- 一键重建 + 验证: | |
| 338 | +增量构建: | |
| 30 | 339 | |
| 31 | 340 | ```bash |
| 32 | -./scripts/rebuild_suggestions.sh <tenant_id> | |
| 341 | +./scripts/build_suggestions.sh <tenant_id> --mode incremental | |
| 33 | 342 | ``` |
| 34 | 343 | |
| 35 | -## 4. API 约定(简版) | |
| 344 | +一键重建并验证: | |
| 36 | 345 | |
| 37 | -- 端点:`GET /search/suggestions` | |
| 38 | -- 参数:`q`, `size`, `language` | |
| 39 | -- Header:`X-Tenant-ID` | |
| 346 | +```bash | |
| 347 | +./scripts/rebuild_suggestions.sh <tenant_id> | |
| 348 | +``` | |
| 40 | 349 | |
| 41 | -示例: | |
| 350 | +接口示例: | |
| 42 | 351 | |
| 43 | 352 | ```bash |
| 44 | 353 | curl "http://localhost:6002/search/suggestions?q=shi&size=10&language=en" \ |
| 45 | 354 | -H "X-Tenant-ID: 162" |
| 46 | 355 | ``` |
| 356 | + | |
| 357 | +更多操作细节见 [RUNBOOK.md](/data/saas-search/suggestion/RUNBOOK.md),故障排查看 [TROUBLESHOOTING.md](/data/saas-search/suggestion/TROUBLESHOOTING.md)。 | ... | ... |
suggestion/RUNBOOK.md
| ... | ... | @@ -34,7 +34,7 @@ DB_PASSWORD=... |
| 34 | 34 | ### 3.1 执行 |
| 35 | 35 | |
| 36 | 36 | ```bash |
| 37 | -./scripts/build_suggestions.sh 162 \ | |
| 37 | +./scripts/build_suggestions.sh 163 \ | |
| 38 | 38 | --mode full \ |
| 39 | 39 | --days 365 \ |
| 40 | 40 | --publish-alias \ |
| ... | ... | @@ -56,7 +56,7 @@ DB_PASSWORD=... |
| 56 | 56 | ### 4.1 执行 |
| 57 | 57 | |
| 58 | 58 | ```bash |
| 59 | -./scripts/build_suggestions.sh 162 \ | |
| 59 | +./scripts/build_suggestions.sh 163 \ | |
| 60 | 60 | --mode incremental \ |
| 61 | 61 | --overlap-minutes 30 |
| 62 | 62 | ``` |
| ... | ... | @@ -76,7 +76,7 @@ DB_PASSWORD=... |
| 76 | 76 | > 若 ES 开启鉴权,请附带 `-u "$ES_USERNAME:$ES_PASSWORD"`。 |
| 77 | 77 | |
| 78 | 78 | ```bash |
| 79 | -ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_162_current" | |
| 79 | +ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_163_current" | |
| 80 | 80 | |
| 81 | 81 | curl "$ES_HOST/$ALIAS_NAME/_count?pretty" |
| 82 | 82 | |
| ... | ... | @@ -97,10 +97,10 @@ curl "$ES_HOST/$ALIAS_NAME/_search?pretty" -H 'Content-Type: application/json' - |
| 97 | 97 | |
| 98 | 98 | ```bash |
| 99 | 99 | curl "http://localhost:6002/search/suggestions?q=shirt&size=10&language=en" \ |
| 100 | - -H "X-Tenant-ID: 162" | |
| 100 | + -H "X-Tenant-ID: 163" | |
| 101 | 101 | |
| 102 | 102 | curl "http://localhost:6002/search/suggestions?q=玩具&size=10&language=zh" \ |
| 103 | - -H "X-Tenant-ID: 162" | |
| 103 | + -H "X-Tenant-ID: 163" | |
| 104 | 104 | ``` |
| 105 | 105 | |
| 106 | 106 | 通过标准: |
| ... | ... | @@ -112,7 +112,7 @@ curl "http://localhost:6002/search/suggestions?q=玩具&size=10&language=zh" \ |
| 112 | 112 | ## 7. 一键验证脚本 |
| 113 | 113 | |
| 114 | 114 | ```bash |
| 115 | -./scripts/rebuild_suggestions.sh 162 | |
| 115 | +./scripts/rebuild_suggestions.sh 163 | |
| 116 | 116 | ``` |
| 117 | 117 | |
| 118 | 118 | 该脚本执行: | ... | ... |