Compare View
Commits (2)
-
2. issues文档
-
- consolidate suggestion rebuild flow into build_suggestions.sh via --rebuild and remove the redundant rebuild_suggestions.sh wrapper - make suggestion versioned index names use microseconds and handle index-create retries/timeouts without false already_exists failures - treat create requests as successful when the index was created server-side, then explicitly wait for shard readiness and surface allocation diagnostics - clean up freshly created suggestion indices on rebuild failure to avoid leaving red orphan indices behind - make rebuild smoke tests target the local backend by default, with SUGGESTIONS_SMOKE_BASE_URL as the explicit override - add unit coverage for microsecond versioned index names and cleanup on unallocatable index failures
Showing
19 changed files
Show diff stats
README.md
| @@ -42,15 +42,15 @@ source activate.sh | @@ -42,15 +42,15 @@ source activate.sh | ||
| 42 | - `docs/Usage-Guide.md` -> `服务管理总览` | 42 | - `docs/Usage-Guide.md` -> `服务管理总览` |
| 43 | 43 | ||
| 44 | 核心端口: | 44 | 核心端口: |
| 45 | - | 45 | +- `6001` qp |
| 46 | - `6002` backend(`/search/*`, `/admin/*`) | 46 | - `6002` backend(`/search/*`, `/admin/*`) |
| 47 | -- `6004` indexer(`/indexer/*`) | ||
| 48 | - `6003` frontend | 47 | - `6003` frontend |
| 49 | -- `6010` eval-web(搜索评估 UI,`./scripts/service_ctl.sh` 服务名 `eval-web`) | 48 | +- `6004` indexer(`/indexer/*`) |
| 50 | - `6005` embedding-text(可选,`POST /embed/text`;常见后端为 TEI,默认 `8080`) | 49 | - `6005` embedding-text(可选,`POST /embed/text`;常见后端为 TEI,默认 `8080`) |
| 51 | -- `6008` embedding-image(可选,`POST /embed/image` 等) | ||
| 52 | - `6006` translator(可选) | 50 | - `6006` translator(可选) |
| 53 | - `6007` reranker(可选,`POST /rerank`;精排可与主重排分 `service_profile`,见 `config.yaml` → `fine_rank` / `services.rerank`) | 51 | - `6007` reranker(可选,`POST /rerank`;精排可与主重排分 `service_profile`,见 `config.yaml` → `fine_rank` / `services.rerank`) |
| 52 | +- `6008` embedding-image(可选,`POST /embed/image` 等) | ||
| 53 | +- `6010` eval-web(搜索评估 UI,`./scripts/service_ctl.sh` 服务名 `eval-web`) | ||
| 54 | 54 | ||
| 55 | 更完整示例见 `docs/QUICKSTART.md`。 | 55 | 更完整示例见 `docs/QUICKSTART.md`。 |
| 56 | 56 |
| @@ -0,0 +1,28 @@ | @@ -0,0 +1,28 @@ | ||
| 1 | +product_enrich_prompts增加多模态标注: | ||
| 2 | + | ||
| 3 | +服装品类: | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +模型选项: | ||
| 7 | +参考下面两个模型,综合考虑价格和性能: | ||
| 8 | +Qwen3.5 Plus | ||
| 9 | +qwen3-vl-plus | ||
| 10 | + (Batch调用半价,但是看美国地区是否支持batch调用) | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +品类 Category path | ||
| 14 | +人群/尺码 Target audience | ||
| 15 | +风格 | ||
| 16 | +场景 Usage scene | ||
| 17 | +季节 Applicable season | ||
| 18 | +功能特性 Functional features | ||
| 19 | +版型 | ||
| 20 | +廓形 | ||
| 21 | +袖型 | ||
| 22 | +领型 | ||
| 23 | +面料/材质 Material description | ||
| 24 | +图案/设计 | ||
| 25 | +色系 | ||
| 26 | + | ||
| 27 | + | ||
| 28 | +要有缓存:hash(图片+标题)→llm结果 |
| @@ -0,0 +1,69 @@ | @@ -0,0 +1,69 @@ | ||
| 1 | +这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查: | ||
| 2 | + | ||
| 3 | +总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | ||
| 4 | +先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | ||
| 5 | + | ||
| 6 | +示例prompt和对应的9个分类词: | ||
| 7 | +Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} | ||
| 8 | + | ||
| 9 | +做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。 | ||
| 10 | + | ||
| 11 | +使用Tesla T4,因此: | ||
| 12 | +1. 使用FP16。不用 BF16 作为主路径 | ||
| 13 | +2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | ||
| 14 | +3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 | ||
| 15 | + | ||
| 16 | + | ||
| 17 | +主要考虑优化方向为: | ||
| 18 | +hidden_last -> N-class scorer -> argmax | ||
| 19 | +参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | ||
| 20 | +去 full vocab logits | ||
| 21 | +去 decode / constrained decode | ||
| 22 | +query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) | ||
| 23 | +专用 tail kernel(输出 N 类原始分数) | ||
| 24 | + | ||
| 25 | +以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: | ||
| 26 | +15.1 预分配 | ||
| 27 | +对每个 bucket 预分配: | ||
| 28 | +input ids buffer | ||
| 29 | +attention scratch buffer | ||
| 30 | +output hidden buffer | ||
| 31 | +class id buffer | ||
| 32 | +15.2 Graph capture | ||
| 33 | + | ||
| 34 | +每个 bucket 预先 capture 一张图: | ||
| 35 | +embedding | ||
| 36 | +transformer 主干 | ||
| 37 | +compact last-state | ||
| 38 | +9-way tail kernel | ||
| 39 | + | ||
| 40 | +注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | ||
| 41 | + | ||
| 42 | +你有sudo权限,你可以执行为本项目安装自己的环境 | ||
| 43 | + | ||
| 44 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) | ||
| 45 | +或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | ||
| 49 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 | ||
| 50 | +注意使用fp16版本,不要使用量化版本。 | ||
| 51 | + | ||
| 52 | +上面的需求你已经满足了大部分, | ||
| 53 | +但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。 | ||
| 54 | +需要通过多 token 标签做极致的性能优化,避免串行decode。 | ||
| 55 | +一个考虑的方向是: | ||
| 56 | +将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch) | ||
| 57 | + | ||
| 58 | +只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。 | ||
| 59 | + | ||
| 60 | +所有后续聚合均为向量化操作(GPU 上几乎不花时间)。 | ||
| 61 | + | ||
| 62 | +也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 | ||
| 63 | + | ||
| 64 | +参考文档: | ||
| 65 | + | ||
| 66 | +vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ | ||
| 67 | +PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ | ||
| 68 | +TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | ||
| 69 | +Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq | ||
| 0 | \ No newline at end of file | 70 | \ No newline at end of file |
| @@ -0,0 +1,98 @@ | @@ -0,0 +1,98 @@ | ||
| 1 | +这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查: | ||
| 2 | + | ||
| 3 | +总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | ||
| 4 | +先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | ||
| 5 | + | ||
| 6 | +示例prompt和对应的9个分类词: | ||
| 7 | +Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} | ||
| 8 | + | ||
| 9 | +做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。 | ||
| 10 | + | ||
| 11 | +使用Tesla T4,因此: | ||
| 12 | +1. 使用FP16。不用 BF16 作为主路径 | ||
| 13 | +2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | ||
| 14 | +3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 | ||
| 15 | + | ||
| 16 | + | ||
| 17 | +主要考虑优化方向为: | ||
| 18 | +hidden_last -> N-class scorer -> argmax | ||
| 19 | +参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | ||
| 20 | +去 full vocab logits | ||
| 21 | +去 decode / constrained decode | ||
| 22 | +query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) | ||
| 23 | +专用 tail kernel(输出 N 类原始分数) | ||
| 24 | + | ||
| 25 | +以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: | ||
| 26 | +15.1 预分配 | ||
| 27 | +对每个 bucket 预分配: | ||
| 28 | +input ids buffer | ||
| 29 | +attention scratch buffer | ||
| 30 | +output hidden buffer | ||
| 31 | +class id buffer | ||
| 32 | +15.2 Graph capture | ||
| 33 | + | ||
| 34 | +每个 bucket 预先 capture 一张图: | ||
| 35 | +embedding | ||
| 36 | +transformer 主干 | ||
| 37 | +compact last-state | ||
| 38 | +9-way tail kernel | ||
| 39 | + | ||
| 40 | +注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | ||
| 41 | + | ||
| 42 | +你有sudo权限,你可以执行为本项目安装自己的环境 | ||
| 43 | + | ||
| 44 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) | ||
| 45 | +或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | ||
| 49 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 | ||
| 50 | +注意使用fp16版本,不要使用量化版本。 | ||
| 51 | + | ||
| 52 | +上面的需求你已经满足了大部分, | ||
| 53 | +但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。 | ||
| 54 | +需要通过多 token 标签做极致的性能优化,避免串行decode。 | ||
| 55 | +一个考虑的方向是: | ||
| 56 | +将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch) | ||
| 57 | + | ||
| 58 | +只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。 | ||
| 59 | + | ||
| 60 | +所有后续聚合均为向量化操作(GPU 上几乎不花时间)。 | ||
| 61 | + | ||
| 62 | +也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 | ||
| 63 | + | ||
| 64 | +参考文档: | ||
| 65 | + | ||
| 66 | +vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ | ||
| 67 | +PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ | ||
| 68 | +TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | ||
| 69 | +Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq | ||
| 70 | + | ||
| 71 | + | ||
| 72 | +现在还有以下问题: | ||
| 73 | +1. 请满足: | ||
| 74 | +先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | ||
| 75 | + | ||
| 76 | +2. 打分结果不对:不管输入什么都输出jeans,输入裙子也输出jeans,只有精确的输入skirt才能输出skirt。 | ||
| 77 | + | ||
| 78 | +3. 现在是一个prompt,我需要增加6个prompt,都进行配置化,每次输入一个query、对7个prompt进行批量的推理,对每个prompt都输出词典中每个分类的打分分布 | ||
| 79 | +人群 | ||
| 80 | +Analyze the target user group of the given query. Output exactly one label from: boy, girl, man, woman, pregnant. Output nothing else query:{query} | ||
| 81 | + | ||
| 82 | +尺寸 | ||
| 83 | +Analyze the size intent of the given query. Output exactly one label from: xs, s, m, l, xl, xxl, plus, petite. Output nothing else query:{query} | ||
| 84 | + | ||
| 85 | +领口 | ||
| 86 | +Analyze the neckline type of the given query. Output exactly one label from: round, v, turtleneck, collared, scoop, offshoulder. Output nothing else query:{query} | ||
| 87 | + | ||
| 88 | +袖长类型 | ||
| 89 | +Analyze the sleeve length of the given query. Output exactly one label from: long, short, sleeveless, half, threequarter, cap. Output nothing else query:{query} | ||
| 90 | + | ||
| 91 | +上衣长度类型 | ||
| 92 | +Analyze the top length of the given query. Output exactly one label from: long, short, regular, cropped, tunic, midi. Output nothing else query:{query} | ||
| 93 | + | ||
| 94 | +面料 | ||
| 95 | +Analyze the fabric material of the given query. Output exactly one label from: cotton, linen, silk, wool, polyester, denim, leather, chiffon, fleece. Output nothing else query:{query} | ||
| 96 | + | ||
| 97 | + | ||
| 98 | +注意要使用测试用例进行测试。包括打分结果是否符合预期、性能测试。 |
| @@ -0,0 +1,4 @@ | @@ -0,0 +1,4 @@ | ||
| 1 | + | ||
| 2 | +1. “仓库启动时会为所有 batch_sizes x continuation_buckets x prompt 预建 bucket”,之前可能是根据“极致的性能要求”、“不要做任何懒加载,确保真实请求发生时得到极致的响应时间”所做的设计,这显然太过了,我是希望一些基本的可以事先加载的应该先加载,但是牺牲巨大的显存占用来换取微弱的耗时提升是不提倡的,请你站在更高的角度,理会我的需求(先加载好模型、跨session的KV cache、并针对特殊用法 即score方式而不是逐步decode方式,来极致的优化性能,用于对线上的单个query,以最短耗时得到7个prompt的分类结果) | ||
| 3 | + | ||
| 4 | +2. 现在的 7 个 prompt推理是串行的,MultiPromptRunner.score_query() 里就是 for runner in self.runners 一个一个跑,需要把执行模型改成“按 prompt 分组批量并行”,但是现在有 2 个 fast path 和 5 个 multi-token path,是各合成一次 forward,还是有可能合成一个?因为multi-token也可以一次线prefill进去,是否能做到跟fast path同级别的性能?请你站在更高的角度进行思考,保证性能的同时降低复杂性。 | ||
| 0 | \ No newline at end of file | 5 | \ No newline at end of file |
| @@ -0,0 +1,1003 @@ | @@ -0,0 +1,1003 @@ | ||
| 1 | +总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | ||
| 2 | +先专注于推理的优化,最后再考虑服务化,支持一定程度的并发(比如4)的请求,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,使用N个prompt进行N个维度的分类,对于每个prompt,推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 | ||
| 3 | + | ||
| 4 | +下面有一些参考技术资料,但是你并不需要严格,你应该有一定的灵活度,来追求极致的性能。 | ||
| 5 | + | ||
| 6 | +在 Tesla T4 上,用 3B 到 6B 级别的开源 decoder-only 基座模型做 query 分类。 | ||
| 7 | +启动时完成 tokenizer、权重、prefix cache 和共享执行器准备工作。 | ||
| 8 | +每次输入一个 query,输出每个 prompt 下每个 label 的分数分布,以及预测耗时和阶段耗时。 | ||
| 9 | +不走通用生成路径,不做 decode,不取 full vocab logits,不做 constrained decode。 | ||
| 10 | +对 multi-token label 做专门优化,避免 Python 侧串行 decode。 | ||
| 11 | +prompt 和 label 集合必须可配置(目前只有两个,以后我会加到8个,每次请求输入一个query,并行的调用8个prompt进行推理得到得分最高的label: | ||
| 12 | + | ||
| 13 | + | ||
| 14 | +prompts: | ||
| 15 | + - name: category | ||
| 16 | + prompt_template: | | ||
| 17 | + <|im_start|>system | ||
| 18 | + Analyze the category intent. Output exactly one label from [none, dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other]. Use 'none' if no category intent.<|im_end|> | ||
| 19 | + <|im_start|>user | ||
| 20 | + query: {query}<|im_end|> | ||
| 21 | + <|im_start|>assistant | ||
| 22 | + label: | ||
| 23 | + label_prefix: " " | ||
| 24 | + labels: [dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other, none] | ||
| 25 | + | ||
| 26 | + - name: audience | ||
| 27 | + prompt_template: | | ||
| 28 | + <|im_start|>system | ||
| 29 | + Analyze the target user group. Output exactly one label from [none, boy, girl, man, woman, pregnant]. Use 'none' if no audience mentioned.<|im_end|> | ||
| 30 | + <|im_start|>user | ||
| 31 | + query: {query}<|im_end|> | ||
| 32 | + <|im_start|>assistant | ||
| 33 | + label: | ||
| 34 | + label_prefix: " " | ||
| 35 | + labels: [boy, girl, man, woman, pregnant, none] | ||
| 36 | + | ||
| 37 | + | ||
| 38 | +做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟。 | ||
| 39 | + | ||
| 40 | +主要考虑优化方向为: | ||
| 41 | +1. hidden_last -> N-class scorer -> argmax | ||
| 42 | +2. 参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | ||
| 43 | +3. 去 full vocab logits | ||
| 44 | +4. 去 decode / constrained decode | ||
| 45 | +5. 专用 tail kernel(输出 N 类原始分数) | ||
| 46 | +6. 配置的N个 prompt推理要并行推理(2-8个) | ||
| 47 | +7. 使用Tesla T4,因此不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | ||
| 48 | +多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +你有sudo权限,你可以执行为本项目安装自己的环境 | ||
| 52 | + | ||
| 53 | +使用Qwen/Qwen3-8B的Q4或Q8模型,具体用哪个版本,请你查找huggingface相关资料,选择合适的版本完成部署,并进行推理耗时的测试。 | ||
| 54 | + | ||
| 55 | +请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | ||
| 56 | + | ||
| 57 | +一个重要的问题:一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况。 | ||
| 58 | +需要通过多 token 标签做极致的性能优化,避免串行decode。 | ||
| 59 | +我们最终目的是得到哪个label的得分最高,不一定要精确的概率,计算log P(id1 | query, prompt) + log P(id2 | query, prompt, id1)有可能导致难以优化性能,精确的概率是可以考虑放弃的,要清楚我们的最终目的,达到分类的目的即可,只要得到分类,优先考虑性能,精确的概率可以放下。 | ||
| 60 | +如何通过一次模型 forward处理包括多token label的整个 batch,是你需要探索的问题。 | ||
| 61 | + | ||
| 62 | +单 token fast path 的做法比较确定: last_hidden -> small class scorer -> argmax。 只取目标 label 对应 LM head 行,不做 full vocab 输出。 | ||
| 63 | +multi-token 怎么做需要搜索相关资料进行考量,最好要做到跟单token开销相同(放弃精确的log-prob的前提下。但是:多token和单token的label的打分的对比,一定要是可比的才能正确的分类,兼顾性能和打分的准确性) | ||
| 64 | + | ||
| 65 | +还需要增加一个配置:force_single_token_labels,所有 label 都按首 token 处理,因为,如果各个label收token不同,那么可以近似的认为首token打分代表整个label打分。 | ||
| 66 | +你需要找到多label打分性能和准确性上面的最佳实践。同时也支持force_single_token_labels以达到极致的性能。 | ||
| 67 | + | ||
| 68 | +也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 | ||
| 69 | +以下是一些参考示例: | ||
| 70 | +vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ | ||
| 71 | +PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ | ||
| 72 | +TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | ||
| 73 | +Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq | ||
| 74 | + | ||
| 75 | +TensorRT support matrix: T4 / SM7.5 supports FP16 and INT8, but not BF16/FP8 in the main matrix. https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Support-Matrix-Guide.pdf | ||
| 76 | +TensorRT-LLM support matrix: current official hardware list omits Turing/T4, so T4 is effectively community-support territory there. https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html | ||
| 77 | +FlashAttention repo: FA3 is Hopper-focused; current published benchmarks are A100/H100-centric. https://github.com/Dao-AILab/flash-attention | ||
| 78 | +vLLM APC docs: KV reuse via hashed KV blocks was the baseline idea for the prefix-cache metadata. https://docs.vllm.ai/_/downloads/en/v0.6.2/pdf/ | ||
| 79 | +SGLang HiCache/RadixAttention docs: useful reference for prefix-cache reuse and page-granular KV organization. https://docs.sglang.io/advanced_features/hicache_design.html | ||
| 80 | +FasterTransformer repo: still a useful T4 FP16 optimization baseline and historical Turing-oriented reference. https://github.com/NVIDIA/FasterTransformer | ||
| 81 | +xFormers README: relevant as a Turing-friendly attention alternative; my mainline choice here is PyTorch SDPA on T4, which is an engineering inference from these sources rather than a direct vendor recommendation. https://github.com/facebookresearch/xformers | ||
| 82 | + | ||
| 83 | +注意:已经有一个项目 llm-qp, llm-qp2,这两个项目,对于单token的处理方式是可以的: | ||
| 84 | +SDPA | ||
| 85 | +prefix cache | ||
| 86 | +prebuilt bucket + CUDA graph | ||
| 87 | +他的核心代码是: | ||
| 88 | +from __future__ import annotations | ||
| 89 | + | ||
| 90 | +import hashlib | ||
| 91 | +import time | ||
| 92 | +from dataclasses import asdict, dataclass | ||
| 93 | +from typing import Iterable | ||
| 94 | + | ||
| 95 | +import torch | ||
| 96 | +import torch.nn.functional as F | ||
| 97 | +from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache | ||
| 98 | + | ||
| 99 | +try: | ||
| 100 | + from transformers import BitsAndBytesConfig | ||
| 101 | +except ImportError: # pragma: no cover | ||
| 102 | + BitsAndBytesConfig = None | ||
| 103 | + | ||
| 104 | +from llm_qp.config import PromptTaskConfig, RuntimeConfig | ||
| 105 | +from llm_qp.scorer import SmallClassScorer | ||
| 106 | + | ||
| 107 | +try: | ||
| 108 | + from torch.nn.attention import SDPBackend, sdpa_kernel | ||
| 109 | +except ImportError: # pragma: no cover | ||
| 110 | + SDPBackend = None | ||
| 111 | + sdpa_kernel = None | ||
| 112 | + | ||
| 113 | + | ||
| 114 | +@dataclass(slots=True) | ||
| 115 | +class EncodedLabel: | ||
| 116 | + text: str | ||
| 117 | + token_ids: list[int] | ||
| 118 | + | ||
| 119 | + | ||
| 120 | +@dataclass(slots=True) | ||
| 121 | +class PrefixCache: | ||
| 122 | + prefix_ids: list[int] | ||
| 123 | + prefix_hashes: list[str] | ||
| 124 | + raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] | ||
| 125 | + | ||
| 126 | + @property | ||
| 127 | + def prefix_len(self) -> int: | ||
| 128 | + return len(self.prefix_ids) | ||
| 129 | + | ||
| 130 | + | ||
| 131 | +@dataclass(slots=True) | ||
| 132 | +class MultiTokenTables: | ||
| 133 | + label_token_ids: torch.Tensor | ||
| 134 | + label_token_mask: torch.Tensor | ||
| 135 | + label_prefix_ids: torch.Tensor | ||
| 136 | + label_prefix_mask: torch.Tensor | ||
| 137 | + label_position_offsets: torch.Tensor | ||
| 138 | + | ||
| 139 | + @property | ||
| 140 | + def max_label_len(self) -> int: | ||
| 141 | + return self.label_token_ids.shape[1] | ||
| 142 | + | ||
| 143 | + @property | ||
| 144 | + def max_label_prefix_len(self) -> int: | ||
| 145 | + return self.label_prefix_ids.shape[1] | ||
| 146 | + | ||
| 147 | + | ||
| 148 | +@dataclass(slots=True) | ||
| 149 | +class QueryScoreResult: | ||
| 150 | + task_name: str | ||
| 151 | + query: str | ||
| 152 | + predicted_label: str | ||
| 153 | + scores: list[tuple[str, float, float]] | ||
| 154 | + total_ms: float | ||
| 155 | + stage_ms: dict[str, float] | ||
| 156 | + fast_path: bool | ||
| 157 | + prefix_tokens: int | ||
| 158 | + continuation_tokens: int | ||
| 159 | + label_token_lengths: dict[str, int] | ||
| 160 | + | ||
| 161 | + @property | ||
| 162 | + def predicted_prob(self) -> float: | ||
| 163 | + for label, _score, prob in self.scores: | ||
| 164 | + if label == self.predicted_label: | ||
| 165 | + return prob | ||
| 166 | + return 0.0 | ||
| 167 | + | ||
| 168 | + | ||
| 169 | +@dataclass(slots=True) | ||
| 170 | +class MultiPromptScoreResult: | ||
| 171 | + query: str | ||
| 172 | + total_ms: float | ||
| 173 | + details: list[QueryScoreResult] | ||
| 174 | + stage_ms: dict[str, float] | ||
| 175 | + | ||
| 176 | + def http_json(self) -> dict[str, object]: | ||
| 177 | + return { | ||
| 178 | + "query": self.query, | ||
| 179 | + "total_ms": self.total_ms, | ||
| 180 | + "stage_ms": self.stage_ms, | ||
| 181 | + "details": [asdict(t) for t in self.details], | ||
| 182 | + "task_results": { | ||
| 183 | + t.task_name: [t.predicted_label, t.continuation_tokens, t.predicted_prob] for t in self.details if t.predicted_label != 'none' | ||
| 184 | + }, | ||
| 185 | + } | ||
| 186 | + | ||
| 187 | + | ||
| 188 | +@dataclass(slots=True) | ||
| 189 | +class BatchScoreResult: | ||
| 190 | + batch_size: int | ||
| 191 | + total_ms: float | ||
| 192 | + results: list[MultiPromptScoreResult] | ||
| 193 | + stage_ms: dict[str, float] | ||
| 194 | + | ||
| 195 | + | ||
| 196 | +@dataclass(slots=True) | ||
| 197 | +class SharedRuntime: | ||
| 198 | + device: torch.device | ||
| 199 | + dtype: torch.dtype | ||
| 200 | + tokenizer: object | ||
| 201 | + model: object | ||
| 202 | + backbone: object | ||
| 203 | + hidden_size: int | ||
| 204 | + graph_capture_pool: object | None = None | ||
| 205 | + graph_capture_stream: torch.cuda.Stream | None = None | ||
| 206 | + | ||
| 207 | + | ||
| 208 | +@dataclass(slots=True) | ||
| 209 | +class PromptBatchPlan: | ||
| 210 | + runner: "PromptClassifierRunner" | ||
| 211 | + row_start: int | ||
| 212 | + row_count: int | ||
| 213 | + score_buffer: torch.Tensor | ||
| 214 | + | ||
| 215 | + @property | ||
| 216 | + def row_stop(self) -> int: | ||
| 217 | + return self.row_start + self.row_count | ||
| 218 | + | ||
| 219 | + | ||
| 220 | +@dataclass(slots=True) | ||
| 221 | +class MixedPrefixCache: | ||
| 222 | + batch_size: int | ||
| 223 | + total_rows: int | ||
| 224 | + prefix_lengths: torch.Tensor | ||
| 225 | + attention_mask: torch.Tensor | ||
| 226 | + raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] | ||
| 227 | + | ||
| 228 | + @property | ||
| 229 | + def max_prefix_len(self) -> int: | ||
| 230 | + return int(self.prefix_lengths.max().item()) if self.prefix_lengths.numel() else 0 | ||
| 231 | + | ||
| 232 | + | ||
| 233 | +@dataclass(slots=True) | ||
| 234 | +class BatchLayout: | ||
| 235 | + batch_size: int | ||
| 236 | + total_rows: int | ||
| 237 | + plans: list[PromptBatchPlan] | ||
| 238 | + | ||
| 239 | + | ||
| 240 | +@dataclass(slots=True) | ||
| 241 | +class MixedBucketRuntime: | ||
| 242 | + batch_size: int | ||
| 243 | + total_rows: int | ||
| 244 | + continuation_len: int | ||
| 245 | + max_input_len: int | ||
| 246 | + input_ids: torch.Tensor | ||
| 247 | + attention_mask: torch.Tensor | ||
| 248 | + position_ids: torch.Tensor | ||
| 249 | + last_hidden_state: torch.Tensor | ||
| 250 | + graph: torch.cuda.CUDAGraph | None = None | ||
| 251 | + | ||
| 252 | + | ||
| 253 | +@dataclass(slots=True) | ||
| 254 | +class PreloadReport: | ||
| 255 | + total_ms: float | ||
| 256 | + stage_ms: dict[str, float] | ||
| 257 | + runtime: dict[str, object] | ||
| 258 | + | ||
| 259 | + | ||
| 260 | +def _hash_blocks(token_ids: Iterable[int], block_size: int) -> list[str]: | ||
| 261 | + token_list = list(token_ids) | ||
| 262 | + hashes: list[str] = [] | ||
| 263 | + for start in range(0, len(token_list), block_size): | ||
| 264 | + block = token_list[start : start + block_size] | ||
| 265 | + payload = ",".join(str(x) for x in block).encode("utf-8") | ||
| 266 | + hashes.append(hashlib.sha1(payload).hexdigest()) | ||
| 267 | + return hashes | ||
| 268 | + | ||
| 269 | + | ||
| 270 | +def _expand_legacy_cache( | ||
| 271 | + raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...], | ||
| 272 | + batch_size: int, | ||
| 273 | +) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: | ||
| 274 | + expanded: list[tuple[torch.Tensor, torch.Tensor]] = [] | ||
| 275 | + for key, value in raw_cache: | ||
| 276 | + expanded.append( | ||
| 277 | + ( | ||
| 278 | + key.expand(batch_size, *key.shape[1:]).contiguous(), | ||
| 279 | + value.expand(batch_size, *value.shape[1:]).contiguous(), | ||
| 280 | + ) | ||
| 281 | + ) | ||
| 282 | + return tuple(expanded) | ||
| 283 | + | ||
| 284 | + | ||
| 285 | +class PromptClassifierRunner: | ||
| 286 | + def __init__( | ||
| 287 | + self, | ||
| 288 | + cfg: RuntimeConfig, | ||
| 289 | + task_cfg: PromptTaskConfig, | ||
| 290 | + shared_runtime: SharedRuntime, | ||
| 291 | + ): | ||
| 292 | + self.cfg = cfg | ||
| 293 | + self.task_cfg = task_cfg | ||
| 294 | + self.device = shared_runtime.device | ||
| 295 | + self.dtype = shared_runtime.dtype | ||
| 296 | + self.tokenizer = shared_runtime.tokenizer | ||
| 297 | + self.model = shared_runtime.model | ||
| 298 | + self.backbone = shared_runtime.backbone | ||
| 299 | + self.hidden_size = shared_runtime.hidden_size | ||
| 300 | + self.prefix_text, self.suffix_text = task_cfg.prompt_parts | ||
| 301 | + self.prefix_ids = self.tokenizer.encode(self.prefix_text, add_special_tokens=False) | ||
| 302 | + self.suffix_ids = self.tokenizer.encode(self.suffix_text, add_special_tokens=False) | ||
| 303 | + self.labels = list(task_cfg.labels) | ||
| 304 | + self.encoded_labels = [ | ||
| 305 | + EncodedLabel(text=label, token_ids=self._encode_label_token_ids(label)) | ||
| 306 | + for label in self.labels | ||
| 307 | + ] | ||
| 308 | + self.num_labels = len(self.labels) | ||
| 309 | + self.lm_head = self.model.get_output_embeddings() | ||
| 310 | + self.lm_head_weight = self.lm_head.weight.detach() | ||
| 311 | + self.lm_head_bias = self.lm_head.bias.detach() if getattr(self.lm_head, "bias", None) is not None else None | ||
| 312 | + if self.cfg.force_single_token_labels and not self._has_unique_single_token_labels(): | ||
| 313 | + raise ValueError( | ||
| 314 | + f"prompt task '{self.task_cfg.name}' cannot force single-token labels because first tokens collide" | ||
| 315 | + ) | ||
| 316 | + self.fast_path = self._has_unique_single_token_labels() | ||
| 317 | + self.fast_path_token_ids = [item.token_ids[0] for item in self.encoded_labels] if self.fast_path else [] | ||
| 318 | + self.scorer = self._build_scorer() if self.fast_path else None | ||
| 319 | + self.multi_token_tables = self._build_multi_token_tables() if not self.fast_path else None | ||
| 320 | + self.prefix_cache = self._build_prefix_cache() | ||
| 321 | + | ||
| 322 | + def _encode_label_token_ids(self, label: str) -> list[int]: | ||
| 323 | + token_ids = self.tokenizer.encode( | ||
| 324 | + f"{self.task_cfg.label_prefix}{label}", | ||
| 325 | + add_special_tokens=False, | ||
| 326 | + ) | ||
| 327 | + if not token_ids: | ||
| 328 | + raise ValueError(f"label '{label}' in prompt '{self.task_cfg.name}' tokenizes to an empty sequence") | ||
| 329 | + if self.cfg.force_single_token_labels: | ||
| 330 | + return token_ids[:1] | ||
| 331 | + return token_ids | ||
| 332 | + | ||
| 333 | + def _has_unique_single_token_labels(self) -> bool: | ||
| 334 | + token_ids: list[int] = [] | ||
| 335 | + for item in self.encoded_labels: | ||
| 336 | + if len(item.token_ids) != 1: | ||
| 337 | + return False | ||
| 338 | + token_ids.append(item.token_ids[0]) | ||
| 339 | + return len(token_ids) == len(set(token_ids)) | ||
| 340 | + | ||
| 341 | + def _build_scorer(self) -> SmallClassScorer: | ||
| 342 | + token_ids = torch.tensor(self.fast_path_token_ids, dtype=torch.long, device=self.device) | ||
| 343 | + weights = torch.index_select(self.lm_head_weight, 0, token_ids).to(dtype=self.dtype).contiguous() | ||
| 344 | + bias = None | ||
| 345 | + if self.lm_head_bias is not None: | ||
| 346 | + bias = torch.index_select(self.lm_head_bias, 0, token_ids).to(dtype=self.dtype).contiguous() | ||
| 347 | + return SmallClassScorer(weights=weights, bias=bias) | ||
| 348 | + | ||
| 349 | + def _build_multi_token_tables(self) -> MultiTokenTables: | ||
| 350 | + max_label_len = max(len(item.token_ids) for item in self.encoded_labels) | ||
| 351 | + max_prefix_len = max(len(item.token_ids) - 1 for item in self.encoded_labels) | ||
| 352 | + label_token_ids = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.long) | ||
| 353 | + label_token_mask = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.float32) | ||
| 354 | + label_prefix_ids = torch.full( | ||
| 355 | + (self.num_labels, max_prefix_len), | ||
| 356 | + fill_value=self.tokenizer.pad_token_id, | ||
| 357 | + device=self.device, | ||
| 358 | + dtype=torch.long, | ||
| 359 | + ) | ||
| 360 | + label_prefix_mask = torch.zeros((self.num_labels, max_prefix_len), device=self.device, dtype=torch.long) | ||
| 361 | + for idx, item in enumerate(self.encoded_labels): | ||
| 362 | + token_ids = torch.tensor(item.token_ids, device=self.device, dtype=torch.long) | ||
| 363 | + token_len = token_ids.numel() | ||
| 364 | + label_token_ids[idx, :token_len] = token_ids | ||
| 365 | + label_token_mask[idx, :token_len] = 1.0 | ||
| 366 | + if token_len > 1: | ||
| 367 | + prefix_len = token_len - 1 | ||
| 368 | + label_prefix_ids[idx, :prefix_len] = token_ids[:-1] | ||
| 369 | + label_prefix_mask[idx, :prefix_len] = 1 | ||
| 370 | + return MultiTokenTables( | ||
| 371 | + label_token_ids=label_token_ids.contiguous(), | ||
| 372 | + label_token_mask=label_token_mask.contiguous(), | ||
| 373 | + label_prefix_ids=label_prefix_ids.contiguous(), | ||
| 374 | + label_prefix_mask=label_prefix_mask.contiguous(), | ||
| 375 | + label_position_offsets=torch.arange(max_label_len, device=self.device, dtype=torch.long), | ||
| 376 | + ) | ||
| 377 | + | ||
| 378 | + @torch.inference_mode() | ||
| 379 | + def _build_prefix_cache(self) -> PrefixCache: | ||
| 380 | + if not self.prefix_ids: | ||
| 381 | + return PrefixCache(prefix_ids=[], prefix_hashes=[], raw_cache=tuple()) | ||
| 382 | + prefix_tensor = torch.tensor([self.prefix_ids], dtype=torch.long, device=self.device) | ||
| 383 | + attention_mask = torch.ones_like(prefix_tensor, dtype=torch.long, device=self.device) | ||
| 384 | + outputs = self.model( | ||
| 385 | + input_ids=prefix_tensor, | ||
| 386 | + attention_mask=attention_mask, | ||
| 387 | + use_cache=True, | ||
| 388 | + return_dict=True, | ||
| 389 | + ) | ||
| 390 | + raw_cache = tuple( | ||
| 391 | + (layer.keys.detach(), layer.values.detach()) | ||
| 392 | + for layer in outputs.past_key_values.layers | ||
| 393 | + ) | ||
| 394 | + return PrefixCache( | ||
| 395 | + prefix_ids=list(self.prefix_ids), | ||
| 396 | + prefix_hashes=_hash_blocks(self.prefix_ids, self.cfg.prefix_block_size), | ||
| 397 | + raw_cache=raw_cache, | ||
| 398 | + ) | ||
| 399 | + | ||
| 400 | + def expand_prefix_raw_cache(self, batch_size: int) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: | ||
| 401 | + if not self.prefix_cache.raw_cache: | ||
| 402 | + return tuple() | ||
| 403 | + return _expand_legacy_cache(self.prefix_cache.raw_cache, batch_size) | ||
| 404 | + | ||
| 405 | + def build_continuation_from_query_ids(self, query_ids: list[int]) -> list[int]: | ||
| 406 | + continuation = query_ids + self.suffix_ids | ||
| 407 | + if not continuation: | ||
| 408 | + raise ValueError("prompt continuation is empty after substituting query") | ||
| 409 | + if self.prefix_cache.prefix_len + len(continuation) > self.cfg.max_length: | ||
| 410 | + raise ValueError( | ||
| 411 | + f"sequence length {self.prefix_cache.prefix_len + len(continuation)} exceeds max_length={self.cfg.max_length}" | ||
| 412 | + ) | ||
| 413 | + return continuation | ||
| 414 | + | ||
| 415 | + @torch.inference_mode() | ||
| 416 | + def reduce_fast_scores( | ||
| 417 | + self, | ||
| 418 | + hidden: torch.Tensor, | ||
| 419 | + out_scores: torch.Tensor, | ||
| 420 | + ) -> None: | ||
| 421 | + assert self.scorer is not None | ||
| 422 | + out_scores.copy_(self.scorer(hidden)) | ||
| 423 | + | ||
| 424 | + @torch.inference_mode() | ||
| 425 | + def reduce_multi_token_scores( | ||
| 426 | + self, | ||
| 427 | + last_hidden_state: torch.Tensor, | ||
| 428 | + batch_size: int, | ||
| 429 | + max_input_len: int, | ||
| 430 | + score_positions: torch.Tensor, | ||
| 431 | + out_scores: torch.Tensor, | ||
| 432 | + ) -> None: | ||
| 433 | + assert self.multi_token_tables is not None | ||
| 434 | + hidden = last_hidden_state.reshape(batch_size, self.num_labels, max_input_len, self.hidden_size) | ||
| 435 | + gather_index = score_positions[:, None, :, None].expand( | ||
| 436 | + batch_size, | ||
| 437 | + self.num_labels, | ||
| 438 | + self.multi_token_tables.max_label_len, | ||
| 439 | + self.hidden_size, | ||
| 440 | + ) | ||
| 441 | + gathered_hidden = torch.gather(hidden, 2, gather_index) | ||
| 442 | + used_mask = self.multi_token_tables.label_token_mask.unsqueeze(0).expand(batch_size, -1, -1).bool() | ||
| 443 | + | ||
| 444 | + token_log_probs = torch.zeros( | ||
| 445 | + (batch_size, self.num_labels, self.multi_token_tables.max_label_len), | ||
| 446 | + device=self.device, | ||
| 447 | + dtype=torch.float32, | ||
| 448 | + ) | ||
| 449 | + if used_mask.any(): | ||
| 450 | + flat_hidden = gathered_hidden[used_mask] | ||
| 451 | + flat_token_ids = self.multi_token_tables.label_token_ids.unsqueeze(0).expand(batch_size, -1, -1)[used_mask] | ||
| 452 | + linear_hidden = flat_hidden.to(self.dtype) if self.device.type == "cuda" else flat_hidden.float() | ||
| 453 | + linear_weight = self.lm_head_weight if self.device.type == "cuda" else self.lm_head_weight.float() | ||
| 454 | + linear_bias = self.lm_head_bias | ||
| 455 | + if linear_bias is not None and self.device.type != "cuda": | ||
| 456 | + linear_bias = linear_bias.float() | ||
| 457 | + flat_logits = F.linear(linear_hidden, linear_weight, linear_bias) | ||
| 458 | + flat_selected = flat_logits.gather(1, flat_token_ids.unsqueeze(1)).squeeze(1).float() | ||
| 459 | + flat_log_norm = torch.logsumexp(flat_logits.float(), dim=-1) | ||
| 460 | + token_log_probs[used_mask] = flat_selected - flat_log_norm | ||
| 461 | + out_scores.copy_(token_log_probs.sum(dim=-1)) | ||
| 462 | + | ||
| 463 | + def build_score_result( | ||
| 464 | + self, | ||
| 465 | + query: str, | ||
| 466 | + scores: torch.Tensor, | ||
| 467 | + stage_ms: dict[str, float], | ||
| 468 | + continuation_tokens: int, | ||
| 469 | + ) -> QueryScoreResult: | ||
| 470 | + score_values = scores.detach().float().cpu().tolist() | ||
| 471 | + best_idx = max(range(len(score_values)), key=score_values.__getitem__) | ||
| 472 | + probs = torch.softmax(torch.tensor(score_values, dtype=torch.float32), dim=0).tolist() | ||
| 473 | + return QueryScoreResult( | ||
| 474 | + task_name=self.task_cfg.name, | ||
| 475 | + query=query, | ||
| 476 | + predicted_label=self.labels[best_idx], | ||
| 477 | + scores=[ | ||
| 478 | + (label, score, prob) | ||
| 479 | + for label, score, prob in zip(self.labels, score_values, probs, strict=True) | ||
| 480 | + ], | ||
| 481 | + total_ms=sum(stage_ms.values()), | ||
| 482 | + stage_ms=stage_ms, | ||
| 483 | + fast_path=self.fast_path, | ||
| 484 | + prefix_tokens=self.prefix_cache.prefix_len, | ||
| 485 | + continuation_tokens=continuation_tokens, | ||
| 486 | + label_token_lengths={item.text: len(item.token_ids) for item in self.encoded_labels}, | ||
| 487 | + ) | ||
| 488 | + | ||
| 489 | + | ||
| 490 | +class MultiPromptRunner: | ||
| 491 | + def __init__(self, cfg: RuntimeConfig): | ||
| 492 | + self.cfg = cfg | ||
| 493 | + t0 = time.perf_counter() | ||
| 494 | + self.shared_runtime = self.build_shared_runtime(cfg) | ||
| 495 | + t1 = time.perf_counter() | ||
| 496 | + self.device = self.shared_runtime.device | ||
| 497 | + self.dtype = self.shared_runtime.dtype | ||
| 498 | + self.tokenizer = self.shared_runtime.tokenizer | ||
| 499 | + self.model = self.shared_runtime.model | ||
| 500 | + self.backbone = self.shared_runtime.backbone | ||
| 501 | + self.hidden_size = self.shared_runtime.hidden_size | ||
| 502 | + self.graph_capture_pool = self.shared_runtime.graph_capture_pool | ||
| 503 | + self.graph_capture_stream = self.shared_runtime.graph_capture_stream | ||
| 504 | + self.runners = [ | ||
| 505 | + PromptClassifierRunner(cfg=cfg, task_cfg=task_cfg, shared_runtime=self.shared_runtime) | ||
| 506 | + for task_cfg in cfg.tasks | ||
| 507 | + ] | ||
| 508 | + t2 = time.perf_counter() | ||
| 509 | + self.batch_layouts = {batch_size: self._build_batch_layout(batch_size) for batch_size in self.cfg.batch_sizes} | ||
| 510 | + t3 = time.perf_counter() | ||
| 511 | + self.mixed_prefix_caches = { | ||
| 512 | + batch_size: self._build_mixed_prefix_cache(self.batch_layouts[batch_size]) | ||
| 513 | + for batch_size in self.cfg.batch_sizes | ||
| 514 | + } | ||
| 515 | + t4 = time.perf_counter() | ||
| 516 | + self.max_label_prefix_len = max( | ||
| 517 | + (runner.multi_token_tables.max_label_prefix_len if runner.multi_token_tables is not None else 0) | ||
| 518 | + for runner in self.runners | ||
| 519 | + ) | ||
| 520 | + self.mixed_buckets = { | ||
| 521 | + (batch_size, continuation_len): self._build_mixed_bucket( | ||
| 522 | + self.batch_layouts[batch_size], | ||
| 523 | + self.mixed_prefix_caches[batch_size], | ||
| 524 | + continuation_len, | ||
| 525 | + ) | ||
| 526 | + for batch_size in self.cfg.batch_sizes | ||
| 527 | + for continuation_len in self.cfg.continuation_buckets | ||
| 528 | + } | ||
| 529 | + t5 = time.perf_counter() | ||
| 530 | + self._warmup_results: dict[int, BatchScoreResult] = {} | ||
| 531 | + self._preload_report: PreloadReport | None = None | ||
| 532 | + self._init_stage_ms = { | ||
| 533 | + "load_model_and_tokenizer": (t1 - t0) * 1000.0, | ||
| 534 | + "build_prompt_runtimes": (t2 - t1) * 1000.0, | ||
| 535 | + "build_batch_layouts": (t3 - t2) * 1000.0, | ||
| 536 | + "build_mixed_prefix_caches": (t4 - t3) * 1000.0, | ||
| 537 | + "build_mixed_buckets_and_graphs": (t5 - t4) * 1000.0, | ||
| 538 | + } | ||
| 539 | + self._init_total_ms = sum(self._init_stage_ms.values()) | ||
| 540 | + | ||
| 541 | + @staticmethod | ||
| 542 | + def build_shared_runtime(cfg: RuntimeConfig) -> SharedRuntime: | ||
| 543 | + device = torch.device(cfg.device) | ||
| 544 | + dtype = torch.float16 | ||
| 545 | + tokenizer = AutoTokenizer.from_pretrained( | ||
| 546 | + cfg.resolved_model_source, | ||
| 547 | + trust_remote_code=cfg.resolved_trust_remote_code, | ||
| 548 | + token=cfg.hf_token, | ||
| 549 | + cache_dir=cfg.hf_cache_dir, | ||
| 550 | + local_files_only=cfg.resolved_local_files_only, | ||
| 551 | + ) | ||
| 552 | + if tokenizer.pad_token_id is None: | ||
| 553 | + tokenizer.pad_token = tokenizer.eos_token | ||
| 554 | + attn_impl = MultiPromptRunner._resolve_attn_impl(cfg.attn_backend) | ||
| 555 | + quantization_config = None | ||
| 556 | + model_kwargs: dict[str, object] = { | ||
| 557 | + "trust_remote_code": cfg.resolved_trust_remote_code, | ||
| 558 | + "attn_implementation": attn_impl, | ||
| 559 | + "token": cfg.hf_token, | ||
| 560 | + "cache_dir": cfg.hf_cache_dir, | ||
| 561 | + "local_files_only": cfg.resolved_local_files_only, | ||
| 562 | + } | ||
| 563 | + if cfg.load_in_4bit: | ||
| 564 | + if BitsAndBytesConfig is None: | ||
| 565 | + raise ImportError("transformers BitsAndBytesConfig is unavailable; install bitsandbytes support first") | ||
| 566 | + quantization_config = BitsAndBytesConfig( | ||
| 567 | + load_in_4bit=True, | ||
| 568 | + bnb_4bit_compute_dtype=dtype, | ||
| 569 | + bnb_4bit_quant_type=cfg.bnb_4bit_quant_type, | ||
| 570 | + bnb_4bit_use_double_quant=cfg.bnb_4bit_use_double_quant, | ||
| 571 | + ) | ||
| 572 | + model_kwargs["quantization_config"] = quantization_config | ||
| 573 | + model_kwargs["device_map"] = {"": device.index or 0} | ||
| 574 | + else: | ||
| 575 | + model_kwargs["dtype"] = dtype | ||
| 576 | + model_kwargs["device_map"] = None | ||
| 577 | + model = AutoModelForCausalLM.from_pretrained( | ||
| 578 | + cfg.resolved_model_source, | ||
| 579 | + **model_kwargs, | ||
| 580 | + ).eval() | ||
| 581 | + if not cfg.load_in_4bit: | ||
| 582 | + model = model.to(device) | ||
| 583 | + backbone = model.get_submodule(model.base_model_prefix) | ||
| 584 | + hidden_size = model.get_output_embeddings().weight.shape[1] | ||
| 585 | + graph_capture_pool = None | ||
| 586 | + graph_capture_stream = None | ||
| 587 | + if device.type == "cuda" and torch.cuda.is_available() and cfg.cuda_graphs and not cfg.load_in_4bit: | ||
| 588 | + graph_capture_pool = torch.cuda.graph_pool_handle() | ||
| 589 | + graph_capture_stream = torch.cuda.Stream(device=device) | ||
| 590 | + return SharedRuntime( | ||
| 591 | + device=device, | ||
| 592 | + dtype=dtype, | ||
| 593 | + tokenizer=tokenizer, | ||
| 594 | + model=model, | ||
| 595 | + backbone=backbone, | ||
| 596 | + hidden_size=hidden_size, | ||
| 597 | + graph_capture_pool=graph_capture_pool, | ||
| 598 | + graph_capture_stream=graph_capture_stream, | ||
| 599 | + ) | ||
| 600 | + | ||
| 601 | + @staticmethod | ||
| 602 | + def _resolve_attn_impl(requested: str) -> str: | ||
| 603 | + if requested in {"sdpa", "eager"}: | ||
| 604 | + return requested | ||
| 605 | + if requested == "auto": | ||
| 606 | + return "sdpa" | ||
| 607 | + raise ValueError(f"unsupported attn_backend: {requested}") | ||
| 608 | + | ||
| 609 | + def _attn_context(self): | ||
| 610 | + if sdpa_kernel is not None and self.cfg.attn_backend in {"auto", "sdpa"}: | ||
| 611 | + return sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]) | ||
| 612 | + return torch.no_grad() | ||
| 613 | + | ||
| 614 | + def _sync(self) -> None: | ||
| 615 | + if self.device.type == "cuda": | ||
| 616 | + torch.cuda.synchronize() | ||
| 617 | + | ||
| 618 | + def _pick_bucket(self, continuation_len: int) -> int: | ||
| 619 | + for bucket in self.cfg.continuation_buckets: | ||
| 620 | + if continuation_len <= bucket: | ||
| 621 | + return bucket | ||
| 622 | + if self.cfg.pad_to_bucket: | ||
| 623 | + raise ValueError( | ||
| 624 | + f"continuation length {continuation_len} exceeds configured buckets; extend continuation_buckets" | ||
| 625 | + ) | ||
| 626 | + return continuation_len | ||
| 627 | + | ||
| 628 | + def _build_batch_layout(self, batch_size: int) -> BatchLayout: | ||
| 629 | + plans: list[PromptBatchPlan] = [] | ||
| 630 | + row_start = 0 | ||
| 631 | + for runner in self.runners: | ||
| 632 | + row_count = batch_size if runner.fast_path else batch_size * runner.num_labels | ||
| 633 | + plans.append( | ||
| 634 | + PromptBatchPlan( | ||
| 635 | + runner=runner, | ||
| 636 | + row_start=row_start, | ||
| 637 | + row_count=row_count, | ||
| 638 | + score_buffer=torch.empty((batch_size, runner.num_labels), device=self.device, dtype=torch.float32), | ||
| 639 | + ) | ||
| 640 | + ) | ||
| 641 | + row_start += row_count | ||
| 642 | + return BatchLayout(batch_size=batch_size, total_rows=row_start, plans=plans) | ||
| 643 | + | ||
| 644 | + def _build_mixed_prefix_cache(self, layout: BatchLayout) -> MixedPrefixCache: | ||
| 645 | + prefix_lengths = torch.zeros((layout.total_rows,), device=self.device, dtype=torch.long) | ||
| 646 | + non_empty = [plan.runner.prefix_cache.raw_cache for plan in layout.plans if plan.runner.prefix_cache.raw_cache] | ||
| 647 | + if not non_empty: | ||
| 648 | + return MixedPrefixCache( | ||
| 649 | + batch_size=layout.batch_size, | ||
| 650 | + total_rows=layout.total_rows, | ||
| 651 | + prefix_lengths=prefix_lengths, | ||
| 652 | + attention_mask=torch.zeros((layout.total_rows, 0), device=self.device, dtype=torch.long), | ||
| 653 | + raw_cache=tuple(), | ||
| 654 | + ) | ||
| 655 | + | ||
| 656 | + max_prefix_len = max(plan.runner.prefix_cache.prefix_len for plan in layout.plans) | ||
| 657 | + num_layers = len(non_empty[0]) | ||
| 658 | + attention_mask = torch.zeros((layout.total_rows, max_prefix_len), device=self.device, dtype=torch.long) | ||
| 659 | + raw_layers: list[tuple[torch.Tensor, torch.Tensor]] = [] | ||
| 660 | + for layer_idx in range(num_layers): | ||
| 661 | + sample_key, sample_value = non_empty[0][layer_idx] | ||
| 662 | + merged_key = sample_key.new_zeros( | ||
| 663 | + (layout.total_rows, sample_key.shape[1], max_prefix_len, sample_key.shape[3]) | ||
| 664 | + ) | ||
| 665 | + merged_value = sample_value.new_zeros( | ||
| 666 | + (layout.total_rows, sample_value.shape[1], max_prefix_len, sample_value.shape[3]) | ||
| 667 | + ) | ||
| 668 | + raw_layers.append((merged_key, merged_value)) | ||
| 669 | + | ||
| 670 | + for plan in layout.plans: | ||
| 671 | + runner = plan.runner | ||
| 672 | + prefix_len = runner.prefix_cache.prefix_len | ||
| 673 | + row_slice = slice(plan.row_start, plan.row_stop) | ||
| 674 | + prefix_lengths[row_slice] = prefix_len | ||
| 675 | + if prefix_len == 0: | ||
| 676 | + continue | ||
| 677 | + attention_mask[row_slice, :prefix_len] = 1 | ||
| 678 | + raw_cache = runner.expand_prefix_raw_cache(plan.row_count) | ||
| 679 | + for layer_idx, (key, value) in enumerate(raw_cache): | ||
| 680 | + merged_key, merged_value = raw_layers[layer_idx] | ||
| 681 | + merged_key[row_slice, :, :prefix_len, :] = key | ||
| 682 | + merged_value[row_slice, :, :prefix_len, :] = value | ||
| 683 | + | ||
| 684 | + return MixedPrefixCache( | ||
| 685 | + batch_size=layout.batch_size, | ||
| 686 | + total_rows=layout.total_rows, | ||
| 687 | + prefix_lengths=prefix_lengths, | ||
| 688 | + attention_mask=attention_mask.contiguous(), | ||
| 689 | + raw_cache=tuple(raw_layers), | ||
| 690 | + ) | ||
| 691 | + | ||
| 692 | + def _build_mixed_bucket( | ||
| 693 | + self, | ||
| 694 | + layout: BatchLayout, | ||
| 695 | + prefix_cache: MixedPrefixCache, | ||
| 696 | + continuation_len: int, | ||
| 697 | + ) -> MixedBucketRuntime: | ||
| 698 | + max_input_len = continuation_len + self.max_label_prefix_len | ||
| 699 | + total_len = prefix_cache.max_prefix_len + max_input_len | ||
| 700 | + input_ids = torch.full( | ||
| 701 | + (layout.total_rows, max_input_len), | ||
| 702 | + fill_value=self.tokenizer.pad_token_id, | ||
| 703 | + device=self.device, | ||
| 704 | + dtype=torch.long, | ||
| 705 | + ) | ||
| 706 | + attention_mask = torch.zeros((layout.total_rows, total_len), device=self.device, dtype=torch.long) | ||
| 707 | + if prefix_cache.max_prefix_len: | ||
| 708 | + attention_mask[:, : prefix_cache.max_prefix_len] = prefix_cache.attention_mask | ||
| 709 | + position_ids = ( | ||
| 710 | + prefix_cache.prefix_lengths[:, None] | ||
| 711 | + + torch.arange(max_input_len, device=self.device, dtype=torch.long).unsqueeze(0) | ||
| 712 | + ).contiguous() | ||
| 713 | + last_hidden_state = torch.empty( | ||
| 714 | + (layout.total_rows, max_input_len, self.hidden_size), | ||
| 715 | + device=self.device, | ||
| 716 | + dtype=self.dtype, | ||
| 717 | + ) | ||
| 718 | + bucket = MixedBucketRuntime( | ||
| 719 | + batch_size=layout.batch_size, | ||
| 720 | + total_rows=layout.total_rows, | ||
| 721 | + continuation_len=continuation_len, | ||
| 722 | + max_input_len=max_input_len, | ||
| 723 | + input_ids=input_ids, | ||
| 724 | + attention_mask=attention_mask, | ||
| 725 | + position_ids=position_ids, | ||
| 726 | + last_hidden_state=last_hidden_state, | ||
| 727 | + ) | ||
| 728 | + if self.cfg.cuda_graphs: | ||
| 729 | + self._capture_mixed_bucket(bucket, prefix_cache) | ||
| 730 | + return bucket | ||
| 731 | + | ||
| 732 | + @torch.inference_mode() | ||
| 733 | + def _run_mixed_backbone( | ||
| 734 | + self, | ||
| 735 | + bucket: MixedBucketRuntime, | ||
| 736 | + prefix_cache: MixedPrefixCache, | ||
| 737 | + ) -> None: | ||
| 738 | + cache = DynamicCache(ddp_cache_data=prefix_cache.raw_cache, config=self.model.config) | ||
| 739 | + with self._attn_context(): | ||
| 740 | + outputs = self.backbone( | ||
| 741 | + input_ids=bucket.input_ids, | ||
| 742 | + attention_mask=bucket.attention_mask, | ||
| 743 | + position_ids=bucket.position_ids, | ||
| 744 | + past_key_values=cache, | ||
| 745 | + use_cache=False, | ||
| 746 | + return_dict=True, | ||
| 747 | + ) | ||
| 748 | + bucket.last_hidden_state.copy_(outputs.last_hidden_state) | ||
| 749 | + | ||
| 750 | + def _capture_mixed_bucket(self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache) -> None: | ||
| 751 | + if not torch.cuda.is_available(): | ||
| 752 | + return | ||
| 753 | + try: | ||
| 754 | + torch.cuda.synchronize() | ||
| 755 | + stream = self.graph_capture_stream or torch.cuda.Stream(device=self.device) | ||
| 756 | + with torch.cuda.stream(stream): | ||
| 757 | + for _ in range(self.cfg.graph_warmups): | ||
| 758 | + self._run_mixed_backbone(bucket, prefix_cache) | ||
| 759 | + stream.synchronize() | ||
| 760 | + graph = torch.cuda.CUDAGraph() | ||
| 761 | + with torch.cuda.graph(graph, pool=self.graph_capture_pool, stream=stream): | ||
| 762 | + self._run_mixed_backbone(bucket, prefix_cache) | ||
| 763 | + bucket.graph = graph | ||
| 764 | + except RuntimeError: | ||
| 765 | + bucket.graph = None | ||
| 766 | + | ||
| 767 | + def _prepare_bucket( | ||
| 768 | + self, | ||
| 769 | + layout: BatchLayout, | ||
| 770 | + prefix_cache: MixedPrefixCache, | ||
| 771 | + bucket: MixedBucketRuntime, | ||
| 772 | + query_ids_batch: list[list[int]], | ||
| 773 | + ) -> tuple[list[list[int]], dict[str, list[object]]]: | ||
| 774 | + del prefix_cache | ||
| 775 | + bucket.input_ids.fill_(self.tokenizer.pad_token_id) | ||
| 776 | + bucket.attention_mask.zero_() | ||
| 777 | + if self.mixed_prefix_caches[layout.batch_size].max_prefix_len: | ||
| 778 | + bucket.attention_mask[:, : self.mixed_prefix_caches[layout.batch_size].max_prefix_len] = ( | ||
| 779 | + self.mixed_prefix_caches[layout.batch_size].attention_mask | ||
| 780 | + ) | ||
| 781 | + continuation_lengths_per_task: dict[str, list[int]] = {} | ||
| 782 | + continuation_tokens_per_task: dict[str, list[list[int]]] = {} | ||
| 783 | + prefix_base = self.mixed_prefix_caches[layout.batch_size].max_prefix_len | ||
| 784 | + for plan in layout.plans: | ||
| 785 | + runner = plan.runner | ||
| 786 | + per_query_continuations = [runner.build_continuation_from_query_ids(query_ids) for query_ids in query_ids_batch] | ||
| 787 | + continuation_tokens_per_task[runner.task_cfg.name] = per_query_continuations | ||
| 788 | + continuation_lengths_per_task[runner.task_cfg.name] = [len(ids) for ids in per_query_continuations] | ||
| 789 | + if runner.fast_path: | ||
| 790 | + for batch_idx, continuation in enumerate(per_query_continuations): | ||
| 791 | + cont_len = len(continuation) | ||
| 792 | + row_idx = plan.row_start + batch_idx | ||
| 793 | + bucket.input_ids[row_idx, :cont_len] = torch.tensor(continuation, device=self.device, dtype=torch.long) | ||
| 794 | + bucket.attention_mask[row_idx, prefix_base : prefix_base + cont_len] = 1 | ||
| 795 | + continue | ||
| 796 | + | ||
| 797 | + assert runner.multi_token_tables is not None | ||
| 798 | + for batch_idx, continuation in enumerate(per_query_continuations): | ||
| 799 | + cont_len = len(continuation) | ||
| 800 | + row_start = plan.row_start + batch_idx * runner.num_labels | ||
| 801 | + row_stop = row_start + runner.num_labels | ||
| 802 | + row_slice = slice(row_start, row_stop) | ||
| 803 | + cont_tensor = torch.tensor(continuation, device=self.device, dtype=torch.long) | ||
| 804 | + bucket.input_ids[row_slice, :cont_len] = cont_tensor.unsqueeze(0).expand(runner.num_labels, -1) | ||
| 805 | + bucket.attention_mask[row_slice, prefix_base : prefix_base + cont_len] = 1 | ||
| 806 | + if runner.multi_token_tables.max_label_prefix_len: | ||
| 807 | + bucket.input_ids[ | ||
| 808 | + row_slice, | ||
| 809 | + cont_len : cont_len + runner.multi_token_tables.max_label_prefix_len, | ||
| 810 | + ] = runner.multi_token_tables.label_prefix_ids | ||
| 811 | + bucket.attention_mask[ | ||
| 812 | + row_slice, | ||
| 813 | + prefix_base + cont_len : prefix_base + cont_len + runner.multi_token_tables.max_label_prefix_len, | ||
| 814 | + ] = runner.multi_token_tables.label_prefix_mask | ||
| 815 | + return query_ids_batch, { | ||
| 816 | + "continuation_lengths_per_task": continuation_lengths_per_task, | ||
| 817 | + "continuation_tokens_per_task": continuation_tokens_per_task, | ||
| 818 | + } | ||
| 819 | + | ||
| 820 | + def _reduce_prompt_scores( | ||
| 821 | + self, | ||
| 822 | + layout: BatchLayout, | ||
| 823 | + bucket: MixedBucketRuntime, | ||
| 824 | + query_texts: list[str], | ||
| 825 | + prep_meta: dict[str, list[object]], | ||
| 826 | + shared_stage_ms: dict[str, float], | ||
| 827 | + ) -> list[MultiPromptScoreResult]: | ||
| 828 | + result_rows = [[] for _ in range(layout.batch_size)] | ||
| 829 | + prompt_reduce_total_ms = 0.0 | ||
| 830 | + for plan in layout.plans: | ||
| 831 | + runner = plan.runner | ||
| 832 | + continuation_lengths = prep_meta["continuation_lengths_per_task"][runner.task_cfg.name] | ||
| 833 | + reduce_start = time.perf_counter() | ||
| 834 | + if runner.fast_path: | ||
| 835 | + hidden_rows = [] | ||
| 836 | + row_slice = bucket.last_hidden_state[plan.row_start : plan.row_start + layout.batch_size] | ||
| 837 | + for batch_idx, cont_len in enumerate(continuation_lengths): | ||
| 838 | + hidden_rows.append(row_slice[batch_idx, cont_len - 1]) | ||
| 839 | + hidden = torch.stack(hidden_rows, dim=0) | ||
| 840 | + runner.reduce_fast_scores(hidden=hidden, out_scores=plan.score_buffer) | ||
| 841 | + stage_name = "tail_scorer" | ||
| 842 | + else: | ||
| 843 | + assert runner.multi_token_tables is not None | ||
| 844 | + score_positions = torch.stack( | ||
| 845 | + [ | ||
| 846 | + cont_len - 1 + runner.multi_token_tables.label_position_offsets | ||
| 847 | + for cont_len in continuation_lengths | ||
| 848 | + ], | ||
| 849 | + dim=0, | ||
| 850 | + ) | ||
| 851 | + runner.reduce_multi_token_scores( | ||
| 852 | + last_hidden_state=bucket.last_hidden_state[plan.row_start : plan.row_stop], | ||
| 853 | + batch_size=layout.batch_size, | ||
| 854 | + max_input_len=bucket.max_input_len, | ||
| 855 | + score_positions=score_positions, | ||
| 856 | + out_scores=plan.score_buffer, | ||
| 857 | + ) | ||
| 858 | + stage_name = "candidate_reduce" | ||
| 859 | + self._sync() | ||
| 860 | + reduce_end = time.perf_counter() | ||
| 861 | + reduce_ms = (reduce_end - reduce_start) * 1000.0 | ||
| 862 | + prompt_reduce_total_ms += reduce_ms | ||
| 863 | + for batch_idx, query in enumerate(query_texts): | ||
| 864 | + stage_ms = dict(shared_stage_ms) | ||
| 865 | + stage_ms[stage_name] = reduce_ms / layout.batch_size | ||
| 866 | + result_rows[batch_idx].append( | ||
| 867 | + runner.build_score_result( | ||
| 868 | + query=query, | ||
| 869 | + scores=plan.score_buffer[batch_idx], | ||
| 870 | + stage_ms=stage_ms, | ||
| 871 | + continuation_tokens=continuation_lengths[batch_idx], | ||
| 872 | + ) | ||
| 873 | + ) | ||
| 874 | + | ||
| 875 | + batch_total_ms = sum(shared_stage_ms.values()) + prompt_reduce_total_ms | ||
| 876 | + shared_plus_reduce = dict(shared_stage_ms) | ||
| 877 | + shared_plus_reduce["prompt_reduce_total"] = prompt_reduce_total_ms | ||
| 878 | + results: list[MultiPromptScoreResult] = [] | ||
| 879 | + for batch_idx, query in enumerate(query_texts): | ||
| 880 | + results.append( | ||
| 881 | + MultiPromptScoreResult( | ||
| 882 | + query=query, | ||
| 883 | + total_ms=batch_total_ms / layout.batch_size, | ||
| 884 | + details=result_rows[batch_idx], | ||
| 885 | + stage_ms={ | ||
| 886 | + **shared_plus_reduce, | ||
| 887 | + "per_query_total_estimate": batch_total_ms / layout.batch_size, | ||
| 888 | + }, | ||
| 889 | + ) | ||
| 890 | + ) | ||
| 891 | + return results | ||
| 892 | + | ||
| 893 | + @torch.inference_mode() | ||
| 894 | + def score_queries(self, queries: list[str]) -> BatchScoreResult: | ||
| 895 | + if not queries: | ||
| 896 | + raise ValueError("queries must not be empty") | ||
| 897 | + batch_size = len(queries) | ||
| 898 | + if batch_size not in self.batch_layouts: | ||
| 899 | + raise ValueError(f"batch size {batch_size} is not preloaded; configured batch_sizes={self.cfg.batch_sizes}") | ||
| 900 | + layout = self.batch_layouts[batch_size] | ||
| 901 | + prefix_cache = self.mixed_prefix_caches[batch_size] | ||
| 902 | + | ||
| 903 | + self._sync() | ||
| 904 | + t0 = time.perf_counter() | ||
| 905 | + query_ids_batch = [self.tokenizer.encode(query, add_special_tokens=False) for query in queries] | ||
| 906 | + self._sync() | ||
| 907 | + t1 = time.perf_counter() | ||
| 908 | + | ||
| 909 | + max_continuation_len = max( | ||
| 910 | + len(plan.runner.build_continuation_from_query_ids(query_ids)) | ||
| 911 | + for plan in layout.plans | ||
| 912 | + for query_ids in query_ids_batch | ||
| 913 | + ) | ||
| 914 | + picked_bucket = self._pick_bucket(max_continuation_len) | ||
| 915 | + bucket = self.mixed_buckets[(batch_size, picked_bucket)] | ||
| 916 | + _, prep_meta = self._prepare_bucket(layout, prefix_cache, bucket, query_ids_batch) | ||
| 917 | + self._sync() | ||
| 918 | + t2 = time.perf_counter() | ||
| 919 | + | ||
| 920 | + if bucket.graph is not None: | ||
| 921 | + bucket.graph.replay() | ||
| 922 | + else: | ||
| 923 | + self._run_mixed_backbone(bucket, prefix_cache) | ||
| 924 | + self._sync() | ||
| 925 | + t3 = time.perf_counter() | ||
| 926 | + | ||
| 927 | + shared_stage_ms = { | ||
| 928 | + "encode_queries_shared": (t1 - t0) * 1000.0, | ||
| 929 | + "prepare_batch_shared": (t2 - t1) * 1000.0, | ||
| 930 | + "backbone_shared": (t3 - t2) * 1000.0, | ||
| 931 | + } | ||
| 932 | + results = self._reduce_prompt_scores(layout, bucket, queries, prep_meta, shared_stage_ms) | ||
| 933 | + total_ms = sum(shared_stage_ms.values()) + results[0].stage_ms["prompt_reduce_total"] | ||
| 934 | + return BatchScoreResult( | ||
| 935 | + batch_size=batch_size, | ||
| 936 | + total_ms=total_ms, | ||
| 937 | + results=results, | ||
| 938 | + stage_ms={ | ||
| 939 | + **shared_stage_ms, | ||
| 940 | + "prompt_reduce_total": results[0].stage_ms["prompt_reduce_total"], | ||
| 941 | + }, | ||
| 942 | + ) | ||
| 943 | + | ||
| 944 | + def score_query(self, query: str) -> MultiPromptScoreResult: | ||
| 945 | + return self.score_queries([query]).results[0] | ||
| 946 | + | ||
| 947 | + def preload(self) -> PreloadReport: | ||
| 948 | + if self._preload_report is not None: | ||
| 949 | + return self._preload_report | ||
| 950 | + stage_ms: dict[str, float] = dict(self._init_stage_ms) | ||
| 951 | + start = time.perf_counter() | ||
| 952 | + self._sync() | ||
| 953 | + t0 = time.perf_counter() | ||
| 954 | + warmup_batch_sizes = self.cfg.warmup_batch_sizes or self.cfg.batch_sizes | ||
| 955 | + for batch_size in warmup_batch_sizes: | ||
| 956 | + queries = [self.cfg.warmup_query] * batch_size | ||
| 957 | + self._warmup_results[batch_size] = self.score_queries(queries) | ||
| 958 | + self._sync() | ||
| 959 | + t1 = time.perf_counter() | ||
| 960 | + stage_ms["warmup_end_to_end"] = (t1 - t0) * 1000.0 | ||
| 961 | + stage_ms["startup_total_before_warmup"] = self._init_total_ms | ||
| 962 | + total_ms = self._init_total_ms + (t1 - start) * 1000.0 | ||
| 963 | + runtime = self.preload_report() | ||
| 964 | + self._preload_report = PreloadReport(total_ms=total_ms, stage_ms=stage_ms, runtime=runtime) | ||
| 965 | + return self._preload_report | ||
| 966 | + | ||
| 967 | + def preload_report(self) -> dict[str, object]: | ||
| 968 | + return { | ||
| 969 | + "model_name": self.cfg.resolved_model_name, | ||
| 970 | + "model_source": self.cfg.resolved_model_source, | ||
| 971 | + "device": str(self.device), | ||
| 972 | + "dtype": self.cfg.dtype, | ||
| 973 | + "attn_backend": self.cfg.attn_backend, | ||
| 974 | + "execution_model": "single_mixed_backbone_per_batch", | ||
| 975 | + "num_tasks": len(self.runners), | ||
| 976 | + "task_names": [runner.task_cfg.name for runner in self.runners], | ||
| 977 | + "batch_sizes": list(self.cfg.batch_sizes), | ||
| 978 | + "continuation_buckets": list(self.cfg.continuation_buckets), | ||
| 979 | + "mixed_bucket_count": len(self.mixed_buckets), | ||
| 980 | + "captured_mixed_buckets": sum(bucket.graph is not None for bucket in self.mixed_buckets.values()), | ||
| 981 | + "all_configured_buckets_preloaded": True, | ||
| 982 | + "init_stage_ms": dict(self._init_stage_ms), | ||
| 983 | + "init_total_ms": self._init_total_ms, | ||
| 984 | + "force_single_token_labels": self.cfg.force_single_token_labels, | ||
| 985 | + "warmup_query": self.cfg.warmup_query, | ||
| 986 | + "tasks": [ | ||
| 987 | + { | ||
| 988 | + "task_name": runner.task_cfg.name, | ||
| 989 | + "fast_path": runner.fast_path, | ||
| 990 | + "num_labels": runner.num_labels, | ||
| 991 | + "label_token_lengths": {item.text: len(item.token_ids) for item in runner.encoded_labels}, | ||
| 992 | + "prefix_tokens": runner.prefix_cache.prefix_len, | ||
| 993 | + "prefix_hashes": runner.prefix_cache.prefix_hashes, | ||
| 994 | + "label_prefix": runner.task_cfg.label_prefix, | ||
| 995 | + } | ||
| 996 | + for runner in self.runners | ||
| 997 | + ], | ||
| 998 | + } | ||
| 999 | + | ||
| 1000 | +但是关于多token的处理方式是低效的、是错的,不要参考他,请你重新实现。 | ||
| 1001 | +本地的.venv已经创建好,是复用llm-qp的,请使用该环境 | ||
| 1002 | +有用的代码我已经引用过来了,请你基于需求,搜寻相关资料,专注于重新开发全新的版本。 | ||
| 1003 | +任务较重,请耐心逐步完成,慢慢来,要长时间迭代一直到服务建设完成、测试完全符合预期。 |
| @@ -0,0 +1,59 @@ | @@ -0,0 +1,59 @@ | ||
| 1 | + | ||
| 2 | +总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 | ||
| 3 | +先专注于推理的优化,不用考虑服务后,可以程序启动后标准输入读取query,输出分类词。 | ||
| 4 | + | ||
| 5 | +示例prompt和对应的9个分类词: | ||
| 6 | +Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} | ||
| 7 | + | ||
| 8 | +做专用执行路径,不是在通用生成引擎上做配置优化。 | ||
| 9 | + | ||
| 10 | +prompt已经固定(不要考虑蒸馏、微调、或者缩短prompt。缩短prompt是可以我自己调整的,并且prompt可能改为其他场景,与你做专用推理优化不相关,我给的prompt只是一个例子,你专注于专用推理的框架,适配配置化的prompt+分类词列表打分检查) | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +使用Tesla T4,因此: | ||
| 14 | +1. 使用FP16。不用 BF16 作为主路径 | ||
| 15 | +2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention | ||
| 16 | +3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +主要考虑优化方向为: | ||
| 20 | +hidden_last -> N-class scorer -> argmax | ||
| 21 | +参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV | ||
| 22 | +去 full vocab logits | ||
| 23 | +去 decode / constrained decode | ||
| 24 | +query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) | ||
| 25 | +专用 tail kernel(输出 N 类原始分数) | ||
| 26 | + | ||
| 27 | +以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: | ||
| 28 | +15.1 预分配 | ||
| 29 | +对每个 bucket 预分配: | ||
| 30 | +input ids buffer | ||
| 31 | +attention scratch buffer | ||
| 32 | +output hidden buffer | ||
| 33 | +class id buffer | ||
| 34 | +15.2 Graph capture | ||
| 35 | + | ||
| 36 | +每个 bucket 预先 capture 一张图: | ||
| 37 | +embedding | ||
| 38 | +transformer 主干 | ||
| 39 | +compact last-state | ||
| 40 | +9-way tail kernel | ||
| 41 | + | ||
| 42 | +注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 | ||
| 43 | + | ||
| 44 | +你有sudo权限,你可以执行为本项目安装自己的环境 | ||
| 45 | + | ||
| 46 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) | ||
| 47 | +或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 | ||
| 48 | + | ||
| 49 | + | ||
| 50 | + | ||
| 51 | + | ||
| 52 | + | ||
| 53 | + | ||
| 54 | + | ||
| 55 | +另外,我想要有个命令行工具,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。 | ||
| 56 | +输入query,输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下),目前是否已经满足了。 | ||
| 57 | + | ||
| 58 | +请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 | ||
| 59 | +使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 |
| @@ -0,0 +1,66 @@ | @@ -0,0 +1,66 @@ | ||
| 1 | +彻底清除掉命令行交互方式的代码,改造为提供 HTTP 接口,端口6001。 | ||
| 2 | +并提供scripts/service_ctl.sh的方式管理服务的启停 | ||
| 3 | +返回分类结果的list(results):list由多个dict组成,每个dict的key为对应的分类任务,value是三元组,即值、打分、概率(如果值为none则该项不输出) | ||
| 4 | +完善日志系统,当前cli方式输出的重要信息,在日志以及http接口的details字段体现。 | ||
| 5 | +提供一个压测脚本放到本项目的合适的目录下,作为压测工具。该压测工具针对每条请求打印出结果,并最后给出性能指标,参考: | ||
| 6 | +#!/bin/bash | ||
| 7 | + | ||
| 8 | +# 默认值 | ||
| 9 | +concurrency=${1:-1} | ||
| 10 | +top_lines=${2:-100} | ||
| 11 | + | ||
| 12 | +# 固定查询文件路径 | ||
| 13 | +query_file="/data/saas-search/scripts/evaluation/queries/queries.txt" | ||
| 14 | + | ||
| 15 | +# 检查文件是否存在 | ||
| 16 | +if [ ! -f "$query_file" ]; then | ||
| 17 | + echo "错误: 查询文件不存在: $query_file" >&2 | ||
| 18 | + exit 1 | ||
| 19 | +fi | ||
| 20 | + | ||
| 21 | +# 检查 jq 是否可用 | ||
| 22 | +if ! command -v jq &> /dev/null; then | ||
| 23 | + echo "错误: 需要安装 jq 来解析 JSON" >&2 | ||
| 24 | + exit 1 | ||
| 25 | +fi | ||
| 26 | + | ||
| 27 | +url="http://127.0.0.1:6001/..." | ||
| 28 | +max_jobs=$concurrency | ||
| 29 | +job_count=0 | ||
| 30 | + | ||
| 31 | +# 读取文件前 top_lines 行,每行作为一个 query | ||
| 32 | +while IFS= read -r query; do | ||
| 33 | + # 跳过空行 | ||
| 34 | + [ -z "$query" ] && continue | ||
| 35 | + | ||
| 36 | + # 启动子进程执行请求 | ||
| 37 | + ( | ||
| 38 | + # 安全构建 JSON payload | ||
| 39 | + payload=$(jq -n --arg q "$query" '{query: $q}') | ||
| 40 | + # 发送请求并获取响应 | ||
| 41 | + response=$(curl -s -X POST "$url" \ | ||
| 42 | + -H 'Content-Type: application/json' \ | ||
| 43 | + -d "$payload") | ||
| 44 | + # 提取 results 字段(紧凑 JSON 格式) | ||
| 45 | + results=$(echo "$response" | jq -c '.results') | ||
| 46 | + # 输出 query 和对应的 results | ||
| 47 | + printf "%s\t%s\n" "$query" "$results" | ||
| 48 | + ) & | ||
| 49 | + | ||
| 50 | + # 控制并发数量 | ||
| 51 | + ((job_count++)) | ||
| 52 | + if (( job_count >= max_jobs )); then | ||
| 53 | + wait -n # 等待任意一个后台进程完成 | ||
| 54 | + ((job_count--)) | ||
| 55 | + fi | ||
| 56 | +done < <(head -n "$top_lines" "$query_file") | ||
| 57 | + | ||
| 58 | +# 等待所有剩余后台进程完成 | ||
| 59 | +# 在这里统计处性能情况,指标: | ||
| 60 | + 平均耗时: | ||
| 61 | + 最大耗时: | ||
| 62 | + 最小耗时: | ||
| 63 | + TP50: | ||
| 64 | + TP90: | ||
| 65 | + TP99: | ||
| 66 | +wait | ||
| 0 | \ No newline at end of file | 67 | \ No newline at end of file |
docs/相关性检索优化说明.md
| @@ -862,17 +862,14 @@ rerank_score:0.0564 | @@ -862,17 +862,14 @@ rerank_score:0.0564 | ||
| 862 | "en": "Judy Blue Women's High Waist Button Fly Skinny Jeans 82319", | 862 | "en": "Judy Blue Women's High Waist Button Fly Skinny Jeans 82319", |
| 863 | "zh": "Judy Blue 女士高腰纽扣开叉修身牛仔裤 82319" | 863 | "zh": "Judy Blue 女士高腰纽扣开叉修身牛仔裤 82319" |
| 864 | 864 | ||
| 865 | - | ||
| 866 | rerank_score:0.0790 | 865 | rerank_score:0.0790 |
| 867 | "en": "2025 New Fashion European and American Women's Jeans High-Waisted Slim Straight Denim Pants Popular Floor-Length Pants", | 866 | "en": "2025 New Fashion European and American Women's Jeans High-Waisted Slim Straight Denim Pants Popular Floor-Length Pants", |
| 868 | "zh": "2025新款欧美风女式高腰显瘦直筒牛仔裤 时尚及地长裤" | 867 | "zh": "2025新款欧美风女式高腰显瘦直筒牛仔裤 时尚及地长裤" |
| 869 | 868 | ||
| 870 | - | ||
| 871 | rerank_score:0.0822 | 869 | rerank_score:0.0822 |
| 872 | "en": "roswear Women's Trendy Stretchy Flare Jeans Mid Rise Bootcut Curvy Denim Pants", | 870 | "en": "roswear Women's Trendy Stretchy Flare Jeans Mid Rise Bootcut Curvy Denim Pants", |
| 873 | "zh": "Roswear 女士时尚弹力喇叭牛仔裤 中腰高腰修身直筒牛仔裤" | 871 | "zh": "Roswear 女士时尚弹力喇叭牛仔裤 中腰高腰修身直筒牛仔裤" |
| 874 | 872 | ||
| 875 | - | ||
| 876 | rerank_score:0.0956 | 873 | rerank_score:0.0956 |
| 877 | "en": "POSHGLAM Women's Maternity Jeans Over Belly 29'' Skinny Denim Jeggings Comfy Stretch Clearance Pregnancy Pants", | 874 | "en": "POSHGLAM Women's Maternity Jeans Over Belly 29'' Skinny Denim Jeggings Comfy Stretch Clearance Pregnancy Pants", |
| 878 | "zh": "POSHGLAM 女士孕产期高腰显瘦牛仔紧身裤 29英寸 紧身弹力孕妇裤 休闲舒适 清仓特价" | 875 | "zh": "POSHGLAM 女士孕产期高腰显瘦牛仔紧身裤 29英寸 紧身弹力孕妇裤 休闲舒适 清仓特价" |
indexer/document_transformer.py
| @@ -151,15 +151,15 @@ class SPUDocumentTransformer: | @@ -151,15 +151,15 @@ class SPUDocumentTransformer: | ||
| 151 | self._fill_title_embedding(doc) | 151 | self._fill_title_embedding(doc) |
| 152 | 152 | ||
| 153 | # Tags:统一转成与 mapping 一致的 core-language object | 153 | # Tags:统一转成与 mapping 一致的 core-language object |
| 154 | - if pd.notna(spu_row.get('tags')): | ||
| 155 | - tags_str = str(spu_row['tags']) | 154 | + if pd.notna(spu_row.get('enriched_tags')): |
| 155 | + tags_str = str(spu_row['enriched_tags']) | ||
| 156 | tags_obj = self._build_core_language_text_object( | 156 | tags_obj = self._build_core_language_text_object( |
| 157 | tags_str, | 157 | tags_str, |
| 158 | source_lang=primary_lang, | 158 | source_lang=primary_lang, |
| 159 | scene="general", | 159 | scene="general", |
| 160 | ) | 160 | ) |
| 161 | if tags_obj: | 161 | if tags_obj: |
| 162 | - doc['tags'] = tags_obj | 162 | + doc['enriched_tags'] = tags_obj |
| 163 | 163 | ||
| 164 | # Category相关字段 | 164 | # Category相关字段 |
| 165 | self._fill_category_fields(doc, spu_row) | 165 | self._fill_category_fields(doc, spu_row) |
| @@ -240,7 +240,7 @@ class SPUDocumentTransformer: | @@ -240,7 +240,7 @@ class SPUDocumentTransformer: | ||
| 240 | """ | 240 | """ |
| 241 | 批量调用 LLM,为一批 doc 填充: | 241 | 批量调用 LLM,为一批 doc 填充: |
| 242 | - qanchors.{lang} | 242 | - qanchors.{lang} |
| 243 | - - tags.{lang} | 243 | + - enriched_tags.{lang} |
| 244 | - enriched_attributes[].value.{lang} | 244 | - enriched_attributes[].value.{lang} |
| 245 | 245 | ||
| 246 | 设计目标: | 246 | 设计目标: |
| @@ -292,8 +292,8 @@ class SPUDocumentTransformer: | @@ -292,8 +292,8 @@ class SPUDocumentTransformer: | ||
| 292 | try: | 292 | try: |
| 293 | if enrichment.get("qanchors"): | 293 | if enrichment.get("qanchors"): |
| 294 | doc["qanchors"] = enrichment["qanchors"] | 294 | doc["qanchors"] = enrichment["qanchors"] |
| 295 | - if enrichment.get("tags"): | ||
| 296 | - doc["tags"] = enrichment["tags"] | 295 | + if enrichment.get("enriched_tags"): |
| 296 | + doc["enriched_tags"] = enrichment["enriched_tags"] | ||
| 297 | if enrichment.get("enriched_attributes"): | 297 | if enrichment.get("enriched_attributes"): |
| 298 | doc["enriched_attributes"] = enrichment["enriched_attributes"] | 298 | doc["enriched_attributes"] = enrichment["enriched_attributes"] |
| 299 | except Exception as e: | 299 | except Exception as e: |
| @@ -656,7 +656,7 @@ class SPUDocumentTransformer: | @@ -656,7 +656,7 @@ class SPUDocumentTransformer: | ||
| 656 | """ | 656 | """ |
| 657 | 调用 indexer.product_enrich 的高层内容理解入口,为当前 SPU 填充: | 657 | 调用 indexer.product_enrich 的高层内容理解入口,为当前 SPU 填充: |
| 658 | - qanchors.{lang} | 658 | - qanchors.{lang} |
| 659 | - - tags.{lang} | 659 | + - enriched_tags.{lang} |
| 660 | - enriched_attributes[].value.{lang} | 660 | - enriched_attributes[].value.{lang} |
| 661 | """ | 661 | """ |
| 662 | spu_id = str(spu_row.get("id") or "").strip() | 662 | spu_id = str(spu_row.get("id") or "").strip() |
indexer/product_enrich_prompts.py
| @@ -10,16 +10,16 @@ SYSTEM_MESSAGE = ( | @@ -10,16 +10,16 @@ SYSTEM_MESSAGE = ( | ||
| 10 | 10 | ||
| 11 | SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product text and fill these columns: | 11 | SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product text and fill these columns: |
| 12 | 12 | ||
| 13 | -1. Product title: a natural localized product name derived from the input product text | ||
| 14 | -2. Category path: broad to fine-grained category, separated by ">" | ||
| 15 | -3. Fine-grained tags: style, features, functions, or notable attributes | ||
| 16 | -4. Target audience: gender, age group, or suitable users | ||
| 17 | -5. Usage scene | ||
| 18 | -6. Applicable season | ||
| 19 | -7. Key attributes | ||
| 20 | -8. Material description | ||
| 21 | -9. Functional features | ||
| 22 | -10. Anchor text: a search-focused set of keywords, selling points, and phrases covering categories, attributes, usage scenarios, and user intent | 13 | +1. Product title: a natural, localized product name based on the input text |
| 14 | +2. Category path: a concise category hierarchy from broad to specific, separated by ">" | ||
| 15 | +3. Fine-grained tags: concise tags for style, features, design details, function, or standout selling points | ||
| 16 | +4. Target audience: gender, age group, body type, or suitable users when clearly implied | ||
| 17 | +5. Usage scene: likely occasions, settings, or use cases | ||
| 18 | +6. Applicable season: relevant season(s) based on the product text | ||
| 19 | +7. Key attributes: core product attributes and specifications. Depending on the item type, this may include fit, silhouette, length, sleeve type, neckline, waistline, closure, pattern, design details, structure, or other relevant attribute dimensions | ||
| 20 | +8. Material description: material, fabric, texture, or construction description | ||
| 21 | +9. Functional features: practical or performance-related functions such as stretch, breathability, warmth, support, storage, protection, or ease of wear | ||
| 22 | +10. Anchor text: a search-oriented keyword string covering product type, category intent, attributes, design cues, usage scenarios, and strong shopping phrases | ||
| 23 | 23 | ||
| 24 | Rules: | 24 | Rules: |
| 25 | - Keep the input order and row count exactly the same. | 25 | - Keep the input order and row count exactly the same. |
perf_reports/20260311/reranker_1000docs/report.md
| @@ -4,7 +4,7 @@ Workload profile: | @@ -4,7 +4,7 @@ Workload profile: | ||
| 4 | - backend: `qwen3_vllm` (`Qwen/Qwen3-Reranker-0.6B`) | 4 | - backend: `qwen3_vllm` (`Qwen/Qwen3-Reranker-0.6B`) |
| 5 | - query: short e-commerce text (<100 tokens) | 5 | - query: short e-commerce text (<100 tokens) |
| 6 | - docs/request: 1000 short titles/title+brief | 6 | - docs/request: 1000 short titles/title+brief |
| 7 | -- options: `sort_by_doc_length=true`, `length_sort_mode=char` | 7 | +- options: `sort_by_doc_length=true` |
| 8 | 8 | ||
| 9 | ## Results | 9 | ## Results |
| 10 | 10 |
reranker/DEPLOYMENT_AND_TUNING.md
| 1 | -# Reranker 部署与性能调优手册(Qwen3-vLLM / Qwen3-GGUF) | 1 | +# Reranker 部署与性能调优手册(Qwen3-vLLM) |
| 2 | 2 | ||
| 3 | 本文档沉淀当前项目在电商搜索重排场景下的可复用实践,覆盖: | 3 | 本文档沉淀当前项目在电商搜索重排场景下的可复用实践,覆盖: |
| 4 | 4 | ||
| 5 | - 环境准备与安装部署 | 5 | - 环境准备与安装部署 |
| 6 | -- `qwen3_vllm` / `qwen3_gguf` / `qwen3_gguf_06b` 配置项与优化思路 | 6 | +- `qwen3_vllm` 配置项与优化思路 |
| 7 | - 1000-doc 场景压测流程 | 7 | - 1000-doc 场景压测流程 |
| 8 | - 关键结论与推荐默认参数 | 8 | - 关键结论与推荐默认参数 |
| 9 | - 常见故障排查 | 9 | - 常见故障排查 |
| 10 | 10 | ||
| 11 | 适用范围: | 11 | 适用范围: |
| 12 | 12 | ||
| 13 | -- 重排后端:`services.rerank.backend: qwen3_vllm` / `qwen3_gguf` / `qwen3_gguf_06b` | ||
| 14 | -- 模型:`Qwen/Qwen3-Reranker-0.6B` / `DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF` / `ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF` | 13 | +- 重排后端:`services.rerank.backend: qwen3_vllm` |
| 14 | +- 模型:`Qwen/Qwen3-Reranker-0.6B` | ||
| 15 | - 场景:query 较短(通常 < 100 tokens),doc 为商品标题或标题+简短描述,单请求 docs 约 1000 条 | 15 | - 场景:query 较短(通常 < 100 tokens),doc 为商品标题或标题+简短描述,单请求 docs 约 1000 条 |
| 16 | 16 | ||
| 17 | ## 1. 环境基线 | 17 | ## 1. 环境基线 |
| 18 | 18 | ||
| 19 | -当前验证环境(2026-03-25): | 19 | +当前验证环境(2026-03-11): |
| 20 | 20 | ||
| 21 | - GPU:`Tesla T4 16GB` | 21 | - GPU:`Tesla T4 16GB` |
| 22 | - Driver / CUDA:`570.158.01 / 12.8` | 22 | - Driver / CUDA:`570.158.01 / 12.8` |
| 23 | - Python:`3.12.3` | 23 | - Python:`3.12.3` |
| 24 | -- 关键依赖:`vllm==0.17.0`、`torch==2.10.0+cu128`、`transformers==4.57.6`、`llama-cpp-python>=0.3.16`、`fastapi==0.135.1`、`uvicorn==0.41.0` | 24 | +- 关键依赖:`vllm==0.17.0`、`torch==2.10.0+cu128`、`transformers==4.57.6`、`fastapi==0.135.1`、`uvicorn==0.41.0` |
| 25 | 25 | ||
| 26 | ## 2. 环境准备与安装 | 26 | ## 2. 环境准备与安装 |
| 27 | 27 | ||
| 28 | ### 2.1 准备 reranker 独立虚拟环境 | 28 | ### 2.1 准备 reranker 独立虚拟环境 |
| 29 | 29 | ||
| 30 | ```bash | 30 | ```bash |
| 31 | -./scripts/setup_reranker_venv.sh qwen3_vllm | ||
| 32 | -``` | ||
| 33 | - | ||
| 34 | -若使用 GGUF 并需要 CUDA: | ||
| 35 | - | ||
| 36 | -```bash | ||
| 37 | -./scripts/setup_reranker_venv.sh qwen3_gguf | ||
| 38 | -PATH=/usr/local/cuda/bin:$PATH \ | ||
| 39 | -CUDACXX=/usr/local/cuda/bin/nvcc \ | ||
| 40 | -CMAKE_ARGS="-DGGML_CUDA=on" \ | ||
| 41 | -FORCE_CMAKE=1 \ | ||
| 42 | -./.venv-reranker-gguf/bin/pip install --no-cache-dir --force-reinstall --no-build-isolation llama-cpp-python==0.3.18 | 31 | +./scripts/setup_reranker_venv.sh |
| 43 | ``` | 32 | ``` |
| 44 | 33 | ||
| 45 | ### 2.2 基础检查 | 34 | ### 2.2 基础检查 |
| @@ -48,7 +37,6 @@ FORCE_CMAKE=1 \ | @@ -48,7 +37,6 @@ FORCE_CMAKE=1 \ | ||
| 48 | nvidia-smi | 37 | nvidia-smi |
| 49 | ./.venv-reranker/bin/python -c "import torch; print(torch.cuda.is_available())" | 38 | ./.venv-reranker/bin/python -c "import torch; print(torch.cuda.is_available())" |
| 50 | ./.venv-reranker/bin/python -c "import vllm, transformers; print(vllm.__version__, transformers.__version__)" | 39 | ./.venv-reranker/bin/python -c "import vllm, transformers; print(vllm.__version__, transformers.__version__)" |
| 51 | -./.venv-reranker-gguf/bin/python -c "import llama_cpp; print(llama_cpp.__version__)" | ||
| 52 | ``` | 40 | ``` |
| 53 | 41 | ||
| 54 | ## 3. 部署与运行 | 42 | ## 3. 部署与运行 |
| @@ -73,30 +61,6 @@ services: | @@ -73,30 +61,6 @@ services: | ||
| 73 | enforce_eager: false | 61 | enforce_eager: false |
| 74 | infer_batch_size: 64 | 62 | infer_batch_size: 64 |
| 75 | sort_by_doc_length: true | 63 | sort_by_doc_length: true |
| 76 | - length_sort_mode: "char" # char | token | ||
| 77 | -``` | ||
| 78 | - | ||
| 79 | -GGUF / T4 剩余显存约 `4.8~6GB` 时,推荐基线: | ||
| 80 | - | ||
| 81 | -```yaml | ||
| 82 | -services: | ||
| 83 | - rerank: | ||
| 84 | - backend: "qwen3_gguf" | ||
| 85 | - backends: | ||
| 86 | - qwen3_gguf: | ||
| 87 | - repo_id: "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF" | ||
| 88 | - filename: "*Q8_0.gguf" | ||
| 89 | - local_dir: "./models/reranker/qwen3-reranker-4b-gguf" | ||
| 90 | - cache_dir: "./model_cache" | ||
| 91 | - n_ctx: 384 | ||
| 92 | - n_batch: 384 | ||
| 93 | - n_ubatch: 128 | ||
| 94 | - n_gpu_layers: 24 | ||
| 95 | - flash_attn: true | ||
| 96 | - offload_kqv: true | ||
| 97 | - infer_batch_size: 8 | ||
| 98 | - sort_by_doc_length: true | ||
| 99 | - length_sort_mode: "char" | ||
| 100 | ``` | 64 | ``` |
| 101 | 65 | ||
| 102 | ### 3.2 启停命令 | 66 | ### 3.2 启停命令 |
| @@ -140,13 +104,6 @@ curl -sS http://127.0.0.1:6007/health | @@ -140,13 +104,6 @@ curl -sS http://127.0.0.1:6007/health | ||
| 140 | - `service_ctl.sh` 对 reranker 使用独立启动路径 | 104 | - `service_ctl.sh` 对 reranker 使用独立启动路径 |
| 141 | - 增加“稳定健康检查”(连续健康探测)避免“刚 healthy 即退出”的假阳性 | 105 | - 增加“稳定健康检查”(连续健康探测)避免“刚 healthy 即退出”的假阳性 |
| 142 | 106 | ||
| 143 | -### 4.4 GGUF / T4 小显存优化原则 | ||
| 144 | - | ||
| 145 | -- `Q8_0` 权重约 `4.28GB`,但还要给 KV cache、CUDA 工作区和运行时碎片预留空间,不能按“模型大小 < 剩余显存”直接判断可行。 | ||
| 146 | -- 当前业务是短 query + 商品标题,优先压缩 `n_ctx`;`384` 通常比默认长上下文更划算。 | ||
| 147 | -- T4 小显存下先扫 `n_gpu_layers`,再尝试提高 `n_ctx`;`infer_batch_size` 在当前 GGUF 接入里主要是服务侧 work chunk,不是 llama.cpp 的真实算子 batch。 | ||
| 148 | -- `flash_attn: true`、`offload_kqv: true` 默认保持开启;若 OOM,优先降低 `n_gpu_layers`。 | ||
| 149 | - | ||
| 150 | ## 5. 性能调优流程(标准流程) | 107 | ## 5. 性能调优流程(标准流程) |
| 151 | 108 | ||
| 152 | ### 5.1 使用一键压测脚本 | 109 | ### 5.1 使用一键压测脚本 |
| @@ -167,13 +124,6 @@ curl -sS http://127.0.0.1:6007/health | @@ -167,13 +124,6 @@ curl -sS http://127.0.0.1:6007/health | ||
| 167 | - `infer_batch_size`: `24 32 48 64` | 124 | - `infer_batch_size`: `24 32 48 64` |
| 168 | - 并发组:`c=1`(看单请求延迟)、`c=4`(看并发吞吐与尾延迟) | 125 | - 并发组:`c=1`(看单请求延迟)、`c=4`(看并发吞吐与尾延迟) |
| 169 | 126 | ||
| 170 | -GGUF 建议扫描: | ||
| 171 | - | ||
| 172 | -- `n_gpu_layers`: `20 24 28` | ||
| 173 | -- `n_ctx`: `320 384 448` | ||
| 174 | -- `infer_batch_size`: `4 8 12`(次要,仅影响服务侧 work chunk) | ||
| 175 | -- 扫描顺序:先固定 `n_ctx=384`,找能稳定启动的最大 `n_gpu_layers`;再在显存允许时尝试 `n_ctx=448`;最后才微调 `infer_batch_size` | ||
| 176 | - | ||
| 177 | 可通过环境变量覆盖: | 127 | 可通过环境变量覆盖: |
| 178 | 128 | ||
| 179 | - `BATCH_SIZES` | 129 | - `BATCH_SIZES` |
| @@ -189,28 +139,23 @@ GGUF 建议扫描: | @@ -189,28 +139,23 @@ GGUF 建议扫描: | ||
| 189 | - `RERANK_VLLM_INFER_BATCH_SIZE` | 139 | - `RERANK_VLLM_INFER_BATCH_SIZE` |
| 190 | - `RERANK_VLLM_SORT_BY_DOC_LENGTH` | 140 | - `RERANK_VLLM_SORT_BY_DOC_LENGTH` |
| 191 | 141 | ||
| 192 | -## 6. 本轮关键结论 | 142 | +## 6. 本轮关键结论(2026-03-11) |
| 143 | + | ||
| 144 | +基于报告: | ||
| 145 | + | ||
| 146 | +- `perf_reports/20260311/reranker_1000docs/report.md` | ||
| 193 | 147 | ||
| 194 | -vLLM(2026-03-11,见 `perf_reports/20260311/reranker_1000docs/report.md`): | 148 | +结论: |
| 195 | 149 | ||
| 196 | - 对在线重排更重要的单请求延迟(`c=1`)指标,`infer_batch_size=64` 最优 | 150 | - 对在线重排更重要的单请求延迟(`c=1`)指标,`infer_batch_size=64` 最优 |
| 197 | - `infer_batch_size=96` 在更高并发下吞吐略高,但会牺牲单请求延迟稳定性 | 151 | - `infer_batch_size=96` 在更高并发下吞吐略高,但会牺牲单请求延迟稳定性 |
| 198 | - 当前默认选择 `infer_batch_size=64` 作为平衡点 | 152 | - 当前默认选择 `infer_batch_size=64` 作为平衡点 |
| 199 | 153 | ||
| 200 | -GGUF(2026-03-25,本次接入): | ||
| 201 | - | ||
| 202 | -- `DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF` 的 `Q8_0` 体积约 `4.28GB`,结合当前机器实测剩余显存约 `4823 MiB`,默认不采用激进的全量 GPU offload。 | ||
| 203 | -- 当前推荐默认值:`n_ctx=384`、`n_batch=384`、`n_ubatch=128`、`n_gpu_layers=24`、`infer_batch_size=8`。 | ||
| 204 | -- 若现场剩余显存更接近 `6GB` 且碎片较少,可优先尝试 `n_gpu_layers=28`;若启动失败,回退到 `24` 或 `20`。 | ||
| 205 | -- 由于当前工作区尚未缓存该 GGUF 权重,本次尚未完成真实吞吐压测;上线前需在部署机复跑一轮参数扫描并归档报告。 | ||
| 206 | - | ||
| 207 | ## 7. 生产建议 | 154 | ## 7. 生产建议 |
| 208 | 155 | ||
| 209 | - 默认保持:`infer_batch_size: 64`、`sort_by_doc_length: true` | 156 | - 默认保持:`infer_batch_size: 64`、`sort_by_doc_length: true` |
| 210 | - 满足以下条件时可考虑提高到 `96`:业务以吞吐优先、可接受更高单请求延迟、已通过同机同数据压测验证收益 | 157 | - 满足以下条件时可考虑提高到 `96`:业务以吞吐优先、可接受更高单请求延迟、已通过同机同数据压测验证收益 |
| 211 | - 每次改动后都必须复跑 `benchmark_reranker_1000docs.sh` 并归档结果 | 158 | - 每次改动后都必须复跑 `benchmark_reranker_1000docs.sh` 并归档结果 |
| 212 | -- GGUF 默认保持:`n_ctx: 384`、`n_gpu_layers: 24`、`infer_batch_size: 8`、`flash_attn: true`、`offload_kqv: true` | ||
| 213 | -- GGUF 若 OOM:先降 `n_gpu_layers`,再降 `n_ctx`,最后再降 `infer_batch_size` | ||
| 214 | 159 | ||
| 215 | ## 8. 故障排查 | 160 | ## 8. 故障排查 |
| 216 | 161 | ||
| @@ -248,13 +193,6 @@ lsof -i :6007 -P -n | @@ -248,13 +193,6 @@ lsof -i :6007 -P -n | ||
| 248 | - 降低 `infer_batch_size` | 193 | - 降低 `infer_batch_size` |
| 249 | - 检查是否有其他进程占用同卡 | 194 | - 检查是否有其他进程占用同卡 |
| 250 | 195 | ||
| 251 | -GGUF 优先调整: | ||
| 252 | - | ||
| 253 | -- 降低 `n_gpu_layers` | ||
| 254 | -- 降低 `n_ctx` | ||
| 255 | -- 降低 `infer_batch_size` | ||
| 256 | -- 检查是否有其他进程占用同卡 | ||
| 257 | - | ||
| 258 | ## 9. 变更与验证清单 | 196 | ## 9. 变更与验证清单 |
| 259 | 197 | ||
| 260 | 每次 reranker 调优改动后,至少完成: | 198 | 每次 reranker 调优改动后,至少完成: |
scripts/build_suggestions.sh
| @@ -9,19 +9,38 @@ | @@ -9,19 +9,38 @@ | ||
| 9 | # # incremental update from watermark | 9 | # # incremental update from watermark |
| 10 | # ./scripts/build_suggestions.sh <tenant_id> --mode incremental | 10 | # ./scripts/build_suggestions.sh <tenant_id> --mode incremental |
| 11 | # | 11 | # |
| 12 | +# # full rebuild + incremental + ES/API smoke checks (same as legacy rebuild_suggestions.sh) | ||
| 13 | +# ./scripts/build_suggestions.sh <tenant_id> --rebuild | ||
| 14 | +# | ||
| 12 | 15 | ||
| 13 | set -euo pipefail | 16 | set -euo pipefail |
| 14 | 17 | ||
| 15 | if [ $# -lt 1 ]; then | 18 | if [ $# -lt 1 ]; then |
| 16 | - echo "Usage: $0 <tenant_id> [extra args...]" | ||
| 17 | - echo "Example (full): $0 162 --mode full --days 30 --publish-alias" | ||
| 18 | - echo "Example (incremental): $0 162 --mode incremental --overlap-minutes 30" | 19 | + echo "Usage: $0 <tenant_id> [--rebuild | extra args for main.py build-suggestions...]" |
| 20 | + echo "Example (full): $0 163 --mode full --days 30 --publish-alias" | ||
| 21 | + echo "Example (incremental): $0 163 --mode incremental --overlap-minutes 30" | ||
| 22 | + echo "Example (pipeline + smoke): $0 163 --rebuild" | ||
| 19 | exit 1 | 23 | exit 1 |
| 20 | fi | 24 | fi |
| 21 | 25 | ||
| 22 | TENANT_ID="$1" | 26 | TENANT_ID="$1" |
| 23 | shift || true | 27 | shift || true |
| 24 | 28 | ||
| 29 | +RUN_REBUILD_PIPELINE=false | ||
| 30 | +PASSTHROUGH_ARGS=() | ||
| 31 | +for arg in "$@"; do | ||
| 32 | + if [ "$arg" = "--rebuild" ]; then | ||
| 33 | + RUN_REBUILD_PIPELINE=true | ||
| 34 | + else | ||
| 35 | + PASSTHROUGH_ARGS+=("$arg") | ||
| 36 | + fi | ||
| 37 | +done | ||
| 38 | + | ||
| 39 | +if [ "$RUN_REBUILD_PIPELINE" = true ] && [ ${#PASSTHROUGH_ARGS[@]} -gt 0 ]; then | ||
| 40 | + echo "Error: --rebuild cannot be combined with other build-suggestions arguments." | ||
| 41 | + exit 1 | ||
| 42 | +fi | ||
| 43 | + | ||
| 25 | ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" | 44 | ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" |
| 26 | 45 | ||
| 27 | cd "$ROOT_DIR" | 46 | cd "$ROOT_DIR" |
| @@ -31,6 +50,79 @@ if [ ! -x "$PY_BIN" ]; then | @@ -31,6 +50,79 @@ if [ ! -x "$PY_BIN" ]; then | ||
| 31 | PY_BIN="python3" | 50 | PY_BIN="python3" |
| 32 | fi | 51 | fi |
| 33 | 52 | ||
| 53 | +if [ "$RUN_REBUILD_PIPELINE" = true ]; then | ||
| 54 | + SAMPLE_QUERIES=(s sh dress tshirt) | ||
| 55 | + SAMPLE_LANGS=(en zh) | ||
| 56 | + # This script validates the locally rebuilt index, so default the smoke target | ||
| 57 | + # to the local backend. A remote/public API base must be opted into explicitly. | ||
| 58 | + API_BASE="${SUGGESTIONS_SMOKE_BASE_URL:-http://localhost:6002}" | ||
| 59 | + | ||
| 60 | + if [ -z "${ES_HOST:-}" ]; then | ||
| 61 | + ES_HOST="$("$PY_BIN" - <<'PY' | ||
| 62 | +from dotenv import dotenv_values | ||
| 63 | +print(dotenv_values('.env').get('ES_HOST') or 'http://localhost:9200') | ||
| 64 | +PY | ||
| 65 | +)" | ||
| 66 | + fi | ||
| 67 | + | ||
| 68 | + if [ -z "${ES_USERNAME:-}" ] || [ -z "${ES_PASSWORD:-}" ]; then | ||
| 69 | + readarray -t _ES_CREDS < <("$PY_BIN" - <<'PY' | ||
| 70 | +from dotenv import dotenv_values | ||
| 71 | +cfg = dotenv_values('.env') | ||
| 72 | +print(cfg.get('ES_USERNAME') or '') | ||
| 73 | +print(cfg.get('ES_PASSWORD') or '') | ||
| 74 | +PY | ||
| 75 | +) | ||
| 76 | + ES_USERNAME="${ES_USERNAME:-${_ES_CREDS[0]}}" | ||
| 77 | + ES_PASSWORD="${ES_PASSWORD:-${_ES_CREDS[1]}}" | ||
| 78 | + fi | ||
| 79 | + | ||
| 80 | + if [ -n "${ES_USERNAME:-}" ] && [ -n "${ES_PASSWORD:-}" ]; then | ||
| 81 | + AUTH=(-u "${ES_USERNAME}:${ES_PASSWORD}") | ||
| 82 | + else | ||
| 83 | + AUTH=() | ||
| 84 | + fi | ||
| 85 | + | ||
| 86 | + ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_${TENANT_ID}_current" | ||
| 87 | + | ||
| 88 | + echo "[1/4] Full rebuild tenant=${TENANT_ID} (versioned + alias publish)" | ||
| 89 | + "$PY_BIN" main.py build-suggestions \ | ||
| 90 | + --tenant-id "$TENANT_ID" \ | ||
| 91 | + --es-host "$ES_HOST" \ | ||
| 92 | + --mode full \ | ||
| 93 | + --days 365 \ | ||
| 94 | + --batch-size 500 \ | ||
| 95 | + --publish-alias \ | ||
| 96 | + --keep-versions 2 | ||
| 97 | + | ||
| 98 | + echo "[2/4] Incremental update tenant=${TENANT_ID}" | ||
| 99 | + "$PY_BIN" main.py build-suggestions \ | ||
| 100 | + --tenant-id "$TENANT_ID" \ | ||
| 101 | + --es-host "$ES_HOST" \ | ||
| 102 | + --mode incremental \ | ||
| 103 | + --overlap-minutes 30 | ||
| 104 | + | ||
| 105 | + echo "[3/4] ES count + sample" | ||
| 106 | + curl -sS "${AUTH[@]}" "$ES_HOST/$ALIAS_NAME/_count?pretty" | ||
| 107 | + echo | ||
| 108 | + curl -sS "${AUTH[@]}" "$ES_HOST/$ALIAS_NAME/_search?pretty" -H 'Content-Type: application/json' -d '{ | ||
| 109 | + "size": 5, | ||
| 110 | + "query": {"match_all": {}}, | ||
| 111 | + "_source": ["lang", "text", "rank_score", "sources", "query_count_30d"] | ||
| 112 | +}' | ||
| 113 | + echo | ||
| 114 | + | ||
| 115 | + echo "[4/4] API smoke test (base=${API_BASE})" | ||
| 116 | + for lang in "${SAMPLE_LANGS[@]}"; do | ||
| 117 | + for q in "${SAMPLE_QUERIES[@]}"; do | ||
| 118 | + echo "--- GET /search/suggestions?q=${q}&language=${lang} ---" | ||
| 119 | + curl -sS "$API_BASE/search/suggestions?q=${q}&size=10&language=${lang}" -H "X-Tenant-ID: ${TENANT_ID}" | ||
| 120 | + echo | ||
| 121 | + done | ||
| 122 | + done | ||
| 123 | + exit 0 | ||
| 124 | +fi | ||
| 125 | + | ||
| 34 | "$PY_BIN" main.py build-suggestions \ | 126 | "$PY_BIN" main.py build-suggestions \ |
| 35 | --tenant-id "$TENANT_ID" \ | 127 | --tenant-id "$TENANT_ID" \ |
| 36 | - "$@" | 128 | + "${PASSTHROUGH_ARGS[@]}" |
scripts/rebuild_suggestions.sh deleted
| @@ -1,86 +0,0 @@ | @@ -1,86 +0,0 @@ | ||
| 1 | -#!/usr/bin/env bash | ||
| 2 | -set -euo pipefail | ||
| 3 | - | ||
| 4 | -if [ $# -lt 1 ]; then | ||
| 5 | - echo "Usage: $0 <tenant_id>" | ||
| 6 | - echo "Example: $0 162" | ||
| 7 | - exit 1 | ||
| 8 | -fi | ||
| 9 | - | ||
| 10 | -ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" | ||
| 11 | -TENANT_ID="$1" | ||
| 12 | -# Fixed smoke-test queries and languages (no CLI args). | ||
| 13 | -SAMPLE_QUERIES=(s sh dress tshirt) | ||
| 14 | -SAMPLE_LANGS=(en zh) | ||
| 15 | -API_BASE="${API_BASE_URL:-http://localhost:6002}" | ||
| 16 | - | ||
| 17 | -cd "$ROOT_DIR" | ||
| 18 | - | ||
| 19 | -PY_BIN="${PYTHON_BIN:-$ROOT_DIR/.venv/bin/python}" | ||
| 20 | -if [ ! -x "$PY_BIN" ]; then | ||
| 21 | - PY_BIN="python3" | ||
| 22 | -fi | ||
| 23 | - | ||
| 24 | -if [ -z "${ES_HOST:-}" ]; then | ||
| 25 | - ES_HOST="$("$PY_BIN" - <<'PY' | ||
| 26 | -from dotenv import dotenv_values | ||
| 27 | -print(dotenv_values('.env').get('ES_HOST') or 'http://localhost:9200') | ||
| 28 | -PY | ||
| 29 | -)" | ||
| 30 | -fi | ||
| 31 | - | ||
| 32 | -if [ -z "${ES_USERNAME:-}" ] || [ -z "${ES_PASSWORD:-}" ]; then | ||
| 33 | - readarray -t _ES_CREDS < <("$PY_BIN" - <<'PY' | ||
| 34 | -from dotenv import dotenv_values | ||
| 35 | -cfg = dotenv_values('.env') | ||
| 36 | -print(cfg.get('ES_USERNAME') or '') | ||
| 37 | -print(cfg.get('ES_PASSWORD') or '') | ||
| 38 | -PY | ||
| 39 | -) | ||
| 40 | - ES_USERNAME="${ES_USERNAME:-${_ES_CREDS[0]}}" | ||
| 41 | - ES_PASSWORD="${ES_PASSWORD:-${_ES_CREDS[1]}}" | ||
| 42 | -fi | ||
| 43 | - | ||
| 44 | -if [ -n "${ES_USERNAME:-}" ] && [ -n "${ES_PASSWORD:-}" ]; then | ||
| 45 | - AUTH=(-u "${ES_USERNAME}:${ES_PASSWORD}") | ||
| 46 | -else | ||
| 47 | - AUTH=() | ||
| 48 | -fi | ||
| 49 | - | ||
| 50 | -ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_${TENANT_ID}_current" | ||
| 51 | - | ||
| 52 | -echo "[1/4] Full rebuild tenant=${TENANT_ID} (versioned + alias publish)" | ||
| 53 | -"$PY_BIN" main.py build-suggestions \ | ||
| 54 | - --tenant-id "$TENANT_ID" \ | ||
| 55 | - --es-host "$ES_HOST" \ | ||
| 56 | - --mode full \ | ||
| 57 | - --days 365 \ | ||
| 58 | - --batch-size 500 \ | ||
| 59 | - --publish-alias \ | ||
| 60 | - --keep-versions 2 | ||
| 61 | - | ||
| 62 | -echo "[2/4] Incremental update tenant=${TENANT_ID}" | ||
| 63 | -"$PY_BIN" main.py build-suggestions \ | ||
| 64 | - --tenant-id "$TENANT_ID" \ | ||
| 65 | - --es-host "$ES_HOST" \ | ||
| 66 | - --mode incremental \ | ||
| 67 | - --overlap-minutes 30 | ||
| 68 | - | ||
| 69 | -echo "[3/4] ES count + sample" | ||
| 70 | -curl -sS "${AUTH[@]}" "$ES_HOST/$ALIAS_NAME/_count?pretty" | ||
| 71 | -echo | ||
| 72 | -curl -sS "${AUTH[@]}" "$ES_HOST/$ALIAS_NAME/_search?pretty" -H 'Content-Type: application/json' -d '{ | ||
| 73 | - "size": 5, | ||
| 74 | - "query": {"match_all": {}}, | ||
| 75 | - "_source": ["lang", "text", "rank_score", "sources", "query_count_30d"] | ||
| 76 | -}' | ||
| 77 | -echo | ||
| 78 | - | ||
| 79 | -echo "[4/4] API smoke test" | ||
| 80 | -for lang in "${SAMPLE_LANGS[@]}"; do | ||
| 81 | - for q in "${SAMPLE_QUERIES[@]}"; do | ||
| 82 | - echo "--- GET /search/suggestions?q=${q}&language=${lang} ---" | ||
| 83 | - curl -sS "$API_BASE/search/suggestions?q=${q}&size=10&language=${lang}" -H "X-Tenant-ID: ${TENANT_ID}" | ||
| 84 | - echo | ||
| 85 | - done | ||
| 86 | -done |
scripts/service_ctl.sh
| @@ -16,9 +16,10 @@ mkdir -p "${LOG_DIR}" | @@ -16,9 +16,10 @@ mkdir -p "${LOG_DIR}" | ||
| 16 | source "${PROJECT_ROOT}/scripts/lib/load_env.sh" | 16 | source "${PROJECT_ROOT}/scripts/lib/load_env.sh" |
| 17 | 17 | ||
| 18 | CORE_SERVICES=("backend" "indexer" "frontend" "eval-web") | 18 | CORE_SERVICES=("backend" "indexer" "frontend" "eval-web") |
| 19 | -OPTIONAL_SERVICES=("tei" "cnclip" "embedding" "embedding-image" "translator" "reranker" "reranker-fine") | 19 | +# reranker-fine 暂时不用,因此暂时从OPTIONAL_SERVICES中删除 |
| 20 | +OPTIONAL_SERVICES=("tei" "cnclip" "embedding" "embedding-image" "translator" "reranker") | ||
| 20 | FULL_SERVICES=("${OPTIONAL_SERVICES[@]}" "${CORE_SERVICES[@]}") | 21 | FULL_SERVICES=("${OPTIONAL_SERVICES[@]}" "${CORE_SERVICES[@]}") |
| 21 | -STOP_ORDER_SERVICES=("frontend" "eval-web" "indexer" "backend" "reranker-fine" "reranker" "translator" "embedding-image" "embedding" "cnclip" "tei") | 22 | +STOP_ORDER_SERVICES=("frontend" "eval-web" "indexer" "backend" "reranker" "translator" "embedding-image" "embedding" "cnclip" "tei") |
| 22 | 23 | ||
| 23 | all_services() { | 24 | all_services() { |
| 24 | echo "${FULL_SERVICES[@]}" | 25 | echo "${FULL_SERVICES[@]}" |
suggestion/builder.py
| @@ -38,7 +38,7 @@ def get_suggestion_alias_name(tenant_id: str) -> str: | @@ -38,7 +38,7 @@ def get_suggestion_alias_name(tenant_id: str) -> str: | ||
| 38 | 38 | ||
| 39 | def get_suggestion_versioned_index_name(tenant_id: str, build_at: Optional[datetime] = None) -> str: | 39 | def get_suggestion_versioned_index_name(tenant_id: str, build_at: Optional[datetime] = None) -> str: |
| 40 | """Versioned suggestion index name.""" | 40 | """Versioned suggestion index name.""" |
| 41 | - ts = (build_at or datetime.now(timezone.utc)).strftime("%Y%m%d%H%M%S") | 41 | + ts = (build_at or datetime.now(timezone.utc)).strftime("%Y%m%d%H%M%S%f") |
| 42 | return f"{_index_prefix()}search_suggestions_tenant_{tenant_id}_v{ts}" | 42 | return f"{_index_prefix()}search_suggestions_tenant_{tenant_id}_v{ts}" |
| 43 | 43 | ||
| 44 | 44 | ||
| @@ -101,6 +101,79 @@ class SuggestionIndexBuilder: | @@ -101,6 +101,79 @@ class SuggestionIndexBuilder: | ||
| 101 | self.es_client = es_client | 101 | self.es_client = es_client |
| 102 | self.db_engine = db_engine | 102 | self.db_engine = db_engine |
| 103 | 103 | ||
| 104 | + def _format_allocation_failure(self, index_name: str) -> str: | ||
| 105 | + health = self.es_client.wait_for_index_ready(index_name=index_name, timeout="5s") | ||
| 106 | + explain = self.es_client.get_allocation_explain(index_name=index_name) | ||
| 107 | + | ||
| 108 | + parts = [ | ||
| 109 | + f"Suggestion index '{index_name}' was created but is not allocatable/readable yet", | ||
| 110 | + f"health_status={health.get('status')}", | ||
| 111 | + f"timed_out={health.get('timed_out')}", | ||
| 112 | + ] | ||
| 113 | + if health.get("error"): | ||
| 114 | + parts.append(f"health_error={health['error']}") | ||
| 115 | + | ||
| 116 | + if explain: | ||
| 117 | + unassigned = explain.get("unassigned_info") or {} | ||
| 118 | + if unassigned.get("reason"): | ||
| 119 | + parts.append(f"unassigned_reason={unassigned['reason']}") | ||
| 120 | + if unassigned.get("last_allocation_status"): | ||
| 121 | + parts.append(f"last_allocation_status={unassigned['last_allocation_status']}") | ||
| 122 | + | ||
| 123 | + for node in explain.get("node_allocation_decisions") or []: | ||
| 124 | + node_name = node.get("node_name") or node.get("node_id") or "unknown-node" | ||
| 125 | + for decider in node.get("deciders") or []: | ||
| 126 | + if decider.get("decision") == "NO": | ||
| 127 | + parts.append( | ||
| 128 | + f"{node_name}:{decider.get('decider')}={decider.get('explanation')}" | ||
| 129 | + ) | ||
| 130 | + return "; ".join(parts) | ||
| 131 | + | ||
| 132 | + return "; ".join(parts) | ||
| 133 | + | ||
| 134 | + def _create_fresh_versioned_index( | ||
| 135 | + self, | ||
| 136 | + tenant_id: str, | ||
| 137 | + mapping: Dict[str, Any], | ||
| 138 | + max_attempts: int = 5, | ||
| 139 | + ) -> str: | ||
| 140 | + for attempt in range(1, max_attempts + 1): | ||
| 141 | + index_name = get_suggestion_versioned_index_name(tenant_id) | ||
| 142 | + if self.es_client.index_exists(index_name): | ||
| 143 | + logger.warning( | ||
| 144 | + "Suggestion index name collision before create for tenant=%s index=%s attempt=%s/%s", | ||
| 145 | + tenant_id, | ||
| 146 | + index_name, | ||
| 147 | + attempt, | ||
| 148 | + max_attempts, | ||
| 149 | + ) | ||
| 150 | + continue | ||
| 151 | + | ||
| 152 | + if self.es_client.create_index(index_name, mapping): | ||
| 153 | + return index_name | ||
| 154 | + | ||
| 155 | + if self.es_client.index_exists(index_name): | ||
| 156 | + logger.warning( | ||
| 157 | + "Suggestion index name collision during create for tenant=%s index=%s attempt=%s/%s", | ||
| 158 | + tenant_id, | ||
| 159 | + index_name, | ||
| 160 | + attempt, | ||
| 161 | + max_attempts, | ||
| 162 | + ) | ||
| 163 | + continue | ||
| 164 | + | ||
| 165 | + raise RuntimeError(f"Failed to create suggestion index: {index_name}") | ||
| 166 | + | ||
| 167 | + raise RuntimeError( | ||
| 168 | + f"Failed to allocate a unique suggestion index name for tenant={tenant_id} after {max_attempts} attempts" | ||
| 169 | + ) | ||
| 170 | + | ||
| 171 | + def _ensure_new_index_ready(self, index_name: str) -> None: | ||
| 172 | + health = self.es_client.wait_for_index_ready(index_name=index_name, timeout="5s") | ||
| 173 | + if health.get("ok"): | ||
| 174 | + return | ||
| 175 | + raise RuntimeError(self._format_allocation_failure(index_name)) | ||
| 176 | + | ||
| 104 | @staticmethod | 177 | @staticmethod |
| 105 | def _to_utc(dt: Any) -> Optional[datetime]: | 178 | def _to_utc(dt: Any) -> Optional[datetime]: |
| 106 | if dt is None: | 179 | if dt is None: |
| @@ -297,7 +370,7 @@ class SuggestionIndexBuilder: | @@ -297,7 +370,7 @@ class SuggestionIndexBuilder: | ||
| 297 | while True: | 370 | while True: |
| 298 | body: Dict[str, Any] = { | 371 | body: Dict[str, Any] = { |
| 299 | "size": batch_size, | 372 | "size": batch_size, |
| 300 | - "_source": ["id", "spu_id", "title", "qanchors", "tags"], | 373 | + "_source": ["id", "spu_id", "title", "qanchors", "enriched_tags"], |
| 301 | "sort": [ | 374 | "sort": [ |
| 302 | {"spu_id": {"order": "asc", "missing": "_last"}}, | 375 | {"spu_id": {"order": "asc", "missing": "_last"}}, |
| 303 | {"id.keyword": {"order": "asc", "missing": "_last"}}, | 376 | {"id.keyword": {"order": "asc", "missing": "_last"}}, |
| @@ -511,7 +584,7 @@ class SuggestionIndexBuilder: | @@ -511,7 +584,7 @@ class SuggestionIndexBuilder: | ||
| 511 | c.add_product("qanchor", spu_id=product_id) | 584 | c.add_product("qanchor", spu_id=product_id) |
| 512 | 585 | ||
| 513 | for tag_lang, tag in self._iter_multilang_product_tags( | 586 | for tag_lang, tag in self._iter_multilang_product_tags( |
| 514 | - src.get("tags"), | 587 | + src.get("enriched_tags"), |
| 515 | index_languages=index_languages, | 588 | index_languages=index_languages, |
| 516 | primary_language=primary_language, | 589 | primary_language=primary_language, |
| 517 | ): | 590 | ): |
| @@ -609,62 +682,65 @@ class SuggestionIndexBuilder: | @@ -609,62 +682,65 @@ class SuggestionIndexBuilder: | ||
| 609 | index_languages: List[str] = tenant_cfg.get("index_languages") or ["en", "zh"] | 682 | index_languages: List[str] = tenant_cfg.get("index_languages") or ["en", "zh"] |
| 610 | primary_language: str = tenant_cfg.get("primary_language") or "en" | 683 | primary_language: str = tenant_cfg.get("primary_language") or "en" |
| 611 | 684 | ||
| 612 | - # Always write to a fresh versioned index; legacy concrete index is no longer supported. | ||
| 613 | - index_name = get_suggestion_versioned_index_name(tenant_id) | ||
| 614 | - | ||
| 615 | - if self.es_client.index_exists(index_name): | ||
| 616 | - raise RuntimeError(f"Target suggestion index already exists: {index_name}") | ||
| 617 | - | ||
| 618 | - mapping = build_suggestion_mapping(index_languages=index_languages) | ||
| 619 | - if not self.es_client.create_index(index_name, mapping): | ||
| 620 | - raise RuntimeError(f"Failed to create suggestion index: {index_name}") | 685 | + alias_publish: Optional[Dict[str, Any]] = None |
| 686 | + index_name: Optional[str] = None | ||
| 687 | + try: | ||
| 688 | + mapping = build_suggestion_mapping(index_languages=index_languages) | ||
| 689 | + index_name = self._create_fresh_versioned_index( | ||
| 690 | + tenant_id=tenant_id, | ||
| 691 | + mapping=mapping, | ||
| 692 | + ) | ||
| 693 | + self._ensure_new_index_ready(index_name) | ||
| 621 | 694 | ||
| 622 | - key_to_candidate = self._build_full_candidates( | ||
| 623 | - tenant_id=tenant_id, | ||
| 624 | - index_languages=index_languages, | ||
| 625 | - primary_language=primary_language, | ||
| 626 | - days=days, | ||
| 627 | - batch_size=batch_size, | ||
| 628 | - min_query_len=min_query_len, | ||
| 629 | - ) | 695 | + key_to_candidate = self._build_full_candidates( |
| 696 | + tenant_id=tenant_id, | ||
| 697 | + index_languages=index_languages, | ||
| 698 | + primary_language=primary_language, | ||
| 699 | + days=days, | ||
| 700 | + batch_size=batch_size, | ||
| 701 | + min_query_len=min_query_len, | ||
| 702 | + ) | ||
| 630 | 703 | ||
| 631 | - now_iso = datetime.now(timezone.utc).isoformat() | ||
| 632 | - docs = [self._candidate_to_doc(tenant_id, c, now_iso) for c in key_to_candidate.values()] | 704 | + now_iso = datetime.now(timezone.utc).isoformat() |
| 705 | + docs = [self._candidate_to_doc(tenant_id, c, now_iso) for c in key_to_candidate.values()] | ||
| 633 | 706 | ||
| 634 | - if docs: | ||
| 635 | - bulk_result = self.es_client.bulk_index(index_name=index_name, docs=docs) | ||
| 636 | - self.es_client.refresh(index_name) | ||
| 637 | - else: | ||
| 638 | - bulk_result = {"success": 0, "failed": 0, "errors": []} | 707 | + if docs: |
| 708 | + bulk_result = self.es_client.bulk_index(index_name=index_name, docs=docs) | ||
| 709 | + self.es_client.refresh(index_name) | ||
| 710 | + else: | ||
| 711 | + bulk_result = {"success": 0, "failed": 0, "errors": []} | ||
| 639 | 712 | ||
| 640 | - alias_publish: Optional[Dict[str, Any]] = None | ||
| 641 | - if publish_alias: | ||
| 642 | - alias_publish = self._publish_alias( | ||
| 643 | - tenant_id=tenant_id, | ||
| 644 | - index_name=index_name, | ||
| 645 | - keep_versions=keep_versions, | ||
| 646 | - ) | 713 | + if publish_alias: |
| 714 | + alias_publish = self._publish_alias( | ||
| 715 | + tenant_id=tenant_id, | ||
| 716 | + index_name=index_name, | ||
| 717 | + keep_versions=keep_versions, | ||
| 718 | + ) | ||
| 647 | 719 | ||
| 648 | - now_utc = datetime.now(timezone.utc).isoformat() | ||
| 649 | - meta_patch: Dict[str, Any] = { | ||
| 650 | - "last_full_build_at": now_utc, | ||
| 651 | - "last_incremental_watermark": now_utc, | ||
| 652 | - } | ||
| 653 | - if publish_alias: | ||
| 654 | - meta_patch["active_index"] = index_name | ||
| 655 | - meta_patch["active_alias"] = get_suggestion_alias_name(tenant_id) | ||
| 656 | - self._upsert_meta(tenant_id, meta_patch) | 720 | + now_utc = datetime.now(timezone.utc).isoformat() |
| 721 | + meta_patch: Dict[str, Any] = { | ||
| 722 | + "last_full_build_at": now_utc, | ||
| 723 | + "last_incremental_watermark": now_utc, | ||
| 724 | + } | ||
| 725 | + if publish_alias: | ||
| 726 | + meta_patch["active_index"] = index_name | ||
| 727 | + meta_patch["active_alias"] = get_suggestion_alias_name(tenant_id) | ||
| 728 | + self._upsert_meta(tenant_id, meta_patch) | ||
| 657 | 729 | ||
| 658 | - return { | ||
| 659 | - "mode": "full", | ||
| 660 | - "tenant_id": str(tenant_id), | ||
| 661 | - "index_name": index_name, | ||
| 662 | - "alias_published": bool(alias_publish), | ||
| 663 | - "alias_publish": alias_publish, | ||
| 664 | - "total_candidates": len(key_to_candidate), | ||
| 665 | - "indexed_docs": len(docs), | ||
| 666 | - "bulk_result": bulk_result, | ||
| 667 | - } | 730 | + return { |
| 731 | + "mode": "full", | ||
| 732 | + "tenant_id": str(tenant_id), | ||
| 733 | + "index_name": index_name, | ||
| 734 | + "alias_published": bool(alias_publish), | ||
| 735 | + "alias_publish": alias_publish, | ||
| 736 | + "total_candidates": len(key_to_candidate), | ||
| 737 | + "indexed_docs": len(docs), | ||
| 738 | + "bulk_result": bulk_result, | ||
| 739 | + } | ||
| 740 | + except Exception: | ||
| 741 | + if index_name and not alias_publish: | ||
| 742 | + self.es_client.delete_index(index_name) | ||
| 743 | + raise | ||
| 668 | 744 | ||
| 669 | def _build_incremental_deltas( | 745 | def _build_incremental_deltas( |
| 670 | self, | 746 | self, |
tests/test_suggestions.py
| @@ -8,6 +8,7 @@ from suggestion.builder import ( | @@ -8,6 +8,7 @@ from suggestion.builder import ( | ||
| 8 | QueryDelta, | 8 | QueryDelta, |
| 9 | SuggestionIndexBuilder, | 9 | SuggestionIndexBuilder, |
| 10 | get_suggestion_alias_name, | 10 | get_suggestion_alias_name, |
| 11 | + get_suggestion_versioned_index_name, | ||
| 11 | ) | 12 | ) |
| 12 | from suggestion.service import SuggestionService | 13 | from suggestion.service import SuggestionService |
| 13 | 14 | ||
| @@ -121,6 +122,16 @@ class FakeESClient: | @@ -121,6 +122,16 @@ class FakeESClient: | ||
| 121 | self.indices.add(index_name) | 122 | self.indices.add(index_name) |
| 122 | return True | 123 | return True |
| 123 | 124 | ||
| 125 | + def wait_for_index_ready(self, index_name: str, timeout: str = "10s") -> Dict[str, Any]: | ||
| 126 | + self.calls.append({"op": "wait_for_index_ready", "index": index_name, "timeout": timeout}) | ||
| 127 | + return {"ok": True, "status": "green", "timed_out": False} | ||
| 128 | + | ||
| 129 | + def get_allocation_explain(self, index_name: str, shard: int = 0, primary: bool = True) -> Dict[str, Any] | None: | ||
| 130 | + self.calls.append( | ||
| 131 | + {"op": "get_allocation_explain", "index": index_name, "shard": shard, "primary": primary} | ||
| 132 | + ) | ||
| 133 | + return None | ||
| 134 | + | ||
| 124 | def refresh(self, index_name: str) -> bool: | 135 | def refresh(self, index_name: str) -> bool: |
| 125 | self.calls.append({"op": "refresh", "index": index_name}) | 136 | self.calls.append({"op": "refresh", "index": index_name}) |
| 126 | return True | 137 | return True |
| @@ -150,6 +161,67 @@ class FakeESClient: | @@ -150,6 +161,67 @@ class FakeESClient: | ||
| 150 | 161 | ||
| 151 | 162 | ||
| 152 | @pytest.mark.unit | 163 | @pytest.mark.unit |
| 164 | +def test_versioned_index_name_uses_microseconds(): | ||
| 165 | + build_at = datetime(2026, 4, 7, 3, 52, 26, 123456, tzinfo=timezone.utc) | ||
| 166 | + assert ( | ||
| 167 | + get_suggestion_versioned_index_name("163", build_at) | ||
| 168 | + == "search_suggestions_tenant_163_v20260407035226123456" | ||
| 169 | + ) | ||
| 170 | + | ||
| 171 | + | ||
| 172 | +@pytest.mark.unit | ||
| 173 | +def test_rebuild_cleans_up_unallocatable_new_index(): | ||
| 174 | + fake_es = FakeESClient() | ||
| 175 | + | ||
| 176 | + def _wait_fail(index_name: str, timeout: str = "10s") -> Dict[str, Any]: | ||
| 177 | + fake_es.calls.append({"op": "wait_for_index_ready", "index": index_name, "timeout": timeout}) | ||
| 178 | + return {"ok": False, "status": "red", "timed_out": True} | ||
| 179 | + | ||
| 180 | + def _allocation_explain(index_name: str, shard: int = 0, primary: bool = True) -> Dict[str, Any]: | ||
| 181 | + fake_es.calls.append( | ||
| 182 | + {"op": "get_allocation_explain", "index": index_name, "shard": shard, "primary": primary} | ||
| 183 | + ) | ||
| 184 | + return { | ||
| 185 | + "unassigned_info": {"reason": "INDEX_CREATED", "last_allocation_status": "no"}, | ||
| 186 | + "node_allocation_decisions": [ | ||
| 187 | + { | ||
| 188 | + "node_name": "node-1", | ||
| 189 | + "deciders": [ | ||
| 190 | + { | ||
| 191 | + "decider": "disk_threshold", | ||
| 192 | + "decision": "NO", | ||
| 193 | + "explanation": "node is above high watermark", | ||
| 194 | + } | ||
| 195 | + ], | ||
| 196 | + } | ||
| 197 | + ], | ||
| 198 | + } | ||
| 199 | + | ||
| 200 | + fake_es.wait_for_index_ready = _wait_fail # type: ignore[method-assign] | ||
| 201 | + fake_es.get_allocation_explain = _allocation_explain # type: ignore[method-assign] | ||
| 202 | + | ||
| 203 | + builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None) | ||
| 204 | + | ||
| 205 | + from config import tenant_config_loader as tcl | ||
| 206 | + | ||
| 207 | + loader = tcl.get_tenant_config_loader() | ||
| 208 | + loader._config = { | ||
| 209 | + "default": {"primary_language": "en", "index_languages": ["en", "zh"]}, | ||
| 210 | + "tenants": { | ||
| 211 | + "163": {"primary_language": "en", "index_languages": ["en", "zh"]}, | ||
| 212 | + }, | ||
| 213 | + } | ||
| 214 | + | ||
| 215 | + with pytest.raises(RuntimeError, match="disk_threshold"): | ||
| 216 | + builder.rebuild_tenant_index(tenant_id="163") | ||
| 217 | + | ||
| 218 | + create_calls = [x for x in fake_es.calls if x.get("op") == "create_index"] | ||
| 219 | + assert len(create_calls) == 1 | ||
| 220 | + created_index = create_calls[0]["index"] | ||
| 221 | + assert created_index not in fake_es.indices | ||
| 222 | + | ||
| 223 | + | ||
| 224 | +@pytest.mark.unit | ||
| 153 | def test_resolve_query_language_prefers_log_field(): | 225 | def test_resolve_query_language_prefers_log_field(): |
| 154 | fake_es = FakeESClient() | 226 | fake_es = FakeESClient() |
| 155 | builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None) | 227 | builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None) |
| @@ -406,7 +478,7 @@ def test_build_full_candidates_tags_and_qanchor_phrases(monkeypatch): | @@ -406,7 +478,7 @@ def test_build_full_candidates_tags_and_qanchor_phrases(monkeypatch): | ||
| 406 | "en": ["slim fit", "sporty casual"], | 478 | "en": ["slim fit", "sporty casual"], |
| 407 | "zh": ["修身", "显瘦"], | 479 | "zh": ["修身", "显瘦"], |
| 408 | }, | 480 | }, |
| 409 | - "tags": { | 481 | + "enriched_tags": { |
| 410 | "en": ["Classic", "ribbed neckline"], | 482 | "en": ["Classic", "ribbed neckline"], |
| 411 | "zh": ["辣妹风"], | 483 | "zh": ["辣妹风"], |
| 412 | }, | 484 | }, |
utils/es_client.py
| @@ -76,13 +76,70 @@ class ESClient: | @@ -76,13 +76,70 @@ class ESClient: | ||
| 76 | True if successful, False otherwise | 76 | True if successful, False otherwise |
| 77 | """ | 77 | """ |
| 78 | try: | 78 | try: |
| 79 | - self.client.indices.create(index=index_name, body=body) | 79 | + client = self.client.options(request_timeout=30, max_retries=0) |
| 80 | + client.indices.create( | ||
| 81 | + index=index_name, | ||
| 82 | + body=body, | ||
| 83 | + wait_for_active_shards="0", | ||
| 84 | + ) | ||
| 80 | logger.info(f"Index '{index_name}' created successfully") | 85 | logger.info(f"Index '{index_name}' created successfully") |
| 81 | return True | 86 | return True |
| 82 | except Exception as e: | 87 | except Exception as e: |
| 88 | + if self.index_exists(index_name): | ||
| 89 | + logger.warning( | ||
| 90 | + "Create index request for '%s' raised %s, but the index now exists; treating it as created", | ||
| 91 | + index_name, | ||
| 92 | + type(e).__name__, | ||
| 93 | + exc_info=True, | ||
| 94 | + ) | ||
| 95 | + return True | ||
| 83 | logger.error(f"Failed to create index '{index_name}': {e}", exc_info=True) | 96 | logger.error(f"Failed to create index '{index_name}': {e}", exc_info=True) |
| 84 | return False | 97 | return False |
| 85 | 98 | ||
| 99 | + def wait_for_index_ready(self, index_name: str, timeout: str = "10s") -> Dict[str, Any]: | ||
| 100 | + """Wait until an index primary shard is allocated and searchable.""" | ||
| 101 | + try: | ||
| 102 | + resp = self.client.cluster.health( | ||
| 103 | + index=index_name, | ||
| 104 | + wait_for_status="yellow", | ||
| 105 | + timeout=timeout, | ||
| 106 | + level="indices", | ||
| 107 | + ) | ||
| 108 | + index_info = ((resp.get("indices") or {}).get(index_name) or {}) | ||
| 109 | + status = index_info.get("status") or resp.get("status") | ||
| 110 | + timed_out = bool(resp.get("timed_out")) | ||
| 111 | + return { | ||
| 112 | + "ok": (not timed_out) and status in {"yellow", "green"}, | ||
| 113 | + "status": status, | ||
| 114 | + "timed_out": timed_out, | ||
| 115 | + "response": resp, | ||
| 116 | + } | ||
| 117 | + except Exception as e: | ||
| 118 | + logger.error("Failed waiting for index '%s' readiness: %s", index_name, e, exc_info=True) | ||
| 119 | + return { | ||
| 120 | + "ok": False, | ||
| 121 | + "status": "unknown", | ||
| 122 | + "timed_out": False, | ||
| 123 | + "error": str(e), | ||
| 124 | + } | ||
| 125 | + | ||
| 126 | + def get_allocation_explain(self, index_name: str, shard: int = 0, primary: bool = True) -> Optional[Dict[str, Any]]: | ||
| 127 | + """Explain why a shard can or cannot be allocated.""" | ||
| 128 | + try: | ||
| 129 | + return self.client.cluster.allocation_explain( | ||
| 130 | + body={"index": index_name, "shard": shard, "primary": primary} | ||
| 131 | + ) | ||
| 132 | + except Exception as e: | ||
| 133 | + logger.warning( | ||
| 134 | + "Failed to get allocation explain for index '%s' shard=%s primary=%s: %s", | ||
| 135 | + index_name, | ||
| 136 | + shard, | ||
| 137 | + primary, | ||
| 138 | + e, | ||
| 139 | + exc_info=True, | ||
| 140 | + ) | ||
| 141 | + return None | ||
| 142 | + | ||
| 86 | def put_alias(self, index_name: str, alias_name: str) -> bool: | 143 | def put_alias(self, index_name: str, alias_name: str) -> bool: |
| 87 | """Add alias for an index.""" | 144 | """Add alias for an index.""" |
| 88 | try: | 145 | try: |