Compare View
Commits (2)
-
2. issues文档
-
- consolidate suggestion rebuild flow into build_suggestions.sh via --rebuild and remove the redundant rebuild_suggestions.sh wrapper - make suggestion versioned index names use microseconds and handle index-create retries/timeouts without false already_exists failures - treat create requests as successful when the index was created server-side, then explicitly wait for shard readiness and surface allocation diagnostics - clean up freshly created suggestion indices on rebuild failure to avoid leaving red orphan indices behind - make rebuild smoke tests target the local backend by default, with SUGGESTIONS_SMOKE_BASE_URL as the explicit override - add unit coverage for microsecond versioned index names and cleanup on unallocatable index failures
Showing
19 changed files
Show diff stats
README.md
| ... | ... | @@ -42,15 +42,15 @@ source activate.sh |
| 42 | 42 | - `docs/Usage-Guide.md` -> `服务管理总览` |
| 43 | 43 | |
| 44 | 44 | 核心端口: |
| 45 | - | |
| 45 | +- `6001` qp | |
| 46 | 46 | - `6002` backend(`/search/*`, `/admin/*`) |
| 47 | -- `6004` indexer(`/indexer/*`) | |
| 48 | 47 | - `6003` frontend |
| 49 | -- `6010` eval-web(搜索评估 UI,`./scripts/service_ctl.sh` 服务名 `eval-web`) | |
| 48 | +- `6004` indexer(`/indexer/*`) | |
| 50 | 49 | - `6005` embedding-text(可选,`POST /embed/text`;常见后端为 TEI,默认 `8080`) |
| 51 | -- `6008` embedding-image(可选,`POST /embed/image` 等) | |
| 52 | 50 | - `6006` translator(可选) |
| 53 | 51 | - `6007` reranker(可选,`POST /rerank`;精排可与主重排分 `service_profile`,见 `config.yaml` → `fine_rank` / `services.rerank`) |
| 52 | +- `6008` embedding-image(可选,`POST /embed/image` 等) | |
| 53 | +- `6010` eval-web(搜索评估 UI,`./scripts/service_ctl.sh` 服务名 `eval-web`) | |
| 54 | 54 | |
| 55 | 55 | 更完整示例见 `docs/QUICKSTART.md`。 |
| 56 | 56 | ... | ... |
| ... | ... | @@ -0,0 +1,28 @@ |
| 1 | +product_enrich_prompts增加多模态标注: | |
| 2 | + | |
| 3 | +服装品类: | |
| 4 | + | |
| 5 | + | |
| 6 | +模型选项: | |
| 7 | +参考下面两个模型,综合考虑价格和性能: | |
| 8 | +Qwen3.5 Plus | |
| 9 | +qwen3-vl-plus | |
| 10 | + (Batch调用半价,但是看美国地区是否支持batch调用) | |
| 11 | + | |
| 12 | + | |
| 13 | +品类 Category path | |
| 14 | +人群/尺码 Target audience | |
| 15 | +风格 | |
| 16 | +场景 Usage scene | |
| 17 | +季节 Applicable season | |
| 18 | +功能特性 Functional features | |
| 19 | +版型 | |
| 20 | +廓形 | |
| 21 | +袖型 | |
| 22 | +领型 | |
| 23 | +面料/材质 Material description | |
| 24 | +图案/设计 | |
| 25 | +色系 | |
| 26 | + | |
| 27 | + | |
| 28 | +要有缓存:hash(图片+标题)→llm结果 | ... | ... |
| ... | ... | @@ -0,0 +1,69 @@ |
| 1 | +这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查: | |
| 2 | + | |
| 3 | +总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | |
| 4 | +先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | |
| 5 | + | |
| 6 | +示例prompt和对应的9个分类词: | |
| 7 | +Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} | |
| 8 | + | |
| 9 | +做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。 | |
| 10 | + | |
| 11 | +使用Tesla T4,因此: | |
| 12 | +1. 使用FP16。不用 BF16 作为主路径 | |
| 13 | +2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | |
| 14 | +3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 | |
| 15 | + | |
| 16 | + | |
| 17 | +主要考虑优化方向为: | |
| 18 | +hidden_last -> N-class scorer -> argmax | |
| 19 | +参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | |
| 20 | +去 full vocab logits | |
| 21 | +去 decode / constrained decode | |
| 22 | +query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) | |
| 23 | +专用 tail kernel(输出 N 类原始分数) | |
| 24 | + | |
| 25 | +以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: | |
| 26 | +15.1 预分配 | |
| 27 | +对每个 bucket 预分配: | |
| 28 | +input ids buffer | |
| 29 | +attention scratch buffer | |
| 30 | +output hidden buffer | |
| 31 | +class id buffer | |
| 32 | +15.2 Graph capture | |
| 33 | + | |
| 34 | +每个 bucket 预先 capture 一张图: | |
| 35 | +embedding | |
| 36 | +transformer 主干 | |
| 37 | +compact last-state | |
| 38 | +9-way tail kernel | |
| 39 | + | |
| 40 | +注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | |
| 41 | + | |
| 42 | +你有sudo权限,你可以执行为本项目安装自己的环境 | |
| 43 | + | |
| 44 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) | |
| 45 | +或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 | |
| 46 | + | |
| 47 | + | |
| 48 | +请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | |
| 49 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 | |
| 50 | +注意使用fp16版本,不要使用量化版本。 | |
| 51 | + | |
| 52 | +上面的需求你已经满足了大部分, | |
| 53 | +但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。 | |
| 54 | +需要通过多 token 标签做极致的性能优化,避免串行decode。 | |
| 55 | +一个考虑的方向是: | |
| 56 | +将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch) | |
| 57 | + | |
| 58 | +只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。 | |
| 59 | + | |
| 60 | +所有后续聚合均为向量化操作(GPU 上几乎不花时间)。 | |
| 61 | + | |
| 62 | +也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 | |
| 63 | + | |
| 64 | +参考文档: | |
| 65 | + | |
| 66 | +vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ | |
| 67 | +PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ | |
| 68 | +TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | |
| 69 | +Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq | |
| 0 | 70 | \ No newline at end of file | ... | ... |
| ... | ... | @@ -0,0 +1,98 @@ |
| 1 | +这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查: | |
| 2 | + | |
| 3 | +总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | |
| 4 | +先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | |
| 5 | + | |
| 6 | +示例prompt和对应的9个分类词: | |
| 7 | +Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} | |
| 8 | + | |
| 9 | +做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。 | |
| 10 | + | |
| 11 | +使用Tesla T4,因此: | |
| 12 | +1. 使用FP16。不用 BF16 作为主路径 | |
| 13 | +2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | |
| 14 | +3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 | |
| 15 | + | |
| 16 | + | |
| 17 | +主要考虑优化方向为: | |
| 18 | +hidden_last -> N-class scorer -> argmax | |
| 19 | +参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | |
| 20 | +去 full vocab logits | |
| 21 | +去 decode / constrained decode | |
| 22 | +query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) | |
| 23 | +专用 tail kernel(输出 N 类原始分数) | |
| 24 | + | |
| 25 | +以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: | |
| 26 | +15.1 预分配 | |
| 27 | +对每个 bucket 预分配: | |
| 28 | +input ids buffer | |
| 29 | +attention scratch buffer | |
| 30 | +output hidden buffer | |
| 31 | +class id buffer | |
| 32 | +15.2 Graph capture | |
| 33 | + | |
| 34 | +每个 bucket 预先 capture 一张图: | |
| 35 | +embedding | |
| 36 | +transformer 主干 | |
| 37 | +compact last-state | |
| 38 | +9-way tail kernel | |
| 39 | + | |
| 40 | +注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | |
| 41 | + | |
| 42 | +你有sudo权限,你可以执行为本项目安装自己的环境 | |
| 43 | + | |
| 44 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) | |
| 45 | +或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 | |
| 46 | + | |
| 47 | + | |
| 48 | +请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | |
| 49 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 | |
| 50 | +注意使用fp16版本,不要使用量化版本。 | |
| 51 | + | |
| 52 | +上面的需求你已经满足了大部分, | |
| 53 | +但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。 | |
| 54 | +需要通过多 token 标签做极致的性能优化,避免串行decode。 | |
| 55 | +一个考虑的方向是: | |
| 56 | +将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch) | |
| 57 | + | |
| 58 | +只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。 | |
| 59 | + | |
| 60 | +所有后续聚合均为向量化操作(GPU 上几乎不花时间)。 | |
| 61 | + | |
| 62 | +也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 | |
| 63 | + | |
| 64 | +参考文档: | |
| 65 | + | |
| 66 | +vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ | |
| 67 | +PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ | |
| 68 | +TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | |
| 69 | +Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq | |
| 70 | + | |
| 71 | + | |
| 72 | +现在还有以下问题: | |
| 73 | +1. 请满足: | |
| 74 | +先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | |
| 75 | + | |
| 76 | +2. 打分结果不对:不管输入什么都输出jeans,输入裙子也输出jeans,只有精确的输入skirt才能输出skirt。 | |
| 77 | + | |
| 78 | +3. 现在是一个prompt,我需要增加6个prompt,都进行配置化,每次输入一个query、对7个prompt进行批量的推理,对每个prompt都输出词典中每个分类的打分分布 | |
| 79 | +人群 | |
| 80 | +Analyze the target user group of the given query. Output exactly one label from: boy, girl, man, woman, pregnant. Output nothing else query:{query} | |
| 81 | + | |
| 82 | +尺寸 | |
| 83 | +Analyze the size intent of the given query. Output exactly one label from: xs, s, m, l, xl, xxl, plus, petite. Output nothing else query:{query} | |
| 84 | + | |
| 85 | +领口 | |
| 86 | +Analyze the neckline type of the given query. Output exactly one label from: round, v, turtleneck, collared, scoop, offshoulder. Output nothing else query:{query} | |
| 87 | + | |
| 88 | +袖长类型 | |
| 89 | +Analyze the sleeve length of the given query. Output exactly one label from: long, short, sleeveless, half, threequarter, cap. Output nothing else query:{query} | |
| 90 | + | |
| 91 | +上衣长度类型 | |
| 92 | +Analyze the top length of the given query. Output exactly one label from: long, short, regular, cropped, tunic, midi. Output nothing else query:{query} | |
| 93 | + | |
| 94 | +面料 | |
| 95 | +Analyze the fabric material of the given query. Output exactly one label from: cotton, linen, silk, wool, polyester, denim, leather, chiffon, fleece. Output nothing else query:{query} | |
| 96 | + | |
| 97 | + | |
| 98 | +注意要使用测试用例进行测试。包括打分结果是否符合预期、性能测试。 | ... | ... |
| ... | ... | @@ -0,0 +1,4 @@ |
| 1 | + | |
| 2 | +1. “仓库启动时会为所有 batch_sizes x continuation_buckets x prompt 预建 bucket”,之前可能是根据“极致的性能要求”、“不要做任何懒加载,确保真实请求发生时得到极致的响应时间”所做的设计,这显然太过了,我是希望一些基本的可以事先加载的应该先加载,但是牺牲巨大的显存占用来换取微弱的耗时提升是不提倡的,请你站在更高的角度,理会我的需求(先加载好模型、跨session的KV cache、并针对特殊用法 即score方式而不是逐步decode方式,来极致的优化性能,用于对线上的单个query,以最短耗时得到7个prompt的分类结果) | |
| 3 | + | |
| 4 | +2. 现在的 7 个 prompt推理是串行的,MultiPromptRunner.score_query() 里就是 for runner in self.runners 一个一个跑,需要把执行模型改成“按 prompt 分组批量并行”,但是现在有 2 个 fast path 和 5 个 multi-token path,是各合成一次 forward,还是有可能合成一个?因为multi-token也可以一次线prefill进去,是否能做到跟fast path同级别的性能?请你站在更高的角度进行思考,保证性能的同时降低复杂性。 | |
| 0 | 5 | \ No newline at end of file | ... | ... |
| ... | ... | @@ -0,0 +1,1003 @@ |
| 1 | +总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | |
| 2 | +先专注于推理的优化,最后再考虑服务化,支持一定程度的并发(比如4)的请求,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,使用N个prompt进行N个维度的分类,对于每个prompt,推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | |
| 3 | + | |
| 4 | +下面有一些参考技术资料,但是你并不需要严格,你应该有一定的灵活度,来追求极致的性能。 | |
| 5 | + | |
| 6 | +在 Tesla T4 上,用 3B 到 6B 级别的开源 decoder-only 基座模型做 query 分类。 | |
| 7 | +启动时完成 tokenizer、权重、prefix cache 和共享执行器准备工作。 | |
| 8 | +每次输入一个 query,输出每个 prompt 下每个 label 的分数分布,以及预测耗时和阶段耗时。 | |
| 9 | +不走通用生成路径,不做 decode,不取 full vocab logits,不做 constrained decode。 | |
| 10 | +对 multi-token label 做专门优化,避免 Python 侧串行 decode。 | |
| 11 | +prompt 和 label 集合必须可配置(目前只有两个,以后我会加到8个,每次请求输入一个query,并行的调用8个prompt进行推理得到得分最高的label: | |
| 12 | + | |
| 13 | + | |
| 14 | +prompts: | |
| 15 | + - name: category | |
| 16 | + prompt_template: | | |
| 17 | + <|im_start|>system | |
| 18 | + Analyze the category intent. Output exactly one label from [none, dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other]. Use 'none' if no category intent.<|im_end|> | |
| 19 | + <|im_start|>user | |
| 20 | + query: {query}<|im_end|> | |
| 21 | + <|im_start|>assistant | |
| 22 | + label: | |
| 23 | + label_prefix: " " | |
| 24 | + labels: [dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other, none] | |
| 25 | + | |
| 26 | + - name: audience | |
| 27 | + prompt_template: | | |
| 28 | + <|im_start|>system | |
| 29 | + Analyze the target user group. Output exactly one label from [none, boy, girl, man, woman, pregnant]. Use 'none' if no audience mentioned.<|im_end|> | |
| 30 | + <|im_start|>user | |
| 31 | + query: {query}<|im_end|> | |
| 32 | + <|im_start|>assistant | |
| 33 | + label: | |
| 34 | + label_prefix: " " | |
| 35 | + labels: [boy, girl, man, woman, pregnant, none] | |
| 36 | + | |
| 37 | + | |
| 38 | +做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟。 | |
| 39 | + | |
| 40 | +主要考虑优化方向为: | |
| 41 | +1. hidden_last -> N-class scorer -> argmax | |
| 42 | +2. 参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | |
| 43 | +3. 去 full vocab logits | |
| 44 | +4. 去 decode / constrained decode | |
| 45 | +5. 专用 tail kernel(输出 N 类原始分数) | |
| 46 | +6. 配置的N个 prompt推理要并行推理(2-8个) | |
| 47 | +7. 使用Tesla T4,因此不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | |
| 48 | +多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | |
| 49 | + | |
| 50 | + | |
| 51 | +你有sudo权限,你可以执行为本项目安装自己的环境 | |
| 52 | + | |
| 53 | +使用Qwen/Qwen3-8B的Q4或Q8模型,具体用哪个版本,请你查找huggingface相关资料,选择合适的版本完成部署,并进行推理耗时的测试。 | |
| 54 | + | |
| 55 | +请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | |
| 56 | + | |
| 57 | +一个重要的问题:一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况。 | |
| 58 | +需要通过多 token 标签做极致的性能优化,避免串行decode。 | |
| 59 | +我们最终目的是得到哪个label的得分最高,不一定要精确的概率,计算log P(id1 | query, prompt) + log P(id2 | query, prompt, id1)有可能导致难以优化性能,精确的概率是可以考虑放弃的,要清楚我们的最终目的,达到分类的目的即可,只要得到分类,优先考虑性能,精确的概率可以放下。 | |
| 60 | +如何通过一次模型 forward处理包括多token label的整个 batch,是你需要探索的问题。 | |
| 61 | + | |
| 62 | +单 token fast path 的做法比较确定: last_hidden -> small class scorer -> argmax。 只取目标 label 对应 LM head 行,不做 full vocab 输出。 | |
| 63 | +multi-token 怎么做需要搜索相关资料进行考量,最好要做到跟单token开销相同(放弃精确的log-prob的前提下。但是:多token和单token的label的打分的对比,一定要是可比的才能正确的分类,兼顾性能和打分的准确性) | |
| 64 | + | |
| 65 | +还需要增加一个配置:force_single_token_labels,所有 label 都按首 token 处理,因为,如果各个label收token不同,那么可以近似的认为首token打分代表整个label打分。 | |
| 66 | +你需要找到多label打分性能和准确性上面的最佳实践。同时也支持force_single_token_labels以达到极致的性能。 | |
| 67 | + | |
| 68 | +也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 | |
| 69 | +以下是一些参考示例: | |
| 70 | +vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ | |
| 71 | +PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ | |
| 72 | +TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | |
| 73 | +Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq | |
| 74 | + | |
| 75 | +TensorRT support matrix: T4 / SM7.5 supports FP16 and INT8, but not BF16/FP8 in the main matrix. https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Support-Matrix-Guide.pdf | |
| 76 | +TensorRT-LLM support matrix: current official hardware list omits Turing/T4, so T4 is effectively community-support territory there. https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | |
| 77 | +FlashAttention repo: FA3 is Hopper-focused; current published benchmarks are A100/H100-centric. https://github.com/Dao-AILab/flash-attention | |
| 78 | +vLLM APC docs: KV reuse via hashed KV blocks was the baseline idea for the prefix-cache metadata. https://docs.vllm.ai/_/downloads/en/v0.6.2/pdf/ | |
| 79 | +SGLang HiCache/RadixAttention docs: useful reference for prefix-cache reuse and page-granular KV organization. https://docs.sglang.io/advanced_features/hicache_design.html | |
| 80 | +FasterTransformer repo: still a useful T4 FP16 optimization baseline and historical Turing-oriented reference. https://github.com/NVIDIA/FasterTransformer | |
| 81 | +xFormers README: relevant as a Turing-friendly attention alternative; my mainline choice here is PyTorch SDPA on T4, which is an engineering inference from these sources rather than a direct vendor recommendation. https://github.com/facebookresearch/xformers | |
| 82 | + | |
| 83 | +注意:已经有一个项目 llm-qp, llm-qp2,这两个项目,对于单token的处理方式是可以的: | |
| 84 | +SDPA | |
| 85 | +prefix cache | |
| 86 | +prebuilt bucket + CUDA graph | |
| 87 | +他的核心代码是: | |
| 88 | +from __future__ import annotations | |
| 89 | + | |
| 90 | +import hashlib | |
| 91 | +import time | |
| 92 | +from dataclasses import asdict, dataclass | |
| 93 | +from typing import Iterable | |
| 94 | + | |
| 95 | +import torch | |
| 96 | +import torch.nn.functional as F | |
| 97 | +from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache | |
| 98 | + | |
| 99 | +try: | |
| 100 | + from transformers import BitsAndBytesConfig | |
| 101 | +except ImportError: # pragma: no cover | |
| 102 | + BitsAndBytesConfig = None | |
| 103 | + | |
| 104 | +from llm_qp.config import PromptTaskConfig, RuntimeConfig | |
| 105 | +from llm_qp.scorer import SmallClassScorer | |
| 106 | + | |
| 107 | +try: | |
| 108 | + from torch.nn.attention import SDPBackend, sdpa_kernel | |
| 109 | +except ImportError: # pragma: no cover | |
| 110 | + SDPBackend = None | |
| 111 | + sdpa_kernel = None | |
| 112 | + | |
| 113 | + | |
| 114 | +@dataclass(slots=True) | |
| 115 | +class EncodedLabel: | |
| 116 | + text: str | |
| 117 | + token_ids: list[int] | |
| 118 | + | |
| 119 | + | |
| 120 | +@dataclass(slots=True) | |
| 121 | +class PrefixCache: | |
| 122 | + prefix_ids: list[int] | |
| 123 | + prefix_hashes: list[str] | |
| 124 | + raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] | |
| 125 | + | |
| 126 | + @property | |
| 127 | + def prefix_len(self) -> int: | |
| 128 | + return len(self.prefix_ids) | |
| 129 | + | |
| 130 | + | |
| 131 | +@dataclass(slots=True) | |
| 132 | +class MultiTokenTables: | |
| 133 | + label_token_ids: torch.Tensor | |
| 134 | + label_token_mask: torch.Tensor | |
| 135 | + label_prefix_ids: torch.Tensor | |
| 136 | + label_prefix_mask: torch.Tensor | |
| 137 | + label_position_offsets: torch.Tensor | |
| 138 | + | |
| 139 | + @property | |
| 140 | + def max_label_len(self) -> int: | |
| 141 | + return self.label_token_ids.shape[1] | |
| 142 | + | |
| 143 | + @property | |
| 144 | + def max_label_prefix_len(self) -> int: | |
| 145 | + return self.label_prefix_ids.shape[1] | |
| 146 | + | |
| 147 | + | |
| 148 | +@dataclass(slots=True) | |
| 149 | +class QueryScoreResult: | |
| 150 | + task_name: str | |
| 151 | + query: str | |
| 152 | + predicted_label: str | |
| 153 | + scores: list[tuple[str, float, float]] | |
| 154 | + total_ms: float | |
| 155 | + stage_ms: dict[str, float] | |
| 156 | + fast_path: bool | |
| 157 | + prefix_tokens: int | |
| 158 | + continuation_tokens: int | |
| 159 | + label_token_lengths: dict[str, int] | |
| 160 | + | |
| 161 | + @property | |
| 162 | + def predicted_prob(self) -> float: | |
| 163 | + for label, _score, prob in self.scores: | |
| 164 | + if label == self.predicted_label: | |
| 165 | + return prob | |
| 166 | + return 0.0 | |
| 167 | + | |
| 168 | + | |
| 169 | +@dataclass(slots=True) | |
| 170 | +class MultiPromptScoreResult: | |
| 171 | + query: str | |
| 172 | + total_ms: float | |
| 173 | + details: list[QueryScoreResult] | |
| 174 | + stage_ms: dict[str, float] | |
| 175 | + | |
| 176 | + def http_json(self) -> dict[str, object]: | |
| 177 | + return { | |
| 178 | + "query": self.query, | |
| 179 | + "total_ms": self.total_ms, | |
| 180 | + "stage_ms": self.stage_ms, | |
| 181 | + "details": [asdict(t) for t in self.details], | |
| 182 | + "task_results": { | |
| 183 | + t.task_name: [t.predicted_label, t.continuation_tokens, t.predicted_prob] for t in self.details if t.predicted_label != 'none' | |
| 184 | + }, | |
| 185 | + } | |
| 186 | + | |
| 187 | + | |
| 188 | +@dataclass(slots=True) | |
| 189 | +class BatchScoreResult: | |
| 190 | + batch_size: int | |
| 191 | + total_ms: float | |
| 192 | + results: list[MultiPromptScoreResult] | |
| 193 | + stage_ms: dict[str, float] | |
| 194 | + | |
| 195 | + | |
| 196 | +@dataclass(slots=True) | |
| 197 | +class SharedRuntime: | |
| 198 | + device: torch.device | |
| 199 | + dtype: torch.dtype | |
| 200 | + tokenizer: object | |
| 201 | + model: object | |
| 202 | + backbone: object | |
| 203 | + hidden_size: int | |
| 204 | + graph_capture_pool: object | None = None | |
| 205 | + graph_capture_stream: torch.cuda.Stream | None = None | |
| 206 | + | |
| 207 | + | |
| 208 | +@dataclass(slots=True) | |
| 209 | +class PromptBatchPlan: | |
| 210 | + runner: "PromptClassifierRunner" | |
| 211 | + row_start: int | |
| 212 | + row_count: int | |
| 213 | + score_buffer: torch.Tensor | |
| 214 | + | |
| 215 | + @property | |
| 216 | + def row_stop(self) -> int: | |
| 217 | + return self.row_start + self.row_count | |
| 218 | + | |
| 219 | + | |
| 220 | +@dataclass(slots=True) | |
| 221 | +class MixedPrefixCache: | |
| 222 | + batch_size: int | |
| 223 | + total_rows: int | |
| 224 | + prefix_lengths: torch.Tensor | |
| 225 | + attention_mask: torch.Tensor | |
| 226 | + raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] | |
| 227 | + | |
| 228 | + @property | |
| 229 | + def max_prefix_len(self) -> int: | |
| 230 | + return int(self.prefix_lengths.max().item()) if self.prefix_lengths.numel() else 0 | |
| 231 | + | |
| 232 | + | |
| 233 | +@dataclass(slots=True) | |
| 234 | +class BatchLayout: | |
| 235 | + batch_size: int | |
| 236 | + total_rows: int | |
| 237 | + plans: list[PromptBatchPlan] | |
| 238 | + | |
| 239 | + | |
| 240 | +@dataclass(slots=True) | |
| 241 | +class MixedBucketRuntime: | |
| 242 | + batch_size: int | |
| 243 | + total_rows: int | |
| 244 | + continuation_len: int | |
| 245 | + max_input_len: int | |
| 246 | + input_ids: torch.Tensor | |
| 247 | + attention_mask: torch.Tensor | |
| 248 | + position_ids: torch.Tensor | |
| 249 | + last_hidden_state: torch.Tensor | |
| 250 | + graph: torch.cuda.CUDAGraph | None = None | |
| 251 | + | |
| 252 | + | |
| 253 | +@dataclass(slots=True) | |
| 254 | +class PreloadReport: | |
| 255 | + total_ms: float | |
| 256 | + stage_ms: dict[str, float] | |
| 257 | + runtime: dict[str, object] | |
| 258 | + | |
| 259 | + | |
| 260 | +def _hash_blocks(token_ids: Iterable[int], block_size: int) -> list[str]: | |
| 261 | + token_list = list(token_ids) | |
| 262 | + hashes: list[str] = [] | |
| 263 | + for start in range(0, len(token_list), block_size): | |
| 264 | + block = token_list[start : start + block_size] | |
| 265 | + payload = ",".join(str(x) for x in block).encode("utf-8") | |
| 266 | + hashes.append(hashlib.sha1(payload).hexdigest()) | |
| 267 | + return hashes | |
| 268 | + | |
| 269 | + | |
| 270 | +def _expand_legacy_cache( | |
| 271 | + raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...], | |
| 272 | + batch_size: int, | |
| 273 | +) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: | |
| 274 | + expanded: list[tuple[torch.Tensor, torch.Tensor]] = [] | |
| 275 | + for key, value in raw_cache: | |
| 276 | + expanded.append( | |
| 277 | + ( | |
| 278 | + key.expand(batch_size, *key.shape[1:]).contiguous(), | |
| 279 | + value.expand(batch_size, *value.shape[1:]).contiguous(), | |
| 280 | + ) | |
| 281 | + ) | |
| 282 | + return tuple(expanded) | |
| 283 | + | |
| 284 | + | |
| 285 | +class PromptClassifierRunner: | |
| 286 | + def __init__( | |
| 287 | + self, | |
| 288 | + cfg: RuntimeConfig, | |
| 289 | + task_cfg: PromptTaskConfig, | |
| 290 | + shared_runtime: SharedRuntime, | |
| 291 | + ): | |
| 292 | + self.cfg = cfg | |
| 293 | + self.task_cfg = task_cfg | |
| 294 | + self.device = shared_runtime.device | |
| 295 | + self.dtype = shared_runtime.dtype | |
| 296 | + self.tokenizer = shared_runtime.tokenizer | |
| 297 | + self.model = shared_runtime.model | |
| 298 | + self.backbone = shared_runtime.backbone | |
| 299 | + self.hidden_size = shared_runtime.hidden_size | |
| 300 | + self.prefix_text, self.suffix_text = task_cfg.prompt_parts | |
| 301 | + self.prefix_ids = self.tokenizer.encode(self.prefix_text, add_special_tokens=False) | |
| 302 | + self.suffix_ids = self.tokenizer.encode(self.suffix_text, add_special_tokens=False) | |
| 303 | + self.labels = list(task_cfg.labels) | |
| 304 | + self.encoded_labels = [ | |
| 305 | + EncodedLabel(text=label, token_ids=self._encode_label_token_ids(label)) | |
| 306 | + for label in self.labels | |
| 307 | + ] | |
| 308 | + self.num_labels = len(self.labels) | |
| 309 | + self.lm_head = self.model.get_output_embeddings() | |
| 310 | + self.lm_head_weight = self.lm_head.weight.detach() | |
| 311 | + self.lm_head_bias = self.lm_head.bias.detach() if getattr(self.lm_head, "bias", None) is not None else None | |
| 312 | + if self.cfg.force_single_token_labels and not self._has_unique_single_token_labels(): | |
| 313 | + raise ValueError( | |
| 314 | + f"prompt task '{self.task_cfg.name}' cannot force single-token labels because first tokens collide" | |
| 315 | + ) | |
| 316 | + self.fast_path = self._has_unique_single_token_labels() | |
| 317 | + self.fast_path_token_ids = [item.token_ids[0] for item in self.encoded_labels] if self.fast_path else [] | |
| 318 | + self.scorer = self._build_scorer() if self.fast_path else None | |
| 319 | + self.multi_token_tables = self._build_multi_token_tables() if not self.fast_path else None | |
| 320 | + self.prefix_cache = self._build_prefix_cache() | |
| 321 | + | |
| 322 | + def _encode_label_token_ids(self, label: str) -> list[int]: | |
| 323 | + token_ids = self.tokenizer.encode( | |
| 324 | + f"{self.task_cfg.label_prefix}{label}", | |
| 325 | + add_special_tokens=False, | |
| 326 | + ) | |
| 327 | + if not token_ids: | |
| 328 | + raise ValueError(f"label '{label}' in prompt '{self.task_cfg.name}' tokenizes to an empty sequence") | |
| 329 | + if self.cfg.force_single_token_labels: | |
| 330 | + return token_ids[:1] | |
| 331 | + return token_ids | |
| 332 | + | |
| 333 | + def _has_unique_single_token_labels(self) -> bool: | |
| 334 | + token_ids: list[int] = [] | |
| 335 | + for item in self.encoded_labels: | |
| 336 | + if len(item.token_ids) != 1: | |
| 337 | + return False | |
| 338 | + token_ids.append(item.token_ids[0]) | |
| 339 | + return len(token_ids) == len(set(token_ids)) | |
| 340 | + | |
| 341 | + def _build_scorer(self) -> SmallClassScorer: | |
| 342 | + token_ids = torch.tensor(self.fast_path_token_ids, dtype=torch.long, device=self.device) | |
| 343 | + weights = torch.index_select(self.lm_head_weight, 0, token_ids).to(dtype=self.dtype).contiguous() | |
| 344 | + bias = None | |
| 345 | + if self.lm_head_bias is not None: | |
| 346 | + bias = torch.index_select(self.lm_head_bias, 0, token_ids).to(dtype=self.dtype).contiguous() | |
| 347 | + return SmallClassScorer(weights=weights, bias=bias) | |
| 348 | + | |
| 349 | + def _build_multi_token_tables(self) -> MultiTokenTables: | |
| 350 | + max_label_len = max(len(item.token_ids) for item in self.encoded_labels) | |
| 351 | + max_prefix_len = max(len(item.token_ids) - 1 for item in self.encoded_labels) | |
| 352 | + label_token_ids = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.long) | |
| 353 | + label_token_mask = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.float32) | |
| 354 | + label_prefix_ids = torch.full( | |
| 355 | + (self.num_labels, max_prefix_len), | |
| 356 | + fill_value=self.tokenizer.pad_token_id, | |
| 357 | + device=self.device, | |
| 358 | + dtype=torch.long, | |
| 359 | + ) | |
| 360 | + label_prefix_mask = torch.zeros((self.num_labels, max_prefix_len), device=self.device, dtype=torch.long) | |
| 361 | + for idx, item in enumerate(self.encoded_labels): | |
| 362 | + token_ids = torch.tensor(item.token_ids, device=self.device, dtype=torch.long) | |
| 363 | + token_len = token_ids.numel() | |
| 364 | + label_token_ids[idx, :token_len] = token_ids | |
| 365 | + label_token_mask[idx, :token_len] = 1.0 | |
| 366 | + if token_len > 1: | |
| 367 | + prefix_len = token_len - 1 | |
| 368 | + label_prefix_ids[idx, :prefix_len] = token_ids[:-1] | |
| 369 | + label_prefix_mask[idx, :prefix_len] = 1 | |
| 370 | + return MultiTokenTables( | |
| 371 | + label_token_ids=label_token_ids.contiguous(), | |
| 372 | + label_token_mask=label_token_mask.contiguous(), | |
| 373 | + label_prefix_ids=label_prefix_ids.contiguous(), | |
| 374 | + label_prefix_mask=label_prefix_mask.contiguous(), | |
| 375 | + label_position_offsets=torch.arange(max_label_len, device=self.device, dtype=torch.long), | |
| 376 | + ) | |
| 377 | + | |
| 378 | + @torch.inference_mode() | |
| 379 | + def _build_prefix_cache(self) -> PrefixCache: | |
| 380 | + if not self.prefix_ids: | |
| 381 | + return PrefixCache(prefix_ids=[], prefix_hashes=[], raw_cache=tuple()) | |
| 382 | + prefix_tensor = torch.tensor([self.prefix_ids], dtype=torch.long, device=self.device) | |
| 383 | + attention_mask = torch.ones_like(prefix_tensor, dtype=torch.long, device=self.device) | |
| 384 | + outputs = self.model( | |
| 385 | + input_ids=prefix_tensor, | |
| 386 | + attention_mask=attention_mask, | |
| 387 | + use_cache=True, | |
| 388 | + return_dict=True, | |
| 389 | + ) | |
| 390 | + raw_cache = tuple( | |
| 391 | + (layer.keys.detach(), layer.values.detach()) | |
| 392 | + for layer in outputs.past_key_values.layers | |
| 393 | + ) | |
| 394 | + return PrefixCache( | |
| 395 | + prefix_ids=list(self.prefix_ids), | |
| 396 | + prefix_hashes=_hash_blocks(self.prefix_ids, self.cfg.prefix_block_size), | |
| 397 | + raw_cache=raw_cache, | |
| 398 | + ) | |
| 399 | + | |
| 400 | + def expand_prefix_raw_cache(self, batch_size: int) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: | |
| 401 | + if not self.prefix_cache.raw_cache: | |
| 402 | + return tuple() | |
| 403 | + return _expand_legacy_cache(self.prefix_cache.raw_cache, batch_size) | |
| 404 | + | |
| 405 | + def build_continuation_from_query_ids(self, query_ids: list[int]) -> list[int]: | |
| 406 | + continuation = query_ids + self.suffix_ids | |
| 407 | + if not continuation: | |
| 408 | + raise ValueError("prompt continuation is empty after substituting query") | |
| 409 | + if self.prefix_cache.prefix_len + len(continuation) > self.cfg.max_length: | |
| 410 | + raise ValueError( | |
| 411 | + f"sequence length {self.prefix_cache.prefix_len + len(continuation)} exceeds max_length={self.cfg.max_length}" | |
| 412 | + ) | |
| 413 | + return continuation | |
| 414 | + | |
| 415 | + @torch.inference_mode() | |
| 416 | + def reduce_fast_scores( | |
| 417 | + self, | |
| 418 | + hidden: torch.Tensor, | |
| 419 | + out_scores: torch.Tensor, | |
| 420 | + ) -> None: | |
| 421 | + assert self.scorer is not None | |
| 422 | + out_scores.copy_(self.scorer(hidden)) | |
| 423 | + | |
| 424 | + @torch.inference_mode() | |
| 425 | + def reduce_multi_token_scores( | |
| 426 | + self, | |
| 427 | + last_hidden_state: torch.Tensor, | |
| 428 | + batch_size: int, | |
| 429 | + max_input_len: int, | |
| 430 | + score_positions: torch.Tensor, | |
| 431 | + out_scores: torch.Tensor, | |
| 432 | + ) -> None: | |
| 433 | + assert self.multi_token_tables is not None | |
| 434 | + hidden = last_hidden_state.reshape(batch_size, self.num_labels, max_input_len, self.hidden_size) | |
| 435 | + gather_index = score_positions[:, None, :, None].expand( | |
| 436 | + batch_size, | |
| 437 | + self.num_labels, | |
| 438 | + self.multi_token_tables.max_label_len, | |
| 439 | + self.hidden_size, | |
| 440 | + ) | |
| 441 | + gathered_hidden = torch.gather(hidden, 2, gather_index) | |
| 442 | + used_mask = self.multi_token_tables.label_token_mask.unsqueeze(0).expand(batch_size, -1, -1).bool() | |
| 443 | + | |
| 444 | + token_log_probs = torch.zeros( | |
| 445 | + (batch_size, self.num_labels, self.multi_token_tables.max_label_len), | |
| 446 | + device=self.device, | |
| 447 | + dtype=torch.float32, | |
| 448 | + ) | |
| 449 | + if used_mask.any(): | |
| 450 | + flat_hidden = gathered_hidden[used_mask] | |
| 451 | + flat_token_ids = self.multi_token_tables.label_token_ids.unsqueeze(0).expand(batch_size, -1, -1)[used_mask] | |
| 452 | + linear_hidden = flat_hidden.to(self.dtype) if self.device.type == "cuda" else flat_hidden.float() | |
| 453 | + linear_weight = self.lm_head_weight if self.device.type == "cuda" else self.lm_head_weight.float() | |
| 454 | + linear_bias = self.lm_head_bias | |
| 455 | + if linear_bias is not None and self.device.type != "cuda": | |
| 456 | + linear_bias = linear_bias.float() | |
| 457 | + flat_logits = F.linear(linear_hidden, linear_weight, linear_bias) | |
| 458 | + flat_selected = flat_logits.gather(1, flat_token_ids.unsqueeze(1)).squeeze(1).float() | |
| 459 | + flat_log_norm = torch.logsumexp(flat_logits.float(), dim=-1) | |
| 460 | + token_log_probs[used_mask] = flat_selected - flat_log_norm | |
| 461 | + out_scores.copy_(token_log_probs.sum(dim=-1)) | |
| 462 | + | |
| 463 | + def build_score_result( | |
| 464 | + self, | |
| 465 | + query: str, | |
| 466 | + scores: torch.Tensor, | |
| 467 | + stage_ms: dict[str, float], | |
| 468 | + continuation_tokens: int, | |
| 469 | + ) -> QueryScoreResult: | |
| 470 | + score_values = scores.detach().float().cpu().tolist() | |
| 471 | + best_idx = max(range(len(score_values)), key=score_values.__getitem__) | |
| 472 | + probs = torch.softmax(torch.tensor(score_values, dtype=torch.float32), dim=0).tolist() | |
| 473 | + return QueryScoreResult( | |
| 474 | + task_name=self.task_cfg.name, | |
| 475 | + query=query, | |
| 476 | + predicted_label=self.labels[best_idx], | |
| 477 | + scores=[ | |
| 478 | + (label, score, prob) | |
| 479 | + for label, score, prob in zip(self.labels, score_values, probs, strict=True) | |
| 480 | + ], | |
| 481 | + total_ms=sum(stage_ms.values()), | |
| 482 | + stage_ms=stage_ms, | |
| 483 | + fast_path=self.fast_path, | |
| 484 | + prefix_tokens=self.prefix_cache.prefix_len, | |
| 485 | + continuation_tokens=continuation_tokens, | |
| 486 | + label_token_lengths={item.text: len(item.token_ids) for item in self.encoded_labels}, | |
| 487 | + ) | |
| 488 | + | |
| 489 | + | |
| 490 | +class MultiPromptRunner: | |
| 491 | + def __init__(self, cfg: RuntimeConfig): | |
| 492 | + self.cfg = cfg | |
| 493 | + t0 = time.perf_counter() | |
| 494 | + self.shared_runtime = self.build_shared_runtime(cfg) | |
| 495 | + t1 = time.perf_counter() | |
| 496 | + self.device = self.shared_runtime.device | |
| 497 | + self.dtype = self.shared_runtime.dtype | |
| 498 | + self.tokenizer = self.shared_runtime.tokenizer | |
| 499 | + self.model = self.shared_runtime.model | |
| 500 | + self.backbone = self.shared_runtime.backbone | |
| 501 | + self.hidden_size = self.shared_runtime.hidden_size | |
| 502 | + self.graph_capture_pool = self.shared_runtime.graph_capture_pool | |
| 503 | + self.graph_capture_stream = self.shared_runtime.graph_capture_stream | |
| 504 | + self.runners = [ | |
| 505 | + PromptClassifierRunner(cfg=cfg, task_cfg=task_cfg, shared_runtime=self.shared_runtime) | |
| 506 | + for task_cfg in cfg.tasks | |
| 507 | + ] | |
| 508 | + t2 = time.perf_counter() | |
| 509 | + self.batch_layouts = {batch_size: self._build_batch_layout(batch_size) for batch_size in self.cfg.batch_sizes} | |
| 510 | + t3 = time.perf_counter() | |
| 511 | + self.mixed_prefix_caches = { | |
| 512 | + batch_size: self._build_mixed_prefix_cache(self.batch_layouts[batch_size]) | |
| 513 | + for batch_size in self.cfg.batch_sizes | |
| 514 | + } | |
| 515 | + t4 = time.perf_counter() | |
| 516 | + self.max_label_prefix_len = max( | |
| 517 | + (runner.multi_token_tables.max_label_prefix_len if runner.multi_token_tables is not None else 0) | |
| 518 | + for runner in self.runners | |
| 519 | + ) | |
| 520 | + self.mixed_buckets = { | |
| 521 | + (batch_size, continuation_len): self._build_mixed_bucket( | |
| 522 | + self.batch_layouts[batch_size], | |
| 523 | + self.mixed_prefix_caches[batch_size], | |
| 524 | + continuation_len, | |
| 525 | + ) | |
| 526 | + for batch_size in self.cfg.batch_sizes | |
| 527 | + for continuation_len in self.cfg.continuation_buckets | |
| 528 | + } | |
| 529 | + t5 = time.perf_counter() | |
| 530 | + self._warmup_results: dict[int, BatchScoreResult] = {} | |
| 531 | + self._preload_report: PreloadReport | None = None | |
| 532 | + self._init_stage_ms = { | |
| 533 | + "load_model_and_tokenizer": (t1 - t0) * 1000.0, | |
| 534 | + "build_prompt_runtimes": (t2 - t1) * 1000.0, | |
| 535 | + "build_batch_layouts": (t3 - t2) * 1000.0, | |
| 536 | + "build_mixed_prefix_caches": (t4 - t3) * 1000.0, | |
| 537 | + "build_mixed_buckets_and_graphs": (t5 - t4) * 1000.0, | |
| 538 | + } | |
| 539 | + self._init_total_ms = sum(self._init_stage_ms.values()) | |
| 540 | + | |
| 541 | + @staticmethod | |
| 542 | + def build_shared_runtime(cfg: RuntimeConfig) -> SharedRuntime: | |
| 543 | + device = torch.device(cfg.device) | |
| 544 | + dtype = torch.float16 | |
| 545 | + tokenizer = AutoTokenizer.from_pretrained( | |
| 546 | + cfg.resolved_model_source, | |
| 547 | + trust_remote_code=cfg.resolved_trust_remote_code, | |
| 548 | + token=cfg.hf_token, | |
| 549 | + cache_dir=cfg.hf_cache_dir, | |
| 550 | + local_files_only=cfg.resolved_local_files_only, | |
| 551 | + ) | |
| 552 | + if tokenizer.pad_token_id is None: | |
| 553 | + tokenizer.pad_token = tokenizer.eos_token | |
| 554 | + attn_impl = MultiPromptRunner._resolve_attn_impl(cfg.attn_backend) | |
| 555 | + quantization_config = None | |
| 556 | + model_kwargs: dict[str, object] = { | |
| 557 | + "trust_remote_code": cfg.resolved_trust_remote_code, | |
| 558 | + "attn_implementation": attn_impl, | |
| 559 | + "token": cfg.hf_token, | |
| 560 | + "cache_dir": cfg.hf_cache_dir, | |
| 561 | + "local_files_only": cfg.resolved_local_files_only, | |
| 562 | + } | |
| 563 | + if cfg.load_in_4bit: | |
| 564 | + if BitsAndBytesConfig is None: | |
| 565 | + raise ImportError("transformers BitsAndBytesConfig is unavailable; install bitsandbytes support first") | |
| 566 | + quantization_config = BitsAndBytesConfig( | |
| 567 | + load_in_4bit=True, | |
| 568 | + bnb_4bit_compute_dtype=dtype, | |
| 569 | + bnb_4bit_quant_type=cfg.bnb_4bit_quant_type, | |
| 570 | + bnb_4bit_use_double_quant=cfg.bnb_4bit_use_double_quant, | |
| 571 | + ) | |
| 572 | + model_kwargs["quantization_config"] = quantization_config | |
| 573 | + model_kwargs["device_map"] = {"": device.index or 0} | |
| 574 | + else: | |
| 575 | + model_kwargs["dtype"] = dtype | |
| 576 | + model_kwargs["device_map"] = None | |
| 577 | + model = AutoModelForCausalLM.from_pretrained( | |
| 578 | + cfg.resolved_model_source, | |
| 579 | + **model_kwargs, | |
| 580 | + ).eval() | |
| 581 | + if not cfg.load_in_4bit: | |
| 582 | + model = model.to(device) | |
| 583 | + backbone = model.get_submodule(model.base_model_prefix) | |
| 584 | + hidden_size = model.get_output_embeddings().weight.shape[1] | |
| 585 | + graph_capture_pool = None | |
| 586 | + graph_capture_stream = None | |
| 587 | + if device.type == "cuda" and torch.cuda.is_available() and cfg.cuda_graphs and not cfg.load_in_4bit: | |
| 588 | + graph_capture_pool = torch.cuda.graph_pool_handle() | |
| 589 | + graph_capture_stream = torch.cuda.Stream(device=device) | |
| 590 | + return SharedRuntime( | |
| 591 | + device=device, | |
| 592 | + dtype=dtype, | |
| 593 | + tokenizer=tokenizer, | |
| 594 | + model=model, | |
| 595 | + backbone=backbone, | |
| 596 | + hidden_size=hidden_size, | |
| 597 | + graph_capture_pool=graph_capture_pool, | |
| 598 | + graph_capture_stream=graph_capture_stream, | |
| 599 | + ) | |
| 600 | + | |
| 601 | + @staticmethod | |
| 602 | + def _resolve_attn_impl(requested: str) -> str: | |
| 603 | + if requested in {"sdpa", "eager"}: | |
| 604 | + return requested | |
| 605 | + if requested == "auto": | |
| 606 | + return "sdpa" | |
| 607 | + raise ValueError(f"unsupported attn_backend: {requested}") | |
| 608 | + | |
| 609 | + def _attn_context(self): | |
| 610 | + if sdpa_kernel is not None and self.cfg.attn_backend in {"auto", "sdpa"}: | |
| 611 | + return sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]) | |
| 612 | + return torch.no_grad() | |
| 613 | + | |
| 614 | + def _sync(self) -> None: | |
| 615 | + if self.device.type == "cuda": | |
| 616 | + torch.cuda.synchronize() | |
| 617 | + | |
| 618 | + def _pick_bucket(self, continuation_len: int) -> int: | |
| 619 | + for bucket in self.cfg.continuation_buckets: | |
| 620 | + if continuation_len <= bucket: | |
| 621 | + return bucket | |
| 622 | + if self.cfg.pad_to_bucket: | |
| 623 | + raise ValueError( | |
| 624 | + f"continuation length {continuation_len} exceeds configured buckets; extend continuation_buckets" | |
| 625 | + ) | |
| 626 | + return continuation_len | |
| 627 | + | |
| 628 | + def _build_batch_layout(self, batch_size: int) -> BatchLayout: | |
| 629 | + plans: list[PromptBatchPlan] = [] | |
| 630 | + row_start = 0 | |
| 631 | + for runner in self.runners: | |
| 632 | + row_count = batch_size if runner.fast_path else batch_size * runner.num_labels | |
| 633 | + plans.append( | |
| 634 | + PromptBatchPlan( | |
| 635 | + runner=runner, | |
| 636 | + row_start=row_start, | |
| 637 | + row_count=row_count, | |
| 638 | + score_buffer=torch.empty((batch_size, runner.num_labels), device=self.device, dtype=torch.float32), | |
| 639 | + ) | |
| 640 | + ) | |
| 641 | + row_start += row_count | |
| 642 | + return BatchLayout(batch_size=batch_size, total_rows=row_start, plans=plans) | |
| 643 | + | |
| 644 | + def _build_mixed_prefix_cache(self, layout: BatchLayout) -> MixedPrefixCache: | |
| 645 | + prefix_lengths = torch.zeros((layout.total_rows,), device=self.device, dtype=torch.long) | |
| 646 | + non_empty = [plan.runner.prefix_cache.raw_cache for plan in layout.plans if plan.runner.prefix_cache.raw_cache] | |
| 647 | + if not non_empty: | |
| 648 | + return MixedPrefixCache( | |
| 649 | + batch_size=layout.batch_size, | |
| 650 | + total_rows=layout.total_rows, | |
| 651 | + prefix_lengths=prefix_lengths, | |
| 652 | + attention_mask=torch.zeros((layout.total_rows, 0), device=self.device, dtype=torch.long), | |
| 653 | + raw_cache=tuple(), | |
| 654 | + ) | |
| 655 | + | |
| 656 | + max_prefix_len = max(plan.runner.prefix_cache.prefix_len for plan in layout.plans) | |
| 657 | + num_layers = len(non_empty[0]) | |
| 658 | + attention_mask = torch.zeros((layout.total_rows, max_prefix_len), device=self.device, dtype=torch.long) | |
| 659 | + raw_layers: list[tuple[torch.Tensor, torch.Tensor]] = [] | |
| 660 | + for layer_idx in range(num_layers): | |
| 661 | + sample_key, sample_value = non_empty[0][layer_idx] | |
| 662 | + merged_key = sample_key.new_zeros( | |
| 663 | + (layout.total_rows, sample_key.shape[1], max_prefix_len, sample_key.shape[3]) | |
| 664 | + ) | |
| 665 | + merged_value = sample_value.new_zeros( | |
| 666 | + (layout.total_rows, sample_value.shape[1], max_prefix_len, sample_value.shape[3]) | |
| 667 | + ) | |
| 668 | + raw_layers.append((merged_key, merged_value)) | |
| 669 | + | |
| 670 | + for plan in layout.plans: | |
| 671 | + runner = plan.runner | |
| 672 | + prefix_len = runner.prefix_cache.prefix_len | |
| 673 | + row_slice = slice(plan.row_start, plan.row_stop) | |
| 674 | + prefix_lengths[row_slice] = prefix_len | |
| 675 | + if prefix_len == 0: | |
| 676 | + continue | |
| 677 | + attention_mask[row_slice, :prefix_len] = 1 | |
| 678 | + raw_cache = runner.expand_prefix_raw_cache(plan.row_count) | |
| 679 | + for layer_idx, (key, value) in enumerate(raw_cache): | |
| 680 | + merged_key, merged_value = raw_layers[layer_idx] | |
| 681 | + merged_key[row_slice, :, :prefix_len, :] = key | |
| 682 | + merged_value[row_slice, :, :prefix_len, :] = value | |
| 683 | + | |
| 684 | + return MixedPrefixCache( | |
| 685 | + batch_size=layout.batch_size, | |
| 686 | + total_rows=layout.total_rows, | |
| 687 | + prefix_lengths=prefix_lengths, | |
| 688 | + attention_mask=attention_mask.contiguous(), | |
| 689 | + raw_cache=tuple(raw_layers), | |
| 690 | + ) | |
| 691 | + | |
| 692 | + def _build_mixed_bucket( | |
| 693 | + self, | |
| 694 | + layout: BatchLayout, | |
| 695 | + prefix_cache: MixedPrefixCache, | |
| 696 | + continuation_len: int, | |
| 697 | + ) -> MixedBucketRuntime: | |
| 698 | + max_input_len = continuation_len + self.max_label_prefix_len | |
| 699 | + total_len = prefix_cache.max_prefix_len + max_input_len | |
| 700 | + input_ids = torch.full( | |
| 701 | + (layout.total_rows, max_input_len), | |
| 702 | + fill_value=self.tokenizer.pad_token_id, | |
| 703 | + device=self.device, | |
| 704 | + dtype=torch.long, | |
| 705 | + ) | |
| 706 | + attention_mask = torch.zeros((layout.total_rows, total_len), device=self.device, dtype=torch.long) | |
| 707 | + if prefix_cache.max_prefix_len: | |
| 708 | + attention_mask[:, : prefix_cache.max_prefix_len] = prefix_cache.attention_mask | |
| 709 | + position_ids = ( | |
| 710 | + prefix_cache.prefix_lengths[:, None] | |
| 711 | + + torch.arange(max_input_len, device=self.device, dtype=torch.long).unsqueeze(0) | |
| 712 | + ).contiguous() | |
| 713 | + last_hidden_state = torch.empty( | |
| 714 | + (layout.total_rows, max_input_len, self.hidden_size), | |
| 715 | + device=self.device, | |
| 716 | + dtype=self.dtype, | |
| 717 | + ) | |
| 718 | + bucket = MixedBucketRuntime( | |
| 719 | + batch_size=layout.batch_size, | |
| 720 | + total_rows=layout.total_rows, | |
| 721 | + continuation_len=continuation_len, | |
| 722 | + max_input_len=max_input_len, | |
| 723 | + input_ids=input_ids, | |
| 724 | + attention_mask=attention_mask, | |
| 725 | + position_ids=position_ids, | |
| 726 | + last_hidden_state=last_hidden_state, | |
| 727 | + ) | |
| 728 | + if self.cfg.cuda_graphs: | |
| 729 | + self._capture_mixed_bucket(bucket, prefix_cache) | |
| 730 | + return bucket | |
| 731 | + | |
| 732 | + @torch.inference_mode() | |
| 733 | + def _run_mixed_backbone( | |
| 734 | + self, | |
| 735 | + bucket: MixedBucketRuntime, | |
| 736 | + prefix_cache: MixedPrefixCache, | |
| 737 | + ) -> None: | |
| 738 | + cache = DynamicCache(ddp_cache_data=prefix_cache.raw_cache, config=self.model.config) | |
| 739 | + with self._attn_context(): | |
| 740 | + outputs = self.backbone( | |
| 741 | + input_ids=bucket.input_ids, | |
| 742 | + attention_mask=bucket.attention_mask, | |
| 743 | + position_ids=bucket.position_ids, | |
| 744 | + past_key_values=cache, | |
| 745 | + use_cache=False, | |
| 746 | + return_dict=True, | |
| 747 | + ) | |
| 748 | + bucket.last_hidden_state.copy_(outputs.last_hidden_state) | |
| 749 | + | |
| 750 | + def _capture_mixed_bucket(self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache) -> None: | |
| 751 | + if not torch.cuda.is_available(): | |
| 752 | + return | |
| 753 | + try: | |
| 754 | + torch.cuda.synchronize() | |
| 755 | + stream = self.graph_capture_stream or torch.cuda.Stream(device=self.device) | |
| 756 | + with torch.cuda.stream(stream): | |
| 757 | + for _ in range(self.cfg.graph_warmups): | |
| 758 | + self._run_mixed_backbone(bucket, prefix_cache) | |
| 759 | + stream.synchronize() | |
| 760 | + graph = torch.cuda.CUDAGraph() | |
| 761 | + with torch.cuda.graph(graph, pool=self.graph_capture_pool, stream=stream): | |
| 762 | + self._run_mixed_backbone(bucket, prefix_cache) | |
| 763 | + bucket.graph = graph | |
| 764 | + except RuntimeError: | |
| 765 | + bucket.graph = None | |
| 766 | + | |
| 767 | + def _prepare_bucket( | |
| 768 | + self, | |
| 769 | + layout: BatchLayout, | |
| 770 | + prefix_cache: MixedPrefixCache, | |
| 771 | + bucket: MixedBucketRuntime, | |
| 772 | + query_ids_batch: list[list[int]], | |
| 773 | + ) -> tuple[list[list[int]], dict[str, list[object]]]: | |
| 774 | + del prefix_cache | |
| 775 | + bucket.input_ids.fill_(self.tokenizer.pad_token_id) | |
| 776 | + bucket.attention_mask.zero_() | |
| 777 | + if self.mixed_prefix_caches[layout.batch_size].max_prefix_len: | |
| 778 | + bucket.attention_mask[:, : self.mixed_prefix_caches[layout.batch_size].max_prefix_len] = ( | |
| 779 | + self.mixed_prefix_caches[layout.batch_size].attention_mask | |
| 780 | + ) | |
| 781 | + continuation_lengths_per_task: dict[str, list[int]] = {} | |
| 782 | + continuation_tokens_per_task: dict[str, list[list[int]]] = {} | |
| 783 | + prefix_base = self.mixed_prefix_caches[layout.batch_size].max_prefix_len | |
| 784 | + for plan in layout.plans: | |
| 785 | + runner = plan.runner | |
| 786 | + per_query_continuations = [runner.build_continuation_from_query_ids(query_ids) for query_ids in query_ids_batch] | |
| 787 | + continuation_tokens_per_task[runner.task_cfg.name] = per_query_continuations | |
| 788 | + continuation_lengths_per_task[runner.task_cfg.name] = [len(ids) for ids in per_query_continuations] | |
| 789 | + if runner.fast_path: | |
| 790 | + for batch_idx, continuation in enumerate(per_query_continuations): | |
| 791 | + cont_len = len(continuation) | |
| 792 | + row_idx = plan.row_start + batch_idx | |
| 793 | + bucket.input_ids[row_idx, :cont_len] = torch.tensor(continuation, device=self.device, dtype=torch.long) | |
| 794 | + bucket.attention_mask[row_idx, prefix_base : prefix_base + cont_len] = 1 | |
| 795 | + continue | |
| 796 | + | |
| 797 | + assert runner.multi_token_tables is not None | |
| 798 | + for batch_idx, continuation in enumerate(per_query_continuations): | |
| 799 | + cont_len = len(continuation) | |
| 800 | + row_start = plan.row_start + batch_idx * runner.num_labels | |
| 801 | + row_stop = row_start + runner.num_labels | |
| 802 | + row_slice = slice(row_start, row_stop) | |
| 803 | + cont_tensor = torch.tensor(continuation, device=self.device, dtype=torch.long) | |
| 804 | + bucket.input_ids[row_slice, :cont_len] = cont_tensor.unsqueeze(0).expand(runner.num_labels, -1) | |
| 805 | + bucket.attention_mask[row_slice, prefix_base : prefix_base + cont_len] = 1 | |
| 806 | + if runner.multi_token_tables.max_label_prefix_len: | |
| 807 | + bucket.input_ids[ | |
| 808 | + row_slice, | |
| 809 | + cont_len : cont_len + runner.multi_token_tables.max_label_prefix_len, | |
| 810 | + ] = runner.multi_token_tables.label_prefix_ids | |
| 811 | + bucket.attention_mask[ | |
| 812 | + row_slice, | |
| 813 | + prefix_base + cont_len : prefix_base + cont_len + runner.multi_token_tables.max_label_prefix_len, | |
| 814 | + ] = runner.multi_token_tables.label_prefix_mask | |
| 815 | + return query_ids_batch, { | |
| 816 | + "continuation_lengths_per_task": continuation_lengths_per_task, | |
| 817 | + "continuation_tokens_per_task": continuation_tokens_per_task, | |
| 818 | + } | |
| 819 | + | |
| 820 | + def _reduce_prompt_scores( | |
| 821 | + self, | |
| 822 | + layout: BatchLayout, | |
| 823 | + bucket: MixedBucketRuntime, | |
| 824 | + query_texts: list[str], | |
| 825 | + prep_meta: dict[str, list[object]], | |
| 826 | + shared_stage_ms: dict[str, float], | |
| 827 | + ) -> list[MultiPromptScoreResult]: | |
| 828 | + result_rows = [[] for _ in range(layout.batch_size)] | |
| 829 | + prompt_reduce_total_ms = 0.0 | |
| 830 | + for plan in layout.plans: | |
| 831 | + runner = plan.runner | |
| 832 | + continuation_lengths = prep_meta["continuation_lengths_per_task"][runner.task_cfg.name] | |
| 833 | + reduce_start = time.perf_counter() | |
| 834 | + if runner.fast_path: | |
| 835 | + hidden_rows = [] | |
| 836 | + row_slice = bucket.last_hidden_state[plan.row_start : plan.row_start + layout.batch_size] | |
| 837 | + for batch_idx, cont_len in enumerate(continuation_lengths): | |
| 838 | + hidden_rows.append(row_slice[batch_idx, cont_len - 1]) | |
| 839 | + hidden = torch.stack(hidden_rows, dim=0) | |
| 840 | + runner.reduce_fast_scores(hidden=hidden, out_scores=plan.score_buffer) | |
| 841 | + stage_name = "tail_scorer" | |
| 842 | + else: | |
| 843 | + assert runner.multi_token_tables is not None | |
| 844 | + score_positions = torch.stack( | |
| 845 | + [ | |
| 846 | + cont_len - 1 + runner.multi_token_tables.label_position_offsets | |
| 847 | + for cont_len in continuation_lengths | |
| 848 | + ], | |
| 849 | + dim=0, | |
| 850 | + ) | |
| 851 | + runner.reduce_multi_token_scores( | |
| 852 | + last_hidden_state=bucket.last_hidden_state[plan.row_start : plan.row_stop], | |
| 853 | + batch_size=layout.batch_size, | |
| 854 | + max_input_len=bucket.max_input_len, | |
| 855 | + score_positions=score_positions, | |
| 856 | + out_scores=plan.score_buffer, | |
| 857 | + ) | |
| 858 | + stage_name = "candidate_reduce" | |
| 859 | + self._sync() | |
| 860 | + reduce_end = time.perf_counter() | |
| 861 | + reduce_ms = (reduce_end - reduce_start) * 1000.0 | |
| 862 | + prompt_reduce_total_ms += reduce_ms | |
| 863 | + for batch_idx, query in enumerate(query_texts): | |
| 864 | + stage_ms = dict(shared_stage_ms) | |
| 865 | + stage_ms[stage_name] = reduce_ms / layout.batch_size | |
| 866 | + result_rows[batch_idx].append( | |
| 867 | + runner.build_score_result( | |
| 868 | + query=query, | |
| 869 | + scores=plan.score_buffer[batch_idx], | |
| 870 | + stage_ms=stage_ms, | |
| 871 | + continuation_tokens=continuation_lengths[batch_idx], | |
| 872 | + ) | |
| 873 | + ) | |
| 874 | + | |
| 875 | + batch_total_ms = sum(shared_stage_ms.values()) + prompt_reduce_total_ms | |
| 876 | + shared_plus_reduce = dict(shared_stage_ms) | |
| 877 | + shared_plus_reduce["prompt_reduce_total"] = prompt_reduce_total_ms | |
| 878 | + results: list[MultiPromptScoreResult] = [] | |
| 879 | + for batch_idx, query in enumerate(query_texts): | |
| 880 | + results.append( | |
| 881 | + MultiPromptScoreResult( | |
| 882 | + query=query, | |
| 883 | + total_ms=batch_total_ms / layout.batch_size, | |
| 884 | + details=result_rows[batch_idx], | |
| 885 | + stage_ms={ | |
| 886 | + **shared_plus_reduce, | |
| 887 | + "per_query_total_estimate": batch_total_ms / layout.batch_size, | |
| 888 | + }, | |
| 889 | + ) | |
| 890 | + ) | |
| 891 | + return results | |
| 892 | + | |
| 893 | + @torch.inference_mode() | |
| 894 | + def score_queries(self, queries: list[str]) -> BatchScoreResult: | |
| 895 | + if not queries: | |
| 896 | + raise ValueError("queries must not be empty") | |
| 897 | + batch_size = len(queries) | |
| 898 | + if batch_size not in self.batch_layouts: | |
| 899 | + raise ValueError(f"batch size {batch_size} is not preloaded; configured batch_sizes={self.cfg.batch_sizes}") | |
| 900 | + layout = self.batch_layouts[batch_size] | |
| 901 | + prefix_cache = self.mixed_prefix_caches[batch_size] | |
| 902 | + | |
| 903 | + self._sync() | |
| 904 | + t0 = time.perf_counter() | |
| 905 | + query_ids_batch = [self.tokenizer.encode(query, add_special_tokens=False) for query in queries] | |
| 906 | + self._sync() | |
| 907 | + t1 = time.perf_counter() | |
| 908 | + | |
| 909 | + max_continuation_len = max( | |
| 910 | + len(plan.runner.build_continuation_from_query_ids(query_ids)) | |
| 911 | + for plan in layout.plans | |
| 912 | + for query_ids in query_ids_batch | |
| 913 | + ) | |
| 914 | + picked_bucket = self._pick_bucket(max_continuation_len) | |
| 915 | + bucket = self.mixed_buckets[(batch_size, picked_bucket)] | |
| 916 | + _, prep_meta = self._prepare_bucket(layout, prefix_cache, bucket, query_ids_batch) | |
| 917 | + self._sync() | |
| 918 | + t2 = time.perf_counter() | |
| 919 | + | |
| 920 | + if bucket.graph is not None: | |
| 921 | + bucket.graph.replay() | |
| 922 | + else: | |
| 923 | + self._run_mixed_backbone(bucket, prefix_cache) | |
| 924 | + self._sync() | |
| 925 | + t3 = time.perf_counter() | |
| 926 | + | |
| 927 | + shared_stage_ms = { | |
| 928 | + "encode_queries_shared": (t1 - t0) * 1000.0, | |
| 929 | + "prepare_batch_shared": (t2 - t1) * 1000.0, | |
| 930 | + "backbone_shared": (t3 - t2) * 1000.0, | |
| 931 | + } | |
| 932 | + results = self._reduce_prompt_scores(layout, bucket, queries, prep_meta, shared_stage_ms) | |
| 933 | + total_ms = sum(shared_stage_ms.values()) + results[0].stage_ms["prompt_reduce_total"] | |
| 934 | + return BatchScoreResult( | |
| 935 | + batch_size=batch_size, | |
| 936 | + total_ms=total_ms, | |
| 937 | + results=results, | |
| 938 | + stage_ms={ | |
| 939 | + **shared_stage_ms, | |
| 940 | + "prompt_reduce_total": results[0].stage_ms["prompt_reduce_total"], | |
| 941 | + }, | |
| 942 | + ) | |
| 943 | + | |
| 944 | + def score_query(self, query: str) -> MultiPromptScoreResult: | |
| 945 | + return self.score_queries([query]).results[0] | |
| 946 | + | |
| 947 | + def preload(self) -> PreloadReport: | |
| 948 | + if self._preload_report is not None: | |
| 949 | + return self._preload_report | |
| 950 | + stage_ms: dict[str, float] = dict(self._init_stage_ms) | |
| 951 | + start = time.perf_counter() | |
| 952 | + self._sync() | |
| 953 | + t0 = time.perf_counter() | |
| 954 | + warmup_batch_sizes = self.cfg.warmup_batch_sizes or self.cfg.batch_sizes | |
| 955 | + for batch_size in warmup_batch_sizes: | |
| 956 | + queries = [self.cfg.warmup_query] * batch_size | |
| 957 | + self._warmup_results[batch_size] = self.score_queries(queries) | |
| 958 | + self._sync() | |
| 959 | + t1 = time.perf_counter() | |
| 960 | + stage_ms["warmup_end_to_end"] = (t1 - t0) * 1000.0 | |
| 961 | + stage_ms["startup_total_before_warmup"] = self._init_total_ms | |
| 962 | + total_ms = self._init_total_ms + (t1 - start) * 1000.0 | |
| 963 | + runtime = self.preload_report() | |
| 964 | + self._preload_report = PreloadReport(total_ms=total_ms, stage_ms=stage_ms, runtime=runtime) | |
| 965 | + return self._preload_report | |
| 966 | + | |
| 967 | + def preload_report(self) -> dict[str, object]: | |
| 968 | + return { | |
| 969 | + "model_name": self.cfg.resolved_model_name, | |
| 970 | + "model_source": self.cfg.resolved_model_source, | |
| 971 | + "device": str(self.device), | |
| 972 | + "dtype": self.cfg.dtype, | |
| 973 | + "attn_backend": self.cfg.attn_backend, | |
| 974 | + "execution_model": "single_mixed_backbone_per_batch", | |
| 975 | + "num_tasks": len(self.runners), | |
| 976 | + "task_names": [runner.task_cfg.name for runner in self.runners], | |
| 977 | + "batch_sizes": list(self.cfg.batch_sizes), | |
| 978 | + "continuation_buckets": list(self.cfg.continuation_buckets), | |
| 979 | + "mixed_bucket_count": len(self.mixed_buckets), | |
| 980 | + "captured_mixed_buckets": sum(bucket.graph is not None for bucket in self.mixed_buckets.values()), | |
| 981 | + "all_configured_buckets_preloaded": True, | |
| 982 | + "init_stage_ms": dict(self._init_stage_ms), | |
| 983 | + "init_total_ms": self._init_total_ms, | |
| 984 | + "force_single_token_labels": self.cfg.force_single_token_labels, | |
| 985 | + "warmup_query": self.cfg.warmup_query, | |
| 986 | + "tasks": [ | |
| 987 | + { | |
| 988 | + "task_name": runner.task_cfg.name, | |
| 989 | + "fast_path": runner.fast_path, | |
| 990 | + "num_labels": runner.num_labels, | |
| 991 | + "label_token_lengths": {item.text: len(item.token_ids) for item in runner.encoded_labels}, | |
| 992 | + "prefix_tokens": runner.prefix_cache.prefix_len, | |
| 993 | + "prefix_hashes": runner.prefix_cache.prefix_hashes, | |
| 994 | + "label_prefix": runner.task_cfg.label_prefix, | |
| 995 | + } | |
| 996 | + for runner in self.runners | |
| 997 | + ], | |
| 998 | + } | |
| 999 | + | |
| 1000 | +但是关于多token的处理方式是低效的、是错的,不要参考他,请你重新实现。 | |
| 1001 | +本地的.venv已经创建好,是复用llm-qp的,请使用该环境 | |
| 1002 | +有用的代码我已经引用过来了,请你基于需求,搜寻相关资料,专注于重新开发全新的版本。 | |
| 1003 | +任务较重,请耐心逐步完成,慢慢来,要长时间迭代一直到服务建设完成、测试完全符合预期。 | ... | ... |
| ... | ... | @@ -0,0 +1,59 @@ |
| 1 | + | |
| 2 | +总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | |
| 3 | +先专注于推理的优化,不用考虑服务后,可以程序启动后标准输入读取query,输出分类词。 | |
| 4 | + | |
| 5 | +示例prompt和对应的9个分类词: | |
| 6 | +Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} | |
| 7 | + | |
| 8 | +做专用执行路径,不是在通用生成引擎上做配置优化。 | |
| 9 | + | |
| 10 | +prompt已经固定(不要考虑蒸馏、微调、或者缩短prompt。缩短prompt是可以我自己调整的,并且prompt可能改为其他场景,与你做专用推理优化不相关,我给的prompt只是一个例子,你专注于专用推理的框架,适配配置化的prompt+分类词列表打分检查) | |
| 11 | + | |
| 12 | + | |
| 13 | +使用Tesla T4,因此: | |
| 14 | +1. 使用FP16。不用 BF16 作为主路径 | |
| 15 | +2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | |
| 16 | +3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 | |
| 17 | + | |
| 18 | + | |
| 19 | +主要考虑优化方向为: | |
| 20 | +hidden_last -> N-class scorer -> argmax | |
| 21 | +参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | |
| 22 | +去 full vocab logits | |
| 23 | +去 decode / constrained decode | |
| 24 | +query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) | |
| 25 | +专用 tail kernel(输出 N 类原始分数) | |
| 26 | + | |
| 27 | +以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: | |
| 28 | +15.1 预分配 | |
| 29 | +对每个 bucket 预分配: | |
| 30 | +input ids buffer | |
| 31 | +attention scratch buffer | |
| 32 | +output hidden buffer | |
| 33 | +class id buffer | |
| 34 | +15.2 Graph capture | |
| 35 | + | |
| 36 | +每个 bucket 预先 capture 一张图: | |
| 37 | +embedding | |
| 38 | +transformer 主干 | |
| 39 | +compact last-state | |
| 40 | +9-way tail kernel | |
| 41 | + | |
| 42 | +注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | |
| 43 | + | |
| 44 | +你有sudo权限,你可以执行为本项目安装自己的环境 | |
| 45 | + | |
| 46 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) | |
| 47 | +或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | +另外,我想要有个命令行工具,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。 | |
| 56 | +输入query,输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下),目前是否已经满足了。 | |
| 57 | + | |
| 58 | +请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | |
| 59 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 | ... | ... |
| ... | ... | @@ -0,0 +1,66 @@ |
| 1 | +彻底清除掉命令行交互方式的代码,改造为提供 HTTP 接口,端口6001。 | |
| 2 | +并提供scripts/service_ctl.sh的方式管理服务的启停 | |
| 3 | +返回分类结果的list(results):list由多个dict组成,每个dict的key为对应的分类任务,value是三元组,即值、打分、概率(如果值为none则该项不输出) | |
| 4 | +完善日志系统,当前cli方式输出的重要信息,在日志以及http接口的details字段体现。 | |
| 5 | +提供一个压测脚本放到本项目的合适的目录下,作为压测工具。该压测工具针对每条请求打印出结果,并最后给出性能指标,参考: | |
| 6 | +#!/bin/bash | |
| 7 | + | |
| 8 | +# 默认值 | |
| 9 | +concurrency=${1:-1} | |
| 10 | +top_lines=${2:-100} | |
| 11 | + | |
| 12 | +# 固定查询文件路径 | |
| 13 | +query_file="/data/saas-search/scripts/evaluation/queries/queries.txt" | |
| 14 | + | |
| 15 | +# 检查文件是否存在 | |
| 16 | +if [ ! -f "$query_file" ]; then | |
| 17 | + echo "错误: 查询文件不存在: $query_file" >&2 | |
| 18 | + exit 1 | |
| 19 | +fi | |
| 20 | + | |
| 21 | +# 检查 jq 是否可用 | |
| 22 | +if ! command -v jq &> /dev/null; then | |
| 23 | + echo "错误: 需要安装 jq 来解析 JSON" >&2 | |
| 24 | + exit 1 | |
| 25 | +fi | |
| 26 | + | |
| 27 | +url="http://127.0.0.1:6001/..." | |
| 28 | +max_jobs=$concurrency | |
| 29 | +job_count=0 | |
| 30 | + | |
| 31 | +# 读取文件前 top_lines 行,每行作为一个 query | |
| 32 | +while IFS= read -r query; do | |
| 33 | + # 跳过空行 | |
| 34 | + [ -z "$query" ] && continue | |
| 35 | + | |
| 36 | + # 启动子进程执行请求 | |
| 37 | + ( | |
| 38 | + # 安全构建 JSON payload | |
| 39 | + payload=$(jq -n --arg q "$query" '{query: $q}') | |
| 40 | + # 发送请求并获取响应 | |
| 41 | + response=$(curl -s -X POST "$url" \ | |
| 42 | + -H 'Content-Type: application/json' \ | |
| 43 | + -d "$payload") | |
| 44 | + # 提取 results 字段(紧凑 JSON 格式) | |
| 45 | + results=$(echo "$response" | jq -c '.results') | |
| 46 | + # 输出 query 和对应的 results | |
| 47 | + printf "%s\t%s\n" "$query" "$results" | |
| 48 | + ) & | |
| 49 | + | |
| 50 | + # 控制并发数量 | |
| 51 | + ((job_count++)) | |
| 52 | + if (( job_count >= max_jobs )); then | |
| 53 | + wait -n # 等待任意一个后台进程完成 | |
| 54 | + ((job_count--)) | |
| 55 | + fi | |
| 56 | +done < <(head -n "$top_lines" "$query_file") | |
| 57 | + | |
| 58 | +# 等待所有剩余后台进程完成 | |
| 59 | +# 在这里统计处性能情况,指标: | |
| 60 | + 平均耗时: | |
| 61 | + 最大耗时: | |
| 62 | + 最小耗时: | |
| 63 | + TP50: | |
| 64 | + TP90: | |
| 65 | + TP99: | |
| 66 | +wait | |
| 0 | 67 | \ No newline at end of file | ... | ... |
docs/相关性检索优化说明.md
| ... | ... | @@ -862,17 +862,14 @@ rerank_score:0.0564 |
| 862 | 862 | "en": "Judy Blue Women's High Waist Button Fly Skinny Jeans 82319", |
| 863 | 863 | "zh": "Judy Blue 女士高腰纽扣开叉修身牛仔裤 82319" |
| 864 | 864 | |
| 865 | - | |
| 866 | 865 | rerank_score:0.0790 |
| 867 | 866 | "en": "2025 New Fashion European and American Women's Jeans High-Waisted Slim Straight Denim Pants Popular Floor-Length Pants", |
| 868 | 867 | "zh": "2025新款欧美风女式高腰显瘦直筒牛仔裤 时尚及地长裤" |
| 869 | 868 | |
| 870 | - | |
| 871 | 869 | rerank_score:0.0822 |
| 872 | 870 | "en": "roswear Women's Trendy Stretchy Flare Jeans Mid Rise Bootcut Curvy Denim Pants", |
| 873 | 871 | "zh": "Roswear 女士时尚弹力喇叭牛仔裤 中腰高腰修身直筒牛仔裤" |
| 874 | 872 | |
| 875 | - | |
| 876 | 873 | rerank_score:0.0956 |
| 877 | 874 | "en": "POSHGLAM Women's Maternity Jeans Over Belly 29'' Skinny Denim Jeggings Comfy Stretch Clearance Pregnancy Pants", |
| 878 | 875 | "zh": "POSHGLAM 女士孕产期高腰显瘦牛仔紧身裤 29英寸 紧身弹力孕妇裤 休闲舒适 清仓特价" | ... | ... |
indexer/document_transformer.py
| ... | ... | @@ -151,15 +151,15 @@ class SPUDocumentTransformer: |
| 151 | 151 | self._fill_title_embedding(doc) |
| 152 | 152 | |
| 153 | 153 | # Tags:统一转成与 mapping 一致的 core-language object |
| 154 | - if pd.notna(spu_row.get('tags')): | |
| 155 | - tags_str = str(spu_row['tags']) | |
| 154 | + if pd.notna(spu_row.get('enriched_tags')): | |
| 155 | + tags_str = str(spu_row['enriched_tags']) | |
| 156 | 156 | tags_obj = self._build_core_language_text_object( |
| 157 | 157 | tags_str, |
| 158 | 158 | source_lang=primary_lang, |
| 159 | 159 | scene="general", |
| 160 | 160 | ) |
| 161 | 161 | if tags_obj: |
| 162 | - doc['tags'] = tags_obj | |
| 162 | + doc['enriched_tags'] = tags_obj | |
| 163 | 163 | |
| 164 | 164 | # Category相关字段 |
| 165 | 165 | self._fill_category_fields(doc, spu_row) |
| ... | ... | @@ -240,7 +240,7 @@ class SPUDocumentTransformer: |
| 240 | 240 | """ |
| 241 | 241 | 批量调用 LLM,为一批 doc 填充: |
| 242 | 242 | - qanchors.{lang} |
| 243 | - - tags.{lang} | |
| 243 | + - enriched_tags.{lang} | |
| 244 | 244 | - enriched_attributes[].value.{lang} |
| 245 | 245 | |
| 246 | 246 | 设计目标: |
| ... | ... | @@ -292,8 +292,8 @@ class SPUDocumentTransformer: |
| 292 | 292 | try: |
| 293 | 293 | if enrichment.get("qanchors"): |
| 294 | 294 | doc["qanchors"] = enrichment["qanchors"] |
| 295 | - if enrichment.get("tags"): | |
| 296 | - doc["tags"] = enrichment["tags"] | |
| 295 | + if enrichment.get("enriched_tags"): | |
| 296 | + doc["enriched_tags"] = enrichment["enriched_tags"] | |
| 297 | 297 | if enrichment.get("enriched_attributes"): |
| 298 | 298 | doc["enriched_attributes"] = enrichment["enriched_attributes"] |
| 299 | 299 | except Exception as e: |
| ... | ... | @@ -656,7 +656,7 @@ class SPUDocumentTransformer: |
| 656 | 656 | """ |
| 657 | 657 | 调用 indexer.product_enrich 的高层内容理解入口,为当前 SPU 填充: |
| 658 | 658 | - qanchors.{lang} |
| 659 | - - tags.{lang} | |
| 659 | + - enriched_tags.{lang} | |
| 660 | 660 | - enriched_attributes[].value.{lang} |
| 661 | 661 | """ |
| 662 | 662 | spu_id = str(spu_row.get("id") or "").strip() | ... | ... |
indexer/product_enrich_prompts.py
| ... | ... | @@ -10,16 +10,16 @@ SYSTEM_MESSAGE = ( |
| 10 | 10 | |
| 11 | 11 | SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product text and fill these columns: |
| 12 | 12 | |
| 13 | -1. Product title: a natural localized product name derived from the input product text | |
| 14 | -2. Category path: broad to fine-grained category, separated by ">" | |
| 15 | -3. Fine-grained tags: style, features, functions, or notable attributes | |
| 16 | -4. Target audience: gender, age group, or suitable users | |
| 17 | -5. Usage scene | |
| 18 | -6. Applicable season | |
| 19 | -7. Key attributes | |
| 20 | -8. Material description | |
| 21 | -9. Functional features | |
| 22 | -10. Anchor text: a search-focused set of keywords, selling points, and phrases covering categories, attributes, usage scenarios, and user intent | |
| 13 | +1. Product title: a natural, localized product name based on the input text | |
| 14 | +2. Category path: a concise category hierarchy from broad to specific, separated by ">" | |
| 15 | +3. Fine-grained tags: concise tags for style, features, design details, function, or standout selling points | |
| 16 | +4. Target audience: gender, age group, body type, or suitable users when clearly implied | |
| 17 | +5. Usage scene: likely occasions, settings, or use cases | |
| 18 | +6. Applicable season: relevant season(s) based on the product text | |
| 19 | +7. Key attributes: core product attributes and specifications. Depending on the item type, this may include fit, silhouette, length, sleeve type, neckline, waistline, closure, pattern, design details, structure, or other relevant attribute dimensions | |
| 20 | +8. Material description: material, fabric, texture, or construction description | |
| 21 | +9. Functional features: practical or performance-related functions such as stretch, breathability, warmth, support, storage, protection, or ease of wear | |
| 22 | +10. Anchor text: a search-oriented keyword string covering product type, category intent, attributes, design cues, usage scenarios, and strong shopping phrases | |
| 23 | 23 | |
| 24 | 24 | Rules: |
| 25 | 25 | - Keep the input order and row count exactly the same. | ... | ... |
perf_reports/20260311/reranker_1000docs/report.md
| ... | ... | @@ -4,7 +4,7 @@ Workload profile: |
| 4 | 4 | - backend: `qwen3_vllm` (`Qwen/Qwen3-Reranker-0.6B`) |
| 5 | 5 | - query: short e-commerce text (<100 tokens) |
| 6 | 6 | - docs/request: 1000 short titles/title+brief |
| 7 | -- options: `sort_by_doc_length=true`, `length_sort_mode=char` | |
| 7 | +- options: `sort_by_doc_length=true` | |
| 8 | 8 | |
| 9 | 9 | ## Results |
| 10 | 10 | ... | ... |
reranker/DEPLOYMENT_AND_TUNING.md
| 1 | -# Reranker 部署与性能调优手册(Qwen3-vLLM / Qwen3-GGUF) | |
| 1 | +# Reranker 部署与性能调优手册(Qwen3-vLLM) | |
| 2 | 2 | |
| 3 | 3 | 本文档沉淀当前项目在电商搜索重排场景下的可复用实践,覆盖: |
| 4 | 4 | |
| 5 | 5 | - 环境准备与安装部署 |
| 6 | -- `qwen3_vllm` / `qwen3_gguf` / `qwen3_gguf_06b` 配置项与优化思路 | |
| 6 | +- `qwen3_vllm` 配置项与优化思路 | |
| 7 | 7 | - 1000-doc 场景压测流程 |
| 8 | 8 | - 关键结论与推荐默认参数 |
| 9 | 9 | - 常见故障排查 |
| 10 | 10 | |
| 11 | 11 | 适用范围: |
| 12 | 12 | |
| 13 | -- 重排后端:`services.rerank.backend: qwen3_vllm` / `qwen3_gguf` / `qwen3_gguf_06b` | |
| 14 | -- 模型:`Qwen/Qwen3-Reranker-0.6B` / `DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF` / `ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF` | |
| 13 | +- 重排后端:`services.rerank.backend: qwen3_vllm` | |
| 14 | +- 模型:`Qwen/Qwen3-Reranker-0.6B` | |
| 15 | 15 | - 场景:query 较短(通常 < 100 tokens),doc 为商品标题或标题+简短描述,单请求 docs 约 1000 条 |
| 16 | 16 | |
| 17 | 17 | ## 1. 环境基线 |
| 18 | 18 | |
| 19 | -当前验证环境(2026-03-25): | |
| 19 | +当前验证环境(2026-03-11): | |
| 20 | 20 | |
| 21 | 21 | - GPU:`Tesla T4 16GB` |
| 22 | 22 | - Driver / CUDA:`570.158.01 / 12.8` |
| 23 | 23 | - Python:`3.12.3` |
| 24 | -- 关键依赖:`vllm==0.17.0`、`torch==2.10.0+cu128`、`transformers==4.57.6`、`llama-cpp-python>=0.3.16`、`fastapi==0.135.1`、`uvicorn==0.41.0` | |
| 24 | +- 关键依赖:`vllm==0.17.0`、`torch==2.10.0+cu128`、`transformers==4.57.6`、`fastapi==0.135.1`、`uvicorn==0.41.0` | |
| 25 | 25 | |
| 26 | 26 | ## 2. 环境准备与安装 |
| 27 | 27 | |
| 28 | 28 | ### 2.1 准备 reranker 独立虚拟环境 |
| 29 | 29 | |
| 30 | 30 | ```bash |
| 31 | -./scripts/setup_reranker_venv.sh qwen3_vllm | |
| 32 | -``` | |
| 33 | - | |
| 34 | -若使用 GGUF 并需要 CUDA: | |
| 35 | - | |
| 36 | -```bash | |
| 37 | -./scripts/setup_reranker_venv.sh qwen3_gguf | |
| 38 | -PATH=/usr/local/cuda/bin:$PATH \ | |
| 39 | -CUDACXX=/usr/local/cuda/bin/nvcc \ | |
| 40 | -CMAKE_ARGS="-DGGML_CUDA=on" \ | |
| 41 | -FORCE_CMAKE=1 \ | |
| 42 | -./.venv-reranker-gguf/bin/pip install --no-cache-dir --force-reinstall --no-build-isolation llama-cpp-python==0.3.18 | |
| 31 | +./scripts/setup_reranker_venv.sh | |
| 43 | 32 | ``` |
| 44 | 33 | |
| 45 | 34 | ### 2.2 基础检查 |
| ... | ... | @@ -48,7 +37,6 @@ FORCE_CMAKE=1 \ |
| 48 | 37 | nvidia-smi |
| 49 | 38 | ./.venv-reranker/bin/python -c "import torch; print(torch.cuda.is_available())" |
| 50 | 39 | ./.venv-reranker/bin/python -c "import vllm, transformers; print(vllm.__version__, transformers.__version__)" |
| 51 | -./.venv-reranker-gguf/bin/python -c "import llama_cpp; print(llama_cpp.__version__)" | |
| 52 | 40 | ``` |
| 53 | 41 | |
| 54 | 42 | ## 3. 部署与运行 |
| ... | ... | @@ -73,30 +61,6 @@ services: |
| 73 | 61 | enforce_eager: false |
| 74 | 62 | infer_batch_size: 64 |
| 75 | 63 | sort_by_doc_length: true |
| 76 | - length_sort_mode: "char" # char | token | |
| 77 | -``` | |
| 78 | - | |
| 79 | -GGUF / T4 剩余显存约 `4.8~6GB` 时,推荐基线: | |
| 80 | - | |
| 81 | -```yaml | |
| 82 | -services: | |
| 83 | - rerank: | |
| 84 | - backend: "qwen3_gguf" | |
| 85 | - backends: | |
| 86 | - qwen3_gguf: | |
| 87 | - repo_id: "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF" | |
| 88 | - filename: "*Q8_0.gguf" | |
| 89 | - local_dir: "./models/reranker/qwen3-reranker-4b-gguf" | |
| 90 | - cache_dir: "./model_cache" | |
| 91 | - n_ctx: 384 | |
| 92 | - n_batch: 384 | |
| 93 | - n_ubatch: 128 | |
| 94 | - n_gpu_layers: 24 | |
| 95 | - flash_attn: true | |
| 96 | - offload_kqv: true | |
| 97 | - infer_batch_size: 8 | |
| 98 | - sort_by_doc_length: true | |
| 99 | - length_sort_mode: "char" | |
| 100 | 64 | ``` |
| 101 | 65 | |
| 102 | 66 | ### 3.2 启停命令 |
| ... | ... | @@ -140,13 +104,6 @@ curl -sS http://127.0.0.1:6007/health |
| 140 | 104 | - `service_ctl.sh` 对 reranker 使用独立启动路径 |
| 141 | 105 | - 增加“稳定健康检查”(连续健康探测)避免“刚 healthy 即退出”的假阳性 |
| 142 | 106 | |
| 143 | -### 4.4 GGUF / T4 小显存优化原则 | |
| 144 | - | |
| 145 | -- `Q8_0` 权重约 `4.28GB`,但还要给 KV cache、CUDA 工作区和运行时碎片预留空间,不能按“模型大小 < 剩余显存”直接判断可行。 | |
| 146 | -- 当前业务是短 query + 商品标题,优先压缩 `n_ctx`;`384` 通常比默认长上下文更划算。 | |
| 147 | -- T4 小显存下先扫 `n_gpu_layers`,再尝试提高 `n_ctx`;`infer_batch_size` 在当前 GGUF 接入里主要是服务侧 work chunk,不是 llama.cpp 的真实算子 batch。 | |
| 148 | -- `flash_attn: true`、`offload_kqv: true` 默认保持开启;若 OOM,优先降低 `n_gpu_layers`。 | |
| 149 | - | |
| 150 | 107 | ## 5. 性能调优流程(标准流程) |
| 151 | 108 | |
| 152 | 109 | ### 5.1 使用一键压测脚本 |
| ... | ... | @@ -167,13 +124,6 @@ curl -sS http://127.0.0.1:6007/health |
| 167 | 124 | - `infer_batch_size`: `24 32 48 64` |
| 168 | 125 | - 并发组:`c=1`(看单请求延迟)、`c=4`(看并发吞吐与尾延迟) |
| 169 | 126 | |
| 170 | -GGUF 建议扫描: | |
| 171 | - | |
| 172 | -- `n_gpu_layers`: `20 24 28` | |
| 173 | -- `n_ctx`: `320 384 448` | |
| 174 | -- `infer_batch_size`: `4 8 12`(次要,仅影响服务侧 work chunk) | |
| 175 | -- 扫描顺序:先固定 `n_ctx=384`,找能稳定启动的最大 `n_gpu_layers`;再在显存允许时尝试 `n_ctx=448`;最后才微调 `infer_batch_size` | |
| 176 | - | |
| 177 | 127 | 可通过环境变量覆盖: |
| 178 | 128 | |
| 179 | 129 | - `BATCH_SIZES` |
| ... | ... | @@ -189,28 +139,23 @@ GGUF 建议扫描: |
| 189 | 139 | - `RERANK_VLLM_INFER_BATCH_SIZE` |
| 190 | 140 | - `RERANK_VLLM_SORT_BY_DOC_LENGTH` |
| 191 | 141 | |
| 192 | -## 6. 本轮关键结论 | |
| 142 | +## 6. 本轮关键结论(2026-03-11) | |
| 143 | + | |
| 144 | +基于报告: | |
| 145 | + | |
| 146 | +- `perf_reports/20260311/reranker_1000docs/report.md` | |
| 193 | 147 | |
| 194 | -vLLM(2026-03-11,见 `perf_reports/20260311/reranker_1000docs/report.md`): | |
| 148 | +结论: | |
| 195 | 149 | |
| 196 | 150 | - 对在线重排更重要的单请求延迟(`c=1`)指标,`infer_batch_size=64` 最优 |
| 197 | 151 | - `infer_batch_size=96` 在更高并发下吞吐略高,但会牺牲单请求延迟稳定性 |
| 198 | 152 | - 当前默认选择 `infer_batch_size=64` 作为平衡点 |
| 199 | 153 | |
| 200 | -GGUF(2026-03-25,本次接入): | |
| 201 | - | |
| 202 | -- `DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF` 的 `Q8_0` 体积约 `4.28GB`,结合当前机器实测剩余显存约 `4823 MiB`,默认不采用激进的全量 GPU offload。 | |
| 203 | -- 当前推荐默认值:`n_ctx=384`、`n_batch=384`、`n_ubatch=128`、`n_gpu_layers=24`、`infer_batch_size=8`。 | |
| 204 | -- 若现场剩余显存更接近 `6GB` 且碎片较少,可优先尝试 `n_gpu_layers=28`;若启动失败,回退到 `24` 或 `20`。 | |
| 205 | -- 由于当前工作区尚未缓存该 GGUF 权重,本次尚未完成真实吞吐压测;上线前需在部署机复跑一轮参数扫描并归档报告。 | |
| 206 | - | |
| 207 | 154 | ## 7. 生产建议 |
| 208 | 155 | |
| 209 | 156 | - 默认保持:`infer_batch_size: 64`、`sort_by_doc_length: true` |
| 210 | 157 | - 满足以下条件时可考虑提高到 `96`:业务以吞吐优先、可接受更高单请求延迟、已通过同机同数据压测验证收益 |
| 211 | 158 | - 每次改动后都必须复跑 `benchmark_reranker_1000docs.sh` 并归档结果 |
| 212 | -- GGUF 默认保持:`n_ctx: 384`、`n_gpu_layers: 24`、`infer_batch_size: 8`、`flash_attn: true`、`offload_kqv: true` | |
| 213 | -- GGUF 若 OOM:先降 `n_gpu_layers`,再降 `n_ctx`,最后再降 `infer_batch_size` | |
| 214 | 159 | |
| 215 | 160 | ## 8. 故障排查 |
| 216 | 161 | |
| ... | ... | @@ -248,13 +193,6 @@ lsof -i :6007 -P -n |
| 248 | 193 | - 降低 `infer_batch_size` |
| 249 | 194 | - 检查是否有其他进程占用同卡 |
| 250 | 195 | |
| 251 | -GGUF 优先调整: | |
| 252 | - | |
| 253 | -- 降低 `n_gpu_layers` | |
| 254 | -- 降低 `n_ctx` | |
| 255 | -- 降低 `infer_batch_size` | |
| 256 | -- 检查是否有其他进程占用同卡 | |
| 257 | - | |
| 258 | 196 | ## 9. 变更与验证清单 |
| 259 | 197 | |
| 260 | 198 | 每次 reranker 调优改动后,至少完成: | ... | ... |
scripts/build_suggestions.sh
| ... | ... | @@ -9,19 +9,38 @@ |
| 9 | 9 | # # incremental update from watermark |
| 10 | 10 | # ./scripts/build_suggestions.sh <tenant_id> --mode incremental |
| 11 | 11 | # |
| 12 | +# # full rebuild + incremental + ES/API smoke checks (same as legacy rebuild_suggestions.sh) | |
| 13 | +# ./scripts/build_suggestions.sh <tenant_id> --rebuild | |
| 14 | +# | |
| 12 | 15 | |
| 13 | 16 | set -euo pipefail |
| 14 | 17 | |
| 15 | 18 | if [ $# -lt 1 ]; then |
| 16 | - echo "Usage: $0 <tenant_id> [extra args...]" | |
| 17 | - echo "Example (full): $0 162 --mode full --days 30 --publish-alias" | |
| 18 | - echo "Example (incremental): $0 162 --mode incremental --overlap-minutes 30" | |
| 19 | + echo "Usage: $0 <tenant_id> [--rebuild | extra args for main.py build-suggestions...]" | |
| 20 | + echo "Example (full): $0 163 --mode full --days 30 --publish-alias" | |
| 21 | + echo "Example (incremental): $0 163 --mode incremental --overlap-minutes 30" | |
| 22 | + echo "Example (pipeline + smoke): $0 163 --rebuild" | |
| 19 | 23 | exit 1 |
| 20 | 24 | fi |
| 21 | 25 | |
| 22 | 26 | TENANT_ID="$1" |
| 23 | 27 | shift || true |
| 24 | 28 | |
| 29 | +RUN_REBUILD_PIPELINE=false | |
| 30 | +PASSTHROUGH_ARGS=() | |
| 31 | +for arg in "$@"; do | |
| 32 | + if [ "$arg" = "--rebuild" ]; then | |
| 33 | + RUN_REBUILD_PIPELINE=true | |
| 34 | + else | |
| 35 | + PASSTHROUGH_ARGS+=("$arg") | |
| 36 | + fi | |
| 37 | +done | |
| 38 | + | |
| 39 | +if [ "$RUN_REBUILD_PIPELINE" = true ] && [ ${#PASSTHROUGH_ARGS[@]} -gt 0 ]; then | |
| 40 | + echo "Error: --rebuild cannot be combined with other build-suggestions arguments." | |
| 41 | + exit 1 | |
| 42 | +fi | |
| 43 | + | |
| 25 | 44 | ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" |
| 26 | 45 | |
| 27 | 46 | cd "$ROOT_DIR" |
| ... | ... | @@ -31,6 +50,79 @@ if [ ! -x "$PY_BIN" ]; then |
| 31 | 50 | PY_BIN="python3" |
| 32 | 51 | fi |
| 33 | 52 | |
| 53 | +if [ "$RUN_REBUILD_PIPELINE" = true ]; then | |
| 54 | + SAMPLE_QUERIES=(s sh dress tshirt) | |
| 55 | + SAMPLE_LANGS=(en zh) | |
| 56 | + # This script validates the locally rebuilt index, so default the smoke target | |
| 57 | + # to the local backend. A remote/public API base must be opted into explicitly. | |
| 58 | + API_BASE="${SUGGESTIONS_SMOKE_BASE_URL:-http://localhost:6002}" | |
| 59 | + | |
| 60 | + if [ -z "${ES_HOST:-}" ]; then | |
| 61 | + ES_HOST="$("$PY_BIN" - <<'PY' | |
| 62 | +from dotenv import dotenv_values | |
| 63 | +print(dotenv_values('.env').get('ES_HOST') or 'http://localhost:9200') | |
| 64 | +PY | |
| 65 | +)" | |
| 66 | + fi | |
| 67 | + | |
| 68 | + if [ -z "${ES_USERNAME:-}" ] || [ -z "${ES_PASSWORD:-}" ]; then | |
| 69 | + readarray -t _ES_CREDS < <("$PY_BIN" - <<'PY' | |
| 70 | +from dotenv import dotenv_values | |
| 71 | +cfg = dotenv_values('.env') | |
| 72 | +print(cfg.get('ES_USERNAME') or '') | |
| 73 | +print(cfg.get('ES_PASSWORD') or '') | |
| 74 | +PY | |
| 75 | +) | |
| 76 | + ES_USERNAME="${ES_USERNAME:-${_ES_CREDS[0]}}" | |
| 77 | + ES_PASSWORD="${ES_PASSWORD:-${_ES_CREDS[1]}}" | |
| 78 | + fi | |
| 79 | + | |
| 80 | + if [ -n "${ES_USERNAME:-}" ] && [ -n "${ES_PASSWORD:-}" ]; then | |
| 81 | + AUTH=(-u "${ES_USERNAME}:${ES_PASSWORD}") | |
| 82 | + else | |
| 83 | + AUTH=() | |
| 84 | + fi | |
| 85 | + | |
| 86 | + ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_${TENANT_ID}_current" | |
| 87 | + | |
| 88 | + echo "[1/4] Full rebuild tenant=${TENANT_ID} (versioned + alias publish)" | |
| 89 | + "$PY_BIN" main.py build-suggestions \ | |
| 90 | + --tenant-id "$TENANT_ID" \ | |
| 91 | + --es-host "$ES_HOST" \ | |
| 92 | + --mode full \ | |
| 93 | + --days 365 \ | |
| 94 | + --batch-size 500 \ | |
| 95 | + --publish-alias \ | |
| 96 | + --keep-versions 2 | |
| 97 | + | |
| 98 | + echo "[2/4] Incremental update tenant=${TENANT_ID}" | |
| 99 | + "$PY_BIN" main.py build-suggestions \ | |
| 100 | + --tenant-id "$TENANT_ID" \ | |
| 101 | + --es-host "$ES_HOST" \ | |
| 102 | + --mode incremental \ | |
| 103 | + --overlap-minutes 30 | |
| 104 | + | |
| 105 | + echo "[3/4] ES count + sample" | |
| 106 | + curl -sS "${AUTH[@]}" "$ES_HOST/$ALIAS_NAME/_count?pretty" | |
| 107 | + echo | |
| 108 | + curl -sS "${AUTH[@]}" "$ES_HOST/$ALIAS_NAME/_search?pretty" -H 'Content-Type: application/json' -d '{ | |
| 109 | + "size": 5, | |
| 110 | + "query": {"match_all": {}}, | |
| 111 | + "_source": ["lang", "text", "rank_score", "sources", "query_count_30d"] | |
| 112 | +}' | |
| 113 | + echo | |
| 114 | + | |
| 115 | + echo "[4/4] API smoke test (base=${API_BASE})" | |
| 116 | + for lang in "${SAMPLE_LANGS[@]}"; do | |
| 117 | + for q in "${SAMPLE_QUERIES[@]}"; do | |
| 118 | + echo "--- GET /search/suggestions?q=${q}&language=${lang} ---" | |
| 119 | + curl -sS "$API_BASE/search/suggestions?q=${q}&size=10&language=${lang}" -H "X-Tenant-ID: ${TENANT_ID}" | |
| 120 | + echo | |
| 121 | + done | |
| 122 | + done | |
| 123 | + exit 0 | |
| 124 | +fi | |
| 125 | + | |
| 34 | 126 | "$PY_BIN" main.py build-suggestions \ |
| 35 | 127 | --tenant-id "$TENANT_ID" \ |
| 36 | - "$@" | |
| 128 | + "${PASSTHROUGH_ARGS[@]}" | ... | ... |
scripts/rebuild_suggestions.sh deleted
| ... | ... | @@ -1,86 +0,0 @@ |
| 1 | -#!/usr/bin/env bash | |
| 2 | -set -euo pipefail | |
| 3 | - | |
| 4 | -if [ $# -lt 1 ]; then | |
| 5 | - echo "Usage: $0 <tenant_id>" | |
| 6 | - echo "Example: $0 162" | |
| 7 | - exit 1 | |
| 8 | -fi | |
| 9 | - | |
| 10 | -ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" | |
| 11 | -TENANT_ID="$1" | |
| 12 | -# Fixed smoke-test queries and languages (no CLI args). | |
| 13 | -SAMPLE_QUERIES=(s sh dress tshirt) | |
| 14 | -SAMPLE_LANGS=(en zh) | |
| 15 | -API_BASE="${API_BASE_URL:-http://localhost:6002}" | |
| 16 | - | |
| 17 | -cd "$ROOT_DIR" | |
| 18 | - | |
| 19 | -PY_BIN="${PYTHON_BIN:-$ROOT_DIR/.venv/bin/python}" | |
| 20 | -if [ ! -x "$PY_BIN" ]; then | |
| 21 | - PY_BIN="python3" | |
| 22 | -fi | |
| 23 | - | |
| 24 | -if [ -z "${ES_HOST:-}" ]; then | |
| 25 | - ES_HOST="$("$PY_BIN" - <<'PY' | |
| 26 | -from dotenv import dotenv_values | |
| 27 | -print(dotenv_values('.env').get('ES_HOST') or 'http://localhost:9200') | |
| 28 | -PY | |
| 29 | -)" | |
| 30 | -fi | |
| 31 | - | |
| 32 | -if [ -z "${ES_USERNAME:-}" ] || [ -z "${ES_PASSWORD:-}" ]; then | |
| 33 | - readarray -t _ES_CREDS < <("$PY_BIN" - <<'PY' | |
| 34 | -from dotenv import dotenv_values | |
| 35 | -cfg = dotenv_values('.env') | |
| 36 | -print(cfg.get('ES_USERNAME') or '') | |
| 37 | -print(cfg.get('ES_PASSWORD') or '') | |
| 38 | -PY | |
| 39 | -) | |
| 40 | - ES_USERNAME="${ES_USERNAME:-${_ES_CREDS[0]}}" | |
| 41 | - ES_PASSWORD="${ES_PASSWORD:-${_ES_CREDS[1]}}" | |
| 42 | -fi | |
| 43 | - | |
| 44 | -if [ -n "${ES_USERNAME:-}" ] && [ -n "${ES_PASSWORD:-}" ]; then | |
| 45 | - AUTH=(-u "${ES_USERNAME}:${ES_PASSWORD}") | |
| 46 | -else | |
| 47 | - AUTH=() | |
| 48 | -fi | |
| 49 | - | |
| 50 | -ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_${TENANT_ID}_current" | |
| 51 | - | |
| 52 | -echo "[1/4] Full rebuild tenant=${TENANT_ID} (versioned + alias publish)" | |
| 53 | -"$PY_BIN" main.py build-suggestions \ | |
| 54 | - --tenant-id "$TENANT_ID" \ | |
| 55 | - --es-host "$ES_HOST" \ | |
| 56 | - --mode full \ | |
| 57 | - --days 365 \ | |
| 58 | - --batch-size 500 \ | |
| 59 | - --publish-alias \ | |
| 60 | - --keep-versions 2 | |
| 61 | - | |
| 62 | -echo "[2/4] Incremental update tenant=${TENANT_ID}" | |
| 63 | -"$PY_BIN" main.py build-suggestions \ | |
| 64 | - --tenant-id "$TENANT_ID" \ | |
| 65 | - --es-host "$ES_HOST" \ | |
| 66 | - --mode incremental \ | |
| 67 | - --overlap-minutes 30 | |
| 68 | - | |
| 69 | -echo "[3/4] ES count + sample" | |
| 70 | -curl -sS "${AUTH[@]}" "$ES_HOST/$ALIAS_NAME/_count?pretty" | |
| 71 | -echo | |
| 72 | -curl -sS "${AUTH[@]}" "$ES_HOST/$ALIAS_NAME/_search?pretty" -H 'Content-Type: application/json' -d '{ | |
| 73 | - "size": 5, | |
| 74 | - "query": {"match_all": {}}, | |
| 75 | - "_source": ["lang", "text", "rank_score", "sources", "query_count_30d"] | |
| 76 | -}' | |
| 77 | -echo | |
| 78 | - | |
| 79 | -echo "[4/4] API smoke test" | |
| 80 | -for lang in "${SAMPLE_LANGS[@]}"; do | |
| 81 | - for q in "${SAMPLE_QUERIES[@]}"; do | |
| 82 | - echo "--- GET /search/suggestions?q=${q}&language=${lang} ---" | |
| 83 | - curl -sS "$API_BASE/search/suggestions?q=${q}&size=10&language=${lang}" -H "X-Tenant-ID: ${TENANT_ID}" | |
| 84 | - echo | |
| 85 | - done | |
| 86 | -done |
scripts/service_ctl.sh
| ... | ... | @@ -16,9 +16,10 @@ mkdir -p "${LOG_DIR}" |
| 16 | 16 | source "${PROJECT_ROOT}/scripts/lib/load_env.sh" |
| 17 | 17 | |
| 18 | 18 | CORE_SERVICES=("backend" "indexer" "frontend" "eval-web") |
| 19 | -OPTIONAL_SERVICES=("tei" "cnclip" "embedding" "embedding-image" "translator" "reranker" "reranker-fine") | |
| 19 | +# reranker-fine 暂时不用,因此暂时从OPTIONAL_SERVICES中删除 | |
| 20 | +OPTIONAL_SERVICES=("tei" "cnclip" "embedding" "embedding-image" "translator" "reranker") | |
| 20 | 21 | FULL_SERVICES=("${OPTIONAL_SERVICES[@]}" "${CORE_SERVICES[@]}") |
| 21 | -STOP_ORDER_SERVICES=("frontend" "eval-web" "indexer" "backend" "reranker-fine" "reranker" "translator" "embedding-image" "embedding" "cnclip" "tei") | |
| 22 | +STOP_ORDER_SERVICES=("frontend" "eval-web" "indexer" "backend" "reranker" "translator" "embedding-image" "embedding" "cnclip" "tei") | |
| 22 | 23 | |
| 23 | 24 | all_services() { |
| 24 | 25 | echo "${FULL_SERVICES[@]}" | ... | ... |
suggestion/builder.py
| ... | ... | @@ -38,7 +38,7 @@ def get_suggestion_alias_name(tenant_id: str) -> str: |
| 38 | 38 | |
| 39 | 39 | def get_suggestion_versioned_index_name(tenant_id: str, build_at: Optional[datetime] = None) -> str: |
| 40 | 40 | """Versioned suggestion index name.""" |
| 41 | - ts = (build_at or datetime.now(timezone.utc)).strftime("%Y%m%d%H%M%S") | |
| 41 | + ts = (build_at or datetime.now(timezone.utc)).strftime("%Y%m%d%H%M%S%f") | |
| 42 | 42 | return f"{_index_prefix()}search_suggestions_tenant_{tenant_id}_v{ts}" |
| 43 | 43 | |
| 44 | 44 | |
| ... | ... | @@ -101,6 +101,79 @@ class SuggestionIndexBuilder: |
| 101 | 101 | self.es_client = es_client |
| 102 | 102 | self.db_engine = db_engine |
| 103 | 103 | |
| 104 | + def _format_allocation_failure(self, index_name: str) -> str: | |
| 105 | + health = self.es_client.wait_for_index_ready(index_name=index_name, timeout="5s") | |
| 106 | + explain = self.es_client.get_allocation_explain(index_name=index_name) | |
| 107 | + | |
| 108 | + parts = [ | |
| 109 | + f"Suggestion index '{index_name}' was created but is not allocatable/readable yet", | |
| 110 | + f"health_status={health.get('status')}", | |
| 111 | + f"timed_out={health.get('timed_out')}", | |
| 112 | + ] | |
| 113 | + if health.get("error"): | |
| 114 | + parts.append(f"health_error={health['error']}") | |
| 115 | + | |
| 116 | + if explain: | |
| 117 | + unassigned = explain.get("unassigned_info") or {} | |
| 118 | + if unassigned.get("reason"): | |
| 119 | + parts.append(f"unassigned_reason={unassigned['reason']}") | |
| 120 | + if unassigned.get("last_allocation_status"): | |
| 121 | + parts.append(f"last_allocation_status={unassigned['last_allocation_status']}") | |
| 122 | + | |
| 123 | + for node in explain.get("node_allocation_decisions") or []: | |
| 124 | + node_name = node.get("node_name") or node.get("node_id") or "unknown-node" | |
| 125 | + for decider in node.get("deciders") or []: | |
| 126 | + if decider.get("decision") == "NO": | |
| 127 | + parts.append( | |
| 128 | + f"{node_name}:{decider.get('decider')}={decider.get('explanation')}" | |
| 129 | + ) | |
| 130 | + return "; ".join(parts) | |
| 131 | + | |
| 132 | + return "; ".join(parts) | |
| 133 | + | |
| 134 | + def _create_fresh_versioned_index( | |
| 135 | + self, | |
| 136 | + tenant_id: str, | |
| 137 | + mapping: Dict[str, Any], | |
| 138 | + max_attempts: int = 5, | |
| 139 | + ) -> str: | |
| 140 | + for attempt in range(1, max_attempts + 1): | |
| 141 | + index_name = get_suggestion_versioned_index_name(tenant_id) | |
| 142 | + if self.es_client.index_exists(index_name): | |
| 143 | + logger.warning( | |
| 144 | + "Suggestion index name collision before create for tenant=%s index=%s attempt=%s/%s", | |
| 145 | + tenant_id, | |
| 146 | + index_name, | |
| 147 | + attempt, | |
| 148 | + max_attempts, | |
| 149 | + ) | |
| 150 | + continue | |
| 151 | + | |
| 152 | + if self.es_client.create_index(index_name, mapping): | |
| 153 | + return index_name | |
| 154 | + | |
| 155 | + if self.es_client.index_exists(index_name): | |
| 156 | + logger.warning( | |
| 157 | + "Suggestion index name collision during create for tenant=%s index=%s attempt=%s/%s", | |
| 158 | + tenant_id, | |
| 159 | + index_name, | |
| 160 | + attempt, | |
| 161 | + max_attempts, | |
| 162 | + ) | |
| 163 | + continue | |
| 164 | + | |
| 165 | + raise RuntimeError(f"Failed to create suggestion index: {index_name}") | |
| 166 | + | |
| 167 | + raise RuntimeError( | |
| 168 | + f"Failed to allocate a unique suggestion index name for tenant={tenant_id} after {max_attempts} attempts" | |
| 169 | + ) | |
| 170 | + | |
| 171 | + def _ensure_new_index_ready(self, index_name: str) -> None: | |
| 172 | + health = self.es_client.wait_for_index_ready(index_name=index_name, timeout="5s") | |
| 173 | + if health.get("ok"): | |
| 174 | + return | |
| 175 | + raise RuntimeError(self._format_allocation_failure(index_name)) | |
| 176 | + | |
| 104 | 177 | @staticmethod |
| 105 | 178 | def _to_utc(dt: Any) -> Optional[datetime]: |
| 106 | 179 | if dt is None: |
| ... | ... | @@ -297,7 +370,7 @@ class SuggestionIndexBuilder: |
| 297 | 370 | while True: |
| 298 | 371 | body: Dict[str, Any] = { |
| 299 | 372 | "size": batch_size, |
| 300 | - "_source": ["id", "spu_id", "title", "qanchors", "tags"], | |
| 373 | + "_source": ["id", "spu_id", "title", "qanchors", "enriched_tags"], | |
| 301 | 374 | "sort": [ |
| 302 | 375 | {"spu_id": {"order": "asc", "missing": "_last"}}, |
| 303 | 376 | {"id.keyword": {"order": "asc", "missing": "_last"}}, |
| ... | ... | @@ -511,7 +584,7 @@ class SuggestionIndexBuilder: |
| 511 | 584 | c.add_product("qanchor", spu_id=product_id) |
| 512 | 585 | |
| 513 | 586 | for tag_lang, tag in self._iter_multilang_product_tags( |
| 514 | - src.get("tags"), | |
| 587 | + src.get("enriched_tags"), | |
| 515 | 588 | index_languages=index_languages, |
| 516 | 589 | primary_language=primary_language, |
| 517 | 590 | ): |
| ... | ... | @@ -609,62 +682,65 @@ class SuggestionIndexBuilder: |
| 609 | 682 | index_languages: List[str] = tenant_cfg.get("index_languages") or ["en", "zh"] |
| 610 | 683 | primary_language: str = tenant_cfg.get("primary_language") or "en" |
| 611 | 684 | |
| 612 | - # Always write to a fresh versioned index; legacy concrete index is no longer supported. | |
| 613 | - index_name = get_suggestion_versioned_index_name(tenant_id) | |
| 614 | - | |
| 615 | - if self.es_client.index_exists(index_name): | |
| 616 | - raise RuntimeError(f"Target suggestion index already exists: {index_name}") | |
| 617 | - | |
| 618 | - mapping = build_suggestion_mapping(index_languages=index_languages) | |
| 619 | - if not self.es_client.create_index(index_name, mapping): | |
| 620 | - raise RuntimeError(f"Failed to create suggestion index: {index_name}") | |
| 685 | + alias_publish: Optional[Dict[str, Any]] = None | |
| 686 | + index_name: Optional[str] = None | |
| 687 | + try: | |
| 688 | + mapping = build_suggestion_mapping(index_languages=index_languages) | |
| 689 | + index_name = self._create_fresh_versioned_index( | |
| 690 | + tenant_id=tenant_id, | |
| 691 | + mapping=mapping, | |
| 692 | + ) | |
| 693 | + self._ensure_new_index_ready(index_name) | |
| 621 | 694 | |
| 622 | - key_to_candidate = self._build_full_candidates( | |
| 623 | - tenant_id=tenant_id, | |
| 624 | - index_languages=index_languages, | |
| 625 | - primary_language=primary_language, | |
| 626 | - days=days, | |
| 627 | - batch_size=batch_size, | |
| 628 | - min_query_len=min_query_len, | |
| 629 | - ) | |
| 695 | + key_to_candidate = self._build_full_candidates( | |
| 696 | + tenant_id=tenant_id, | |
| 697 | + index_languages=index_languages, | |
| 698 | + primary_language=primary_language, | |
| 699 | + days=days, | |
| 700 | + batch_size=batch_size, | |
| 701 | + min_query_len=min_query_len, | |
| 702 | + ) | |
| 630 | 703 | |
| 631 | - now_iso = datetime.now(timezone.utc).isoformat() | |
| 632 | - docs = [self._candidate_to_doc(tenant_id, c, now_iso) for c in key_to_candidate.values()] | |
| 704 | + now_iso = datetime.now(timezone.utc).isoformat() | |
| 705 | + docs = [self._candidate_to_doc(tenant_id, c, now_iso) for c in key_to_candidate.values()] | |
| 633 | 706 | |
| 634 | - if docs: | |
| 635 | - bulk_result = self.es_client.bulk_index(index_name=index_name, docs=docs) | |
| 636 | - self.es_client.refresh(index_name) | |
| 637 | - else: | |
| 638 | - bulk_result = {"success": 0, "failed": 0, "errors": []} | |
| 707 | + if docs: | |
| 708 | + bulk_result = self.es_client.bulk_index(index_name=index_name, docs=docs) | |
| 709 | + self.es_client.refresh(index_name) | |
| 710 | + else: | |
| 711 | + bulk_result = {"success": 0, "failed": 0, "errors": []} | |
| 639 | 712 | |
| 640 | - alias_publish: Optional[Dict[str, Any]] = None | |
| 641 | - if publish_alias: | |
| 642 | - alias_publish = self._publish_alias( | |
| 643 | - tenant_id=tenant_id, | |
| 644 | - index_name=index_name, | |
| 645 | - keep_versions=keep_versions, | |
| 646 | - ) | |
| 713 | + if publish_alias: | |
| 714 | + alias_publish = self._publish_alias( | |
| 715 | + tenant_id=tenant_id, | |
| 716 | + index_name=index_name, | |
| 717 | + keep_versions=keep_versions, | |
| 718 | + ) | |
| 647 | 719 | |
| 648 | - now_utc = datetime.now(timezone.utc).isoformat() | |
| 649 | - meta_patch: Dict[str, Any] = { | |
| 650 | - "last_full_build_at": now_utc, | |
| 651 | - "last_incremental_watermark": now_utc, | |
| 652 | - } | |
| 653 | - if publish_alias: | |
| 654 | - meta_patch["active_index"] = index_name | |
| 655 | - meta_patch["active_alias"] = get_suggestion_alias_name(tenant_id) | |
| 656 | - self._upsert_meta(tenant_id, meta_patch) | |
| 720 | + now_utc = datetime.now(timezone.utc).isoformat() | |
| 721 | + meta_patch: Dict[str, Any] = { | |
| 722 | + "last_full_build_at": now_utc, | |
| 723 | + "last_incremental_watermark": now_utc, | |
| 724 | + } | |
| 725 | + if publish_alias: | |
| 726 | + meta_patch["active_index"] = index_name | |
| 727 | + meta_patch["active_alias"] = get_suggestion_alias_name(tenant_id) | |
| 728 | + self._upsert_meta(tenant_id, meta_patch) | |
| 657 | 729 | |
| 658 | - return { | |
| 659 | - "mode": "full", | |
| 660 | - "tenant_id": str(tenant_id), | |
| 661 | - "index_name": index_name, | |
| 662 | - "alias_published": bool(alias_publish), | |
| 663 | - "alias_publish": alias_publish, | |
| 664 | - "total_candidates": len(key_to_candidate), | |
| 665 | - "indexed_docs": len(docs), | |
| 666 | - "bulk_result": bulk_result, | |
| 667 | - } | |
| 730 | + return { | |
| 731 | + "mode": "full", | |
| 732 | + "tenant_id": str(tenant_id), | |
| 733 | + "index_name": index_name, | |
| 734 | + "alias_published": bool(alias_publish), | |
| 735 | + "alias_publish": alias_publish, | |
| 736 | + "total_candidates": len(key_to_candidate), | |
| 737 | + "indexed_docs": len(docs), | |
| 738 | + "bulk_result": bulk_result, | |
| 739 | + } | |
| 740 | + except Exception: | |
| 741 | + if index_name and not alias_publish: | |
| 742 | + self.es_client.delete_index(index_name) | |
| 743 | + raise | |
| 668 | 744 | |
| 669 | 745 | def _build_incremental_deltas( |
| 670 | 746 | self, | ... | ... |
tests/test_suggestions.py
| ... | ... | @@ -8,6 +8,7 @@ from suggestion.builder import ( |
| 8 | 8 | QueryDelta, |
| 9 | 9 | SuggestionIndexBuilder, |
| 10 | 10 | get_suggestion_alias_name, |
| 11 | + get_suggestion_versioned_index_name, | |
| 11 | 12 | ) |
| 12 | 13 | from suggestion.service import SuggestionService |
| 13 | 14 | |
| ... | ... | @@ -121,6 +122,16 @@ class FakeESClient: |
| 121 | 122 | self.indices.add(index_name) |
| 122 | 123 | return True |
| 123 | 124 | |
| 125 | + def wait_for_index_ready(self, index_name: str, timeout: str = "10s") -> Dict[str, Any]: | |
| 126 | + self.calls.append({"op": "wait_for_index_ready", "index": index_name, "timeout": timeout}) | |
| 127 | + return {"ok": True, "status": "green", "timed_out": False} | |
| 128 | + | |
| 129 | + def get_allocation_explain(self, index_name: str, shard: int = 0, primary: bool = True) -> Dict[str, Any] | None: | |
| 130 | + self.calls.append( | |
| 131 | + {"op": "get_allocation_explain", "index": index_name, "shard": shard, "primary": primary} | |
| 132 | + ) | |
| 133 | + return None | |
| 134 | + | |
| 124 | 135 | def refresh(self, index_name: str) -> bool: |
| 125 | 136 | self.calls.append({"op": "refresh", "index": index_name}) |
| 126 | 137 | return True |
| ... | ... | @@ -150,6 +161,67 @@ class FakeESClient: |
| 150 | 161 | |
| 151 | 162 | |
| 152 | 163 | @pytest.mark.unit |
| 164 | +def test_versioned_index_name_uses_microseconds(): | |
| 165 | + build_at = datetime(2026, 4, 7, 3, 52, 26, 123456, tzinfo=timezone.utc) | |
| 166 | + assert ( | |
| 167 | + get_suggestion_versioned_index_name("163", build_at) | |
| 168 | + == "search_suggestions_tenant_163_v20260407035226123456" | |
| 169 | + ) | |
| 170 | + | |
| 171 | + | |
| 172 | +@pytest.mark.unit | |
| 173 | +def test_rebuild_cleans_up_unallocatable_new_index(): | |
| 174 | + fake_es = FakeESClient() | |
| 175 | + | |
| 176 | + def _wait_fail(index_name: str, timeout: str = "10s") -> Dict[str, Any]: | |
| 177 | + fake_es.calls.append({"op": "wait_for_index_ready", "index": index_name, "timeout": timeout}) | |
| 178 | + return {"ok": False, "status": "red", "timed_out": True} | |
| 179 | + | |
| 180 | + def _allocation_explain(index_name: str, shard: int = 0, primary: bool = True) -> Dict[str, Any]: | |
| 181 | + fake_es.calls.append( | |
| 182 | + {"op": "get_allocation_explain", "index": index_name, "shard": shard, "primary": primary} | |
| 183 | + ) | |
| 184 | + return { | |
| 185 | + "unassigned_info": {"reason": "INDEX_CREATED", "last_allocation_status": "no"}, | |
| 186 | + "node_allocation_decisions": [ | |
| 187 | + { | |
| 188 | + "node_name": "node-1", | |
| 189 | + "deciders": [ | |
| 190 | + { | |
| 191 | + "decider": "disk_threshold", | |
| 192 | + "decision": "NO", | |
| 193 | + "explanation": "node is above high watermark", | |
| 194 | + } | |
| 195 | + ], | |
| 196 | + } | |
| 197 | + ], | |
| 198 | + } | |
| 199 | + | |
| 200 | + fake_es.wait_for_index_ready = _wait_fail # type: ignore[method-assign] | |
| 201 | + fake_es.get_allocation_explain = _allocation_explain # type: ignore[method-assign] | |
| 202 | + | |
| 203 | + builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None) | |
| 204 | + | |
| 205 | + from config import tenant_config_loader as tcl | |
| 206 | + | |
| 207 | + loader = tcl.get_tenant_config_loader() | |
| 208 | + loader._config = { | |
| 209 | + "default": {"primary_language": "en", "index_languages": ["en", "zh"]}, | |
| 210 | + "tenants": { | |
| 211 | + "163": {"primary_language": "en", "index_languages": ["en", "zh"]}, | |
| 212 | + }, | |
| 213 | + } | |
| 214 | + | |
| 215 | + with pytest.raises(RuntimeError, match="disk_threshold"): | |
| 216 | + builder.rebuild_tenant_index(tenant_id="163") | |
| 217 | + | |
| 218 | + create_calls = [x for x in fake_es.calls if x.get("op") == "create_index"] | |
| 219 | + assert len(create_calls) == 1 | |
| 220 | + created_index = create_calls[0]["index"] | |
| 221 | + assert created_index not in fake_es.indices | |
| 222 | + | |
| 223 | + | |
| 224 | +@pytest.mark.unit | |
| 153 | 225 | def test_resolve_query_language_prefers_log_field(): |
| 154 | 226 | fake_es = FakeESClient() |
| 155 | 227 | builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None) |
| ... | ... | @@ -406,7 +478,7 @@ def test_build_full_candidates_tags_and_qanchor_phrases(monkeypatch): |
| 406 | 478 | "en": ["slim fit", "sporty casual"], |
| 407 | 479 | "zh": ["修身", "显瘦"], |
| 408 | 480 | }, |
| 409 | - "tags": { | |
| 481 | + "enriched_tags": { | |
| 410 | 482 | "en": ["Classic", "ribbed neckline"], |
| 411 | 483 | "zh": ["辣妹风"], |
| 412 | 484 | }, | ... | ... |
utils/es_client.py
| ... | ... | @@ -76,13 +76,70 @@ class ESClient: |
| 76 | 76 | True if successful, False otherwise |
| 77 | 77 | """ |
| 78 | 78 | try: |
| 79 | - self.client.indices.create(index=index_name, body=body) | |
| 79 | + client = self.client.options(request_timeout=30, max_retries=0) | |
| 80 | + client.indices.create( | |
| 81 | + index=index_name, | |
| 82 | + body=body, | |
| 83 | + wait_for_active_shards="0", | |
| 84 | + ) | |
| 80 | 85 | logger.info(f"Index '{index_name}' created successfully") |
| 81 | 86 | return True |
| 82 | 87 | except Exception as e: |
| 88 | + if self.index_exists(index_name): | |
| 89 | + logger.warning( | |
| 90 | + "Create index request for '%s' raised %s, but the index now exists; treating it as created", | |
| 91 | + index_name, | |
| 92 | + type(e).__name__, | |
| 93 | + exc_info=True, | |
| 94 | + ) | |
| 95 | + return True | |
| 83 | 96 | logger.error(f"Failed to create index '{index_name}': {e}", exc_info=True) |
| 84 | 97 | return False |
| 85 | 98 | |
| 99 | + def wait_for_index_ready(self, index_name: str, timeout: str = "10s") -> Dict[str, Any]: | |
| 100 | + """Wait until an index primary shard is allocated and searchable.""" | |
| 101 | + try: | |
| 102 | + resp = self.client.cluster.health( | |
| 103 | + index=index_name, | |
| 104 | + wait_for_status="yellow", | |
| 105 | + timeout=timeout, | |
| 106 | + level="indices", | |
| 107 | + ) | |
| 108 | + index_info = ((resp.get("indices") or {}).get(index_name) or {}) | |
| 109 | + status = index_info.get("status") or resp.get("status") | |
| 110 | + timed_out = bool(resp.get("timed_out")) | |
| 111 | + return { | |
| 112 | + "ok": (not timed_out) and status in {"yellow", "green"}, | |
| 113 | + "status": status, | |
| 114 | + "timed_out": timed_out, | |
| 115 | + "response": resp, | |
| 116 | + } | |
| 117 | + except Exception as e: | |
| 118 | + logger.error("Failed waiting for index '%s' readiness: %s", index_name, e, exc_info=True) | |
| 119 | + return { | |
| 120 | + "ok": False, | |
| 121 | + "status": "unknown", | |
| 122 | + "timed_out": False, | |
| 123 | + "error": str(e), | |
| 124 | + } | |
| 125 | + | |
| 126 | + def get_allocation_explain(self, index_name: str, shard: int = 0, primary: bool = True) -> Optional[Dict[str, Any]]: | |
| 127 | + """Explain why a shard can or cannot be allocated.""" | |
| 128 | + try: | |
| 129 | + return self.client.cluster.allocation_explain( | |
| 130 | + body={"index": index_name, "shard": shard, "primary": primary} | |
| 131 | + ) | |
| 132 | + except Exception as e: | |
| 133 | + logger.warning( | |
| 134 | + "Failed to get allocation explain for index '%s' shard=%s primary=%s: %s", | |
| 135 | + index_name, | |
| 136 | + shard, | |
| 137 | + primary, | |
| 138 | + e, | |
| 139 | + exc_info=True, | |
| 140 | + ) | |
| 141 | + return None | |
| 142 | + | |
| 86 | 143 | def put_alias(self, index_name: str, alias_name: str) -> bool: |
| 87 | 144 | """Add alias for an index.""" |
| 88 | 145 | try: | ... | ... |