diff --git a/docs/issues/issue-2026-04-06-推理优化-2.md b/docs/issues/issue-2026-04-06-推理优化-2.md deleted file mode 100644 index 60b5aa3..0000000 --- a/docs/issues/issue-2026-04-06-推理优化-2.md +++ /dev/null @@ -1,69 +0,0 @@ -这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查: - -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 - -示例prompt和对应的9个分类词: -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} - -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。 - -使用Tesla T4,因此: -1. 使用FP16。不用 BF16 作为主路径 -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 - - -主要考虑优化方向为: -hidden_last -> N-class scorer -> argmax -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV -去 full vocab logits -去 decode / constrained decode -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) -专用 tail kernel(输出 N 类原始分数) - -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: -15.1 预分配 -对每个 bucket 预分配: -input ids buffer -attention scratch buffer -output hidden buffer -class id buffer -15.2 Graph capture - -每个 bucket 预先 capture 一张图: -embedding -transformer 主干 -compact last-state -9-way tail kernel - -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 - -你有sudo权限,你可以执行为本项目安装自己的环境 - -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 - - -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 -注意使用fp16版本,不要使用量化版本。 - -上面的需求你已经满足了大部分, -但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。 -需要通过多 token 标签做极致的性能优化,避免串行decode。 -一个考虑的方向是: -将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch) - -只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。 - -所有后续聚合均为向量化操作(GPU 上几乎不花时间)。 - -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 - -参考文档: - -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq \ No newline at end of file diff --git a/docs/issues/issue-2026-04-06-推理优化-3.md b/docs/issues/issue-2026-04-06-推理优化-3.md deleted file mode 100644 index 3a6e4fa..0000000 --- a/docs/issues/issue-2026-04-06-推理优化-3.md +++ /dev/null @@ -1,98 +0,0 @@ -这是我的第一个版本的需求,你已经完成了大部分,请你结合代码现状进行检查: - -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 - -示例prompt和对应的9个分类词: -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} - -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟以及每个标签的精确得分。 - -使用Tesla T4,因此: -1. 使用FP16。不用 BF16 作为主路径 -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 - - -主要考虑优化方向为: -hidden_last -> N-class scorer -> argmax -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV -去 full vocab logits -去 decode / constrained decode -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) -专用 tail kernel(输出 N 类原始分数) - -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: -15.1 预分配 -对每个 bucket 预分配: -input ids buffer -attention scratch buffer -output hidden buffer -class id buffer -15.2 Graph capture - -每个 bucket 预先 capture 一张图: -embedding -transformer 主干 -compact last-state -9-way tail kernel - -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 - -你有sudo权限,你可以执行为本项目安装自己的环境 - -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 - - -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 -注意使用fp16版本,不要使用量化版本。 - -上面的需求你已经满足了大部分, -但是一个重要的、还未处理好的问题:目前的版本应该是针对decode 单token来做的,但是,一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况,不需要为单 token 设计,但是每个分类词仍然是少数token。 -需要通过多 token 标签做极致的性能优化,避免串行decode。 -一个考虑的方向是: -将 HF 的多 token candidate_reduce 路径,改为 融合/向量化实现、向量化方案(batch padding,一次模型 forward,处理整个 batch) - -只做一次模型 forward,输入长度为 total_tokens(通常远小于 batch_size * max_label_len 的循环总和)。 - -所有后续聚合均为向量化操作(GPU 上几乎不花时间)。 - -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 - -参考文档: - -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq - - -现在还有以下问题: -1. 请满足: -先专注于推理的优化,不用考虑服务化,先支持cli模式读取query即可,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,则进行推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 - -2. 打分结果不对:不管输入什么都输出jeans,输入裙子也输出jeans,只有精确的输入skirt才能输出skirt。 - -3. 现在是一个prompt,我需要增加6个prompt,都进行配置化,每次输入一个query、对7个prompt进行批量的推理,对每个prompt都输出词典中每个分类的打分分布 -人群 -Analyze the target user group of the given query. Output exactly one label from: boy, girl, man, woman, pregnant. Output nothing else query:{query} - -尺寸 -Analyze the size intent of the given query. Output exactly one label from: xs, s, m, l, xl, xxl, plus, petite. Output nothing else query:{query} - -领口 -Analyze the neckline type of the given query. Output exactly one label from: round, v, turtleneck, collared, scoop, offshoulder. Output nothing else query:{query} - -袖长类型 -Analyze the sleeve length of the given query. Output exactly one label from: long, short, sleeveless, half, threequarter, cap. Output nothing else query:{query} - -上衣长度类型 -Analyze the top length of the given query. Output exactly one label from: long, short, regular, cropped, tunic, midi. Output nothing else query:{query} - -面料 -Analyze the fabric material of the given query. Output exactly one label from: cotton, linen, silk, wool, polyester, denim, leather, chiffon, fleece. Output nothing else query:{query} - - -注意要使用测试用例进行测试。包括打分结果是否符合预期、性能测试。 diff --git a/docs/issues/issue-2026-04-06-推理优化-4.md b/docs/issues/issue-2026-04-06-推理优化-4.md deleted file mode 100644 index 025a763..0000000 --- a/docs/issues/issue-2026-04-06-推理优化-4.md +++ /dev/null @@ -1,4 +0,0 @@ - -1. “仓库启动时会为所有 batch_sizes x continuation_buckets x prompt 预建 bucket”,之前可能是根据“极致的性能要求”、“不要做任何懒加载,确保真实请求发生时得到极致的响应时间”所做的设计,这显然太过了,我是希望一些基本的可以事先加载的应该先加载,但是牺牲巨大的显存占用来换取微弱的耗时提升是不提倡的,请你站在更高的角度,理会我的需求(先加载好模型、跨session的KV cache、并针对特殊用法 即score方式而不是逐步decode方式,来极致的优化性能,用于对线上的单个query,以最短耗时得到7个prompt的分类结果) - -2. 现在的 7 个 prompt推理是串行的,MultiPromptRunner.score_query() 里就是 for runner in self.runners 一个一个跑,需要把执行模型改成“按 prompt 分组批量并行”,但是现在有 2 个 fast path 和 5 个 multi-token path,是各合成一次 forward,还是有可能合成一个?因为multi-token也可以一次线prefill进去,是否能做到跟fast path同级别的性能?请你站在更高的角度进行思考,保证性能的同时降低复杂性。 \ No newline at end of file diff --git a/docs/issues/issue-2026-04-06-推理优化-重建.md b/docs/issues/issue-2026-04-06-推理优化-重建.md deleted file mode 100644 index 4d948d9..0000000 --- a/docs/issues/issue-2026-04-06-推理优化-重建.md +++ /dev/null @@ -1,1003 +0,0 @@ -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 -先专注于推理的优化,最后再考虑服务化,支持一定程度的并发(比如4)的请求,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,使用N个prompt进行N个维度的分类,对于每个prompt,推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 - -下面有一些参考技术资料,但是你并不需要严格,你应该有一定的灵活度,来追求极致的性能。 - -在 Tesla T4 上,用 3B 到 6B 级别的开源 decoder-only 基座模型做 query 分类。 -启动时完成 tokenizer、权重、prefix cache 和共享执行器准备工作。 -每次输入一个 query,输出每个 prompt 下每个 label 的分数分布,以及预测耗时和阶段耗时。 -不走通用生成路径,不做 decode,不取 full vocab logits,不做 constrained decode。 -对 multi-token label 做专门优化,避免 Python 侧串行 decode。 -prompt 和 label 集合必须可配置(目前只有两个,以后我会加到8个,每次请求输入一个query,并行的调用8个prompt进行推理得到得分最高的label: - - -prompts: - - name: category - prompt_template: | - <|im_start|>system - Analyze the category intent. Output exactly one label from [none, dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other]. Use 'none' if no category intent.<|im_end|> - <|im_start|>user - query: {query}<|im_end|> - <|im_start|>assistant - label: - label_prefix: " " - labels: [dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other, none] - - - name: audience - prompt_template: | - <|im_start|>system - Analyze the target user group. Output exactly one label from [none, boy, girl, man, woman, pregnant]. Use 'none' if no audience mentioned.<|im_end|> - <|im_start|>user - query: {query}<|im_end|> - <|im_start|>assistant - label: - label_prefix: " " - labels: [boy, girl, man, woman, pregnant, none] - - -做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟。 - -主要考虑优化方向为: -1. hidden_last -> N-class scorer -> argmax -2. 参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV -3. 去 full vocab logits -4. 去 decode / constrained decode -5. 专用 tail kernel(输出 N 类原始分数) -6. 配置的N个 prompt推理要并行推理(2-8个) -7. 使用Tesla T4,因此不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention -多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 - - -你有sudo权限,你可以执行为本项目安装自己的环境 - -使用Qwen/Qwen3-8B的Q4或Q8模型,具体用哪个版本,请你查找huggingface相关资料,选择合适的版本完成部署,并进行推理耗时的测试。 - -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 - -一个重要的问题:一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况。 -需要通过多 token 标签做极致的性能优化,避免串行decode。 -我们最终目的是得到哪个label的得分最高,不一定要精确的概率,计算log P(id1 | query, prompt) + log P(id2 | query, prompt, id1)有可能导致难以优化性能,精确的概率是可以考虑放弃的,要清楚我们的最终目的,达到分类的目的即可,只要得到分类,优先考虑性能,精确的概率可以放下。 -如何通过一次模型 forward处理包括多token label的整个 batch,是你需要探索的问题。 - -单 token fast path 的做法比较确定: last_hidden -> small class scorer -> argmax。 只取目标 label 对应 LM head 行,不做 full vocab 输出。 -multi-token 怎么做需要搜索相关资料进行考量,最好要做到跟单token开销相同(放弃精确的log-prob的前提下。但是:多token和单token的label的打分的对比,一定要是可比的才能正确的分类,兼顾性能和打分的准确性) - -还需要增加一个配置:force_single_token_labels,所有 label 都按首 token 处理,因为,如果各个label收token不同,那么可以近似的认为首token打分代表整个label打分。 -你需要找到多label打分性能和准确性上面的最佳实践。同时也支持force_single_token_labels以达到极致的性能。 - -也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 -以下是一些参考示例: -vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ -PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ -TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html -Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq - -TensorRT support matrix: T4 / SM7.5 supports FP16 and INT8, but not BF16/FP8 in the main matrix. https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Support-Matrix-Guide.pdf -TensorRT-LLM support matrix: current official hardware list omits Turing/T4, so T4 is effectively community-support territory there. https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html -FlashAttention repo: FA3 is Hopper-focused; current published benchmarks are A100/H100-centric. https://github.com/Dao-AILab/flash-attention -vLLM APC docs: KV reuse via hashed KV blocks was the baseline idea for the prefix-cache metadata. https://docs.vllm.ai/_/downloads/en/v0.6.2/pdf/ -SGLang HiCache/RadixAttention docs: useful reference for prefix-cache reuse and page-granular KV organization. https://docs.sglang.io/advanced_features/hicache_design.html -FasterTransformer repo: still a useful T4 FP16 optimization baseline and historical Turing-oriented reference. https://github.com/NVIDIA/FasterTransformer -xFormers README: relevant as a Turing-friendly attention alternative; my mainline choice here is PyTorch SDPA on T4, which is an engineering inference from these sources rather than a direct vendor recommendation. https://github.com/facebookresearch/xformers - -注意:已经有一个项目 llm-qp, llm-qp2,这两个项目,对于单token的处理方式是可以的: -SDPA -prefix cache -prebuilt bucket + CUDA graph -他的核心代码是: -from __future__ import annotations - -import hashlib -import time -from dataclasses import asdict, dataclass -from typing import Iterable - -import torch -import torch.nn.functional as F -from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache - -try: - from transformers import BitsAndBytesConfig -except ImportError: # pragma: no cover - BitsAndBytesConfig = None - -from llm_qp.config import PromptTaskConfig, RuntimeConfig -from llm_qp.scorer import SmallClassScorer - -try: - from torch.nn.attention import SDPBackend, sdpa_kernel -except ImportError: # pragma: no cover - SDPBackend = None - sdpa_kernel = None - - -@dataclass(slots=True) -class EncodedLabel: - text: str - token_ids: list[int] - - -@dataclass(slots=True) -class PrefixCache: - prefix_ids: list[int] - prefix_hashes: list[str] - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] - - @property - def prefix_len(self) -> int: - return len(self.prefix_ids) - - -@dataclass(slots=True) -class MultiTokenTables: - label_token_ids: torch.Tensor - label_token_mask: torch.Tensor - label_prefix_ids: torch.Tensor - label_prefix_mask: torch.Tensor - label_position_offsets: torch.Tensor - - @property - def max_label_len(self) -> int: - return self.label_token_ids.shape[1] - - @property - def max_label_prefix_len(self) -> int: - return self.label_prefix_ids.shape[1] - - -@dataclass(slots=True) -class QueryScoreResult: - task_name: str - query: str - predicted_label: str - scores: list[tuple[str, float, float]] - total_ms: float - stage_ms: dict[str, float] - fast_path: bool - prefix_tokens: int - continuation_tokens: int - label_token_lengths: dict[str, int] - - @property - def predicted_prob(self) -> float: - for label, _score, prob in self.scores: - if label == self.predicted_label: - return prob - return 0.0 - - -@dataclass(slots=True) -class MultiPromptScoreResult: - query: str - total_ms: float - details: list[QueryScoreResult] - stage_ms: dict[str, float] - - def http_json(self) -> dict[str, object]: - return { - "query": self.query, - "total_ms": self.total_ms, - "stage_ms": self.stage_ms, - "details": [asdict(t) for t in self.details], - "task_results": { - t.task_name: [t.predicted_label, t.continuation_tokens, t.predicted_prob] for t in self.details if t.predicted_label != 'none' - }, - } - - -@dataclass(slots=True) -class BatchScoreResult: - batch_size: int - total_ms: float - results: list[MultiPromptScoreResult] - stage_ms: dict[str, float] - - -@dataclass(slots=True) -class SharedRuntime: - device: torch.device - dtype: torch.dtype - tokenizer: object - model: object - backbone: object - hidden_size: int - graph_capture_pool: object | None = None - graph_capture_stream: torch.cuda.Stream | None = None - - -@dataclass(slots=True) -class PromptBatchPlan: - runner: "PromptClassifierRunner" - row_start: int - row_count: int - score_buffer: torch.Tensor - - @property - def row_stop(self) -> int: - return self.row_start + self.row_count - - -@dataclass(slots=True) -class MixedPrefixCache: - batch_size: int - total_rows: int - prefix_lengths: torch.Tensor - attention_mask: torch.Tensor - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] - - @property - def max_prefix_len(self) -> int: - return int(self.prefix_lengths.max().item()) if self.prefix_lengths.numel() else 0 - - -@dataclass(slots=True) -class BatchLayout: - batch_size: int - total_rows: int - plans: list[PromptBatchPlan] - - -@dataclass(slots=True) -class MixedBucketRuntime: - batch_size: int - total_rows: int - continuation_len: int - max_input_len: int - input_ids: torch.Tensor - attention_mask: torch.Tensor - position_ids: torch.Tensor - last_hidden_state: torch.Tensor - graph: torch.cuda.CUDAGraph | None = None - - -@dataclass(slots=True) -class PreloadReport: - total_ms: float - stage_ms: dict[str, float] - runtime: dict[str, object] - - -def _hash_blocks(token_ids: Iterable[int], block_size: int) -> list[str]: - token_list = list(token_ids) - hashes: list[str] = [] - for start in range(0, len(token_list), block_size): - block = token_list[start : start + block_size] - payload = ",".join(str(x) for x in block).encode("utf-8") - hashes.append(hashlib.sha1(payload).hexdigest()) - return hashes - - -def _expand_legacy_cache( - raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...], - batch_size: int, -) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: - expanded: list[tuple[torch.Tensor, torch.Tensor]] = [] - for key, value in raw_cache: - expanded.append( - ( - key.expand(batch_size, *key.shape[1:]).contiguous(), - value.expand(batch_size, *value.shape[1:]).contiguous(), - ) - ) - return tuple(expanded) - - -class PromptClassifierRunner: - def __init__( - self, - cfg: RuntimeConfig, - task_cfg: PromptTaskConfig, - shared_runtime: SharedRuntime, - ): - self.cfg = cfg - self.task_cfg = task_cfg - self.device = shared_runtime.device - self.dtype = shared_runtime.dtype - self.tokenizer = shared_runtime.tokenizer - self.model = shared_runtime.model - self.backbone = shared_runtime.backbone - self.hidden_size = shared_runtime.hidden_size - self.prefix_text, self.suffix_text = task_cfg.prompt_parts - self.prefix_ids = self.tokenizer.encode(self.prefix_text, add_special_tokens=False) - self.suffix_ids = self.tokenizer.encode(self.suffix_text, add_special_tokens=False) - self.labels = list(task_cfg.labels) - self.encoded_labels = [ - EncodedLabel(text=label, token_ids=self._encode_label_token_ids(label)) - for label in self.labels - ] - self.num_labels = len(self.labels) - self.lm_head = self.model.get_output_embeddings() - self.lm_head_weight = self.lm_head.weight.detach() - self.lm_head_bias = self.lm_head.bias.detach() if getattr(self.lm_head, "bias", None) is not None else None - if self.cfg.force_single_token_labels and not self._has_unique_single_token_labels(): - raise ValueError( - f"prompt task '{self.task_cfg.name}' cannot force single-token labels because first tokens collide" - ) - self.fast_path = self._has_unique_single_token_labels() - self.fast_path_token_ids = [item.token_ids[0] for item in self.encoded_labels] if self.fast_path else [] - self.scorer = self._build_scorer() if self.fast_path else None - self.multi_token_tables = self._build_multi_token_tables() if not self.fast_path else None - self.prefix_cache = self._build_prefix_cache() - - def _encode_label_token_ids(self, label: str) -> list[int]: - token_ids = self.tokenizer.encode( - f"{self.task_cfg.label_prefix}{label}", - add_special_tokens=False, - ) - if not token_ids: - raise ValueError(f"label '{label}' in prompt '{self.task_cfg.name}' tokenizes to an empty sequence") - if self.cfg.force_single_token_labels: - return token_ids[:1] - return token_ids - - def _has_unique_single_token_labels(self) -> bool: - token_ids: list[int] = [] - for item in self.encoded_labels: - if len(item.token_ids) != 1: - return False - token_ids.append(item.token_ids[0]) - return len(token_ids) == len(set(token_ids)) - - def _build_scorer(self) -> SmallClassScorer: - token_ids = torch.tensor(self.fast_path_token_ids, dtype=torch.long, device=self.device) - weights = torch.index_select(self.lm_head_weight, 0, token_ids).to(dtype=self.dtype).contiguous() - bias = None - if self.lm_head_bias is not None: - bias = torch.index_select(self.lm_head_bias, 0, token_ids).to(dtype=self.dtype).contiguous() - return SmallClassScorer(weights=weights, bias=bias) - - def _build_multi_token_tables(self) -> MultiTokenTables: - max_label_len = max(len(item.token_ids) for item in self.encoded_labels) - max_prefix_len = max(len(item.token_ids) - 1 for item in self.encoded_labels) - label_token_ids = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.long) - label_token_mask = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.float32) - label_prefix_ids = torch.full( - (self.num_labels, max_prefix_len), - fill_value=self.tokenizer.pad_token_id, - device=self.device, - dtype=torch.long, - ) - label_prefix_mask = torch.zeros((self.num_labels, max_prefix_len), device=self.device, dtype=torch.long) - for idx, item in enumerate(self.encoded_labels): - token_ids = torch.tensor(item.token_ids, device=self.device, dtype=torch.long) - token_len = token_ids.numel() - label_token_ids[idx, :token_len] = token_ids - label_token_mask[idx, :token_len] = 1.0 - if token_len > 1: - prefix_len = token_len - 1 - label_prefix_ids[idx, :prefix_len] = token_ids[:-1] - label_prefix_mask[idx, :prefix_len] = 1 - return MultiTokenTables( - label_token_ids=label_token_ids.contiguous(), - label_token_mask=label_token_mask.contiguous(), - label_prefix_ids=label_prefix_ids.contiguous(), - label_prefix_mask=label_prefix_mask.contiguous(), - label_position_offsets=torch.arange(max_label_len, device=self.device, dtype=torch.long), - ) - - @torch.inference_mode() - def _build_prefix_cache(self) -> PrefixCache: - if not self.prefix_ids: - return PrefixCache(prefix_ids=[], prefix_hashes=[], raw_cache=tuple()) - prefix_tensor = torch.tensor([self.prefix_ids], dtype=torch.long, device=self.device) - attention_mask = torch.ones_like(prefix_tensor, dtype=torch.long, device=self.device) - outputs = self.model( - input_ids=prefix_tensor, - attention_mask=attention_mask, - use_cache=True, - return_dict=True, - ) - raw_cache = tuple( - (layer.keys.detach(), layer.values.detach()) - for layer in outputs.past_key_values.layers - ) - return PrefixCache( - prefix_ids=list(self.prefix_ids), - prefix_hashes=_hash_blocks(self.prefix_ids, self.cfg.prefix_block_size), - raw_cache=raw_cache, - ) - - def expand_prefix_raw_cache(self, batch_size: int) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: - if not self.prefix_cache.raw_cache: - return tuple() - return _expand_legacy_cache(self.prefix_cache.raw_cache, batch_size) - - def build_continuation_from_query_ids(self, query_ids: list[int]) -> list[int]: - continuation = query_ids + self.suffix_ids - if not continuation: - raise ValueError("prompt continuation is empty after substituting query") - if self.prefix_cache.prefix_len + len(continuation) > self.cfg.max_length: - raise ValueError( - f"sequence length {self.prefix_cache.prefix_len + len(continuation)} exceeds max_length={self.cfg.max_length}" - ) - return continuation - - @torch.inference_mode() - def reduce_fast_scores( - self, - hidden: torch.Tensor, - out_scores: torch.Tensor, - ) -> None: - assert self.scorer is not None - out_scores.copy_(self.scorer(hidden)) - - @torch.inference_mode() - def reduce_multi_token_scores( - self, - last_hidden_state: torch.Tensor, - batch_size: int, - max_input_len: int, - score_positions: torch.Tensor, - out_scores: torch.Tensor, - ) -> None: - assert self.multi_token_tables is not None - hidden = last_hidden_state.reshape(batch_size, self.num_labels, max_input_len, self.hidden_size) - gather_index = score_positions[:, None, :, None].expand( - batch_size, - self.num_labels, - self.multi_token_tables.max_label_len, - self.hidden_size, - ) - gathered_hidden = torch.gather(hidden, 2, gather_index) - used_mask = self.multi_token_tables.label_token_mask.unsqueeze(0).expand(batch_size, -1, -1).bool() - - token_log_probs = torch.zeros( - (batch_size, self.num_labels, self.multi_token_tables.max_label_len), - device=self.device, - dtype=torch.float32, - ) - if used_mask.any(): - flat_hidden = gathered_hidden[used_mask] - flat_token_ids = self.multi_token_tables.label_token_ids.unsqueeze(0).expand(batch_size, -1, -1)[used_mask] - linear_hidden = flat_hidden.to(self.dtype) if self.device.type == "cuda" else flat_hidden.float() - linear_weight = self.lm_head_weight if self.device.type == "cuda" else self.lm_head_weight.float() - linear_bias = self.lm_head_bias - if linear_bias is not None and self.device.type != "cuda": - linear_bias = linear_bias.float() - flat_logits = F.linear(linear_hidden, linear_weight, linear_bias) - flat_selected = flat_logits.gather(1, flat_token_ids.unsqueeze(1)).squeeze(1).float() - flat_log_norm = torch.logsumexp(flat_logits.float(), dim=-1) - token_log_probs[used_mask] = flat_selected - flat_log_norm - out_scores.copy_(token_log_probs.sum(dim=-1)) - - def build_score_result( - self, - query: str, - scores: torch.Tensor, - stage_ms: dict[str, float], - continuation_tokens: int, - ) -> QueryScoreResult: - score_values = scores.detach().float().cpu().tolist() - best_idx = max(range(len(score_values)), key=score_values.__getitem__) - probs = torch.softmax(torch.tensor(score_values, dtype=torch.float32), dim=0).tolist() - return QueryScoreResult( - task_name=self.task_cfg.name, - query=query, - predicted_label=self.labels[best_idx], - scores=[ - (label, score, prob) - for label, score, prob in zip(self.labels, score_values, probs, strict=True) - ], - total_ms=sum(stage_ms.values()), - stage_ms=stage_ms, - fast_path=self.fast_path, - prefix_tokens=self.prefix_cache.prefix_len, - continuation_tokens=continuation_tokens, - label_token_lengths={item.text: len(item.token_ids) for item in self.encoded_labels}, - ) - - -class MultiPromptRunner: - def __init__(self, cfg: RuntimeConfig): - self.cfg = cfg - t0 = time.perf_counter() - self.shared_runtime = self.build_shared_runtime(cfg) - t1 = time.perf_counter() - self.device = self.shared_runtime.device - self.dtype = self.shared_runtime.dtype - self.tokenizer = self.shared_runtime.tokenizer - self.model = self.shared_runtime.model - self.backbone = self.shared_runtime.backbone - self.hidden_size = self.shared_runtime.hidden_size - self.graph_capture_pool = self.shared_runtime.graph_capture_pool - self.graph_capture_stream = self.shared_runtime.graph_capture_stream - self.runners = [ - PromptClassifierRunner(cfg=cfg, task_cfg=task_cfg, shared_runtime=self.shared_runtime) - for task_cfg in cfg.tasks - ] - t2 = time.perf_counter() - self.batch_layouts = {batch_size: self._build_batch_layout(batch_size) for batch_size in self.cfg.batch_sizes} - t3 = time.perf_counter() - self.mixed_prefix_caches = { - batch_size: self._build_mixed_prefix_cache(self.batch_layouts[batch_size]) - for batch_size in self.cfg.batch_sizes - } - t4 = time.perf_counter() - self.max_label_prefix_len = max( - (runner.multi_token_tables.max_label_prefix_len if runner.multi_token_tables is not None else 0) - for runner in self.runners - ) - self.mixed_buckets = { - (batch_size, continuation_len): self._build_mixed_bucket( - self.batch_layouts[batch_size], - self.mixed_prefix_caches[batch_size], - continuation_len, - ) - for batch_size in self.cfg.batch_sizes - for continuation_len in self.cfg.continuation_buckets - } - t5 = time.perf_counter() - self._warmup_results: dict[int, BatchScoreResult] = {} - self._preload_report: PreloadReport | None = None - self._init_stage_ms = { - "load_model_and_tokenizer": (t1 - t0) * 1000.0, - "build_prompt_runtimes": (t2 - t1) * 1000.0, - "build_batch_layouts": (t3 - t2) * 1000.0, - "build_mixed_prefix_caches": (t4 - t3) * 1000.0, - "build_mixed_buckets_and_graphs": (t5 - t4) * 1000.0, - } - self._init_total_ms = sum(self._init_stage_ms.values()) - - @staticmethod - def build_shared_runtime(cfg: RuntimeConfig) -> SharedRuntime: - device = torch.device(cfg.device) - dtype = torch.float16 - tokenizer = AutoTokenizer.from_pretrained( - cfg.resolved_model_source, - trust_remote_code=cfg.resolved_trust_remote_code, - token=cfg.hf_token, - cache_dir=cfg.hf_cache_dir, - local_files_only=cfg.resolved_local_files_only, - ) - if tokenizer.pad_token_id is None: - tokenizer.pad_token = tokenizer.eos_token - attn_impl = MultiPromptRunner._resolve_attn_impl(cfg.attn_backend) - quantization_config = None - model_kwargs: dict[str, object] = { - "trust_remote_code": cfg.resolved_trust_remote_code, - "attn_implementation": attn_impl, - "token": cfg.hf_token, - "cache_dir": cfg.hf_cache_dir, - "local_files_only": cfg.resolved_local_files_only, - } - if cfg.load_in_4bit: - if BitsAndBytesConfig is None: - raise ImportError("transformers BitsAndBytesConfig is unavailable; install bitsandbytes support first") - quantization_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=dtype, - bnb_4bit_quant_type=cfg.bnb_4bit_quant_type, - bnb_4bit_use_double_quant=cfg.bnb_4bit_use_double_quant, - ) - model_kwargs["quantization_config"] = quantization_config - model_kwargs["device_map"] = {"": device.index or 0} - else: - model_kwargs["dtype"] = dtype - model_kwargs["device_map"] = None - model = AutoModelForCausalLM.from_pretrained( - cfg.resolved_model_source, - **model_kwargs, - ).eval() - if not cfg.load_in_4bit: - model = model.to(device) - backbone = model.get_submodule(model.base_model_prefix) - hidden_size = model.get_output_embeddings().weight.shape[1] - graph_capture_pool = None - graph_capture_stream = None - if device.type == "cuda" and torch.cuda.is_available() and cfg.cuda_graphs and not cfg.load_in_4bit: - graph_capture_pool = torch.cuda.graph_pool_handle() - graph_capture_stream = torch.cuda.Stream(device=device) - return SharedRuntime( - device=device, - dtype=dtype, - tokenizer=tokenizer, - model=model, - backbone=backbone, - hidden_size=hidden_size, - graph_capture_pool=graph_capture_pool, - graph_capture_stream=graph_capture_stream, - ) - - @staticmethod - def _resolve_attn_impl(requested: str) -> str: - if requested in {"sdpa", "eager"}: - return requested - if requested == "auto": - return "sdpa" - raise ValueError(f"unsupported attn_backend: {requested}") - - def _attn_context(self): - if sdpa_kernel is not None and self.cfg.attn_backend in {"auto", "sdpa"}: - return sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]) - return torch.no_grad() - - def _sync(self) -> None: - if self.device.type == "cuda": - torch.cuda.synchronize() - - def _pick_bucket(self, continuation_len: int) -> int: - for bucket in self.cfg.continuation_buckets: - if continuation_len <= bucket: - return bucket - if self.cfg.pad_to_bucket: - raise ValueError( - f"continuation length {continuation_len} exceeds configured buckets; extend continuation_buckets" - ) - return continuation_len - - def _build_batch_layout(self, batch_size: int) -> BatchLayout: - plans: list[PromptBatchPlan] = [] - row_start = 0 - for runner in self.runners: - row_count = batch_size if runner.fast_path else batch_size * runner.num_labels - plans.append( - PromptBatchPlan( - runner=runner, - row_start=row_start, - row_count=row_count, - score_buffer=torch.empty((batch_size, runner.num_labels), device=self.device, dtype=torch.float32), - ) - ) - row_start += row_count - return BatchLayout(batch_size=batch_size, total_rows=row_start, plans=plans) - - def _build_mixed_prefix_cache(self, layout: BatchLayout) -> MixedPrefixCache: - prefix_lengths = torch.zeros((layout.total_rows,), device=self.device, dtype=torch.long) - non_empty = [plan.runner.prefix_cache.raw_cache for plan in layout.plans if plan.runner.prefix_cache.raw_cache] - if not non_empty: - return MixedPrefixCache( - batch_size=layout.batch_size, - total_rows=layout.total_rows, - prefix_lengths=prefix_lengths, - attention_mask=torch.zeros((layout.total_rows, 0), device=self.device, dtype=torch.long), - raw_cache=tuple(), - ) - - max_prefix_len = max(plan.runner.prefix_cache.prefix_len for plan in layout.plans) - num_layers = len(non_empty[0]) - attention_mask = torch.zeros((layout.total_rows, max_prefix_len), device=self.device, dtype=torch.long) - raw_layers: list[tuple[torch.Tensor, torch.Tensor]] = [] - for layer_idx in range(num_layers): - sample_key, sample_value = non_empty[0][layer_idx] - merged_key = sample_key.new_zeros( - (layout.total_rows, sample_key.shape[1], max_prefix_len, sample_key.shape[3]) - ) - merged_value = sample_value.new_zeros( - (layout.total_rows, sample_value.shape[1], max_prefix_len, sample_value.shape[3]) - ) - raw_layers.append((merged_key, merged_value)) - - for plan in layout.plans: - runner = plan.runner - prefix_len = runner.prefix_cache.prefix_len - row_slice = slice(plan.row_start, plan.row_stop) - prefix_lengths[row_slice] = prefix_len - if prefix_len == 0: - continue - attention_mask[row_slice, :prefix_len] = 1 - raw_cache = runner.expand_prefix_raw_cache(plan.row_count) - for layer_idx, (key, value) in enumerate(raw_cache): - merged_key, merged_value = raw_layers[layer_idx] - merged_key[row_slice, :, :prefix_len, :] = key - merged_value[row_slice, :, :prefix_len, :] = value - - return MixedPrefixCache( - batch_size=layout.batch_size, - total_rows=layout.total_rows, - prefix_lengths=prefix_lengths, - attention_mask=attention_mask.contiguous(), - raw_cache=tuple(raw_layers), - ) - - def _build_mixed_bucket( - self, - layout: BatchLayout, - prefix_cache: MixedPrefixCache, - continuation_len: int, - ) -> MixedBucketRuntime: - max_input_len = continuation_len + self.max_label_prefix_len - total_len = prefix_cache.max_prefix_len + max_input_len - input_ids = torch.full( - (layout.total_rows, max_input_len), - fill_value=self.tokenizer.pad_token_id, - device=self.device, - dtype=torch.long, - ) - attention_mask = torch.zeros((layout.total_rows, total_len), device=self.device, dtype=torch.long) - if prefix_cache.max_prefix_len: - attention_mask[:, : prefix_cache.max_prefix_len] = prefix_cache.attention_mask - position_ids = ( - prefix_cache.prefix_lengths[:, None] - + torch.arange(max_input_len, device=self.device, dtype=torch.long).unsqueeze(0) - ).contiguous() - last_hidden_state = torch.empty( - (layout.total_rows, max_input_len, self.hidden_size), - device=self.device, - dtype=self.dtype, - ) - bucket = MixedBucketRuntime( - batch_size=layout.batch_size, - total_rows=layout.total_rows, - continuation_len=continuation_len, - max_input_len=max_input_len, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - last_hidden_state=last_hidden_state, - ) - if self.cfg.cuda_graphs: - self._capture_mixed_bucket(bucket, prefix_cache) - return bucket - - @torch.inference_mode() - def _run_mixed_backbone( - self, - bucket: MixedBucketRuntime, - prefix_cache: MixedPrefixCache, - ) -> None: - cache = DynamicCache(ddp_cache_data=prefix_cache.raw_cache, config=self.model.config) - with self._attn_context(): - outputs = self.backbone( - input_ids=bucket.input_ids, - attention_mask=bucket.attention_mask, - position_ids=bucket.position_ids, - past_key_values=cache, - use_cache=False, - return_dict=True, - ) - bucket.last_hidden_state.copy_(outputs.last_hidden_state) - - def _capture_mixed_bucket(self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache) -> None: - if not torch.cuda.is_available(): - return - try: - torch.cuda.synchronize() - stream = self.graph_capture_stream or torch.cuda.Stream(device=self.device) - with torch.cuda.stream(stream): - for _ in range(self.cfg.graph_warmups): - self._run_mixed_backbone(bucket, prefix_cache) - stream.synchronize() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, pool=self.graph_capture_pool, stream=stream): - self._run_mixed_backbone(bucket, prefix_cache) - bucket.graph = graph - except RuntimeError: - bucket.graph = None - - def _prepare_bucket( - self, - layout: BatchLayout, - prefix_cache: MixedPrefixCache, - bucket: MixedBucketRuntime, - query_ids_batch: list[list[int]], - ) -> tuple[list[list[int]], dict[str, list[object]]]: - del prefix_cache - bucket.input_ids.fill_(self.tokenizer.pad_token_id) - bucket.attention_mask.zero_() - if self.mixed_prefix_caches[layout.batch_size].max_prefix_len: - bucket.attention_mask[:, : self.mixed_prefix_caches[layout.batch_size].max_prefix_len] = ( - self.mixed_prefix_caches[layout.batch_size].attention_mask - ) - continuation_lengths_per_task: dict[str, list[int]] = {} - continuation_tokens_per_task: dict[str, list[list[int]]] = {} - prefix_base = self.mixed_prefix_caches[layout.batch_size].max_prefix_len - for plan in layout.plans: - runner = plan.runner - per_query_continuations = [runner.build_continuation_from_query_ids(query_ids) for query_ids in query_ids_batch] - continuation_tokens_per_task[runner.task_cfg.name] = per_query_continuations - continuation_lengths_per_task[runner.task_cfg.name] = [len(ids) for ids in per_query_continuations] - if runner.fast_path: - for batch_idx, continuation in enumerate(per_query_continuations): - cont_len = len(continuation) - row_idx = plan.row_start + batch_idx - bucket.input_ids[row_idx, :cont_len] = torch.tensor(continuation, device=self.device, dtype=torch.long) - bucket.attention_mask[row_idx, prefix_base : prefix_base + cont_len] = 1 - continue - - assert runner.multi_token_tables is not None - for batch_idx, continuation in enumerate(per_query_continuations): - cont_len = len(continuation) - row_start = plan.row_start + batch_idx * runner.num_labels - row_stop = row_start + runner.num_labels - row_slice = slice(row_start, row_stop) - cont_tensor = torch.tensor(continuation, device=self.device, dtype=torch.long) - bucket.input_ids[row_slice, :cont_len] = cont_tensor.unsqueeze(0).expand(runner.num_labels, -1) - bucket.attention_mask[row_slice, prefix_base : prefix_base + cont_len] = 1 - if runner.multi_token_tables.max_label_prefix_len: - bucket.input_ids[ - row_slice, - cont_len : cont_len + runner.multi_token_tables.max_label_prefix_len, - ] = runner.multi_token_tables.label_prefix_ids - bucket.attention_mask[ - row_slice, - prefix_base + cont_len : prefix_base + cont_len + runner.multi_token_tables.max_label_prefix_len, - ] = runner.multi_token_tables.label_prefix_mask - return query_ids_batch, { - "continuation_lengths_per_task": continuation_lengths_per_task, - "continuation_tokens_per_task": continuation_tokens_per_task, - } - - def _reduce_prompt_scores( - self, - layout: BatchLayout, - bucket: MixedBucketRuntime, - query_texts: list[str], - prep_meta: dict[str, list[object]], - shared_stage_ms: dict[str, float], - ) -> list[MultiPromptScoreResult]: - result_rows = [[] for _ in range(layout.batch_size)] - prompt_reduce_total_ms = 0.0 - for plan in layout.plans: - runner = plan.runner - continuation_lengths = prep_meta["continuation_lengths_per_task"][runner.task_cfg.name] - reduce_start = time.perf_counter() - if runner.fast_path: - hidden_rows = [] - row_slice = bucket.last_hidden_state[plan.row_start : plan.row_start + layout.batch_size] - for batch_idx, cont_len in enumerate(continuation_lengths): - hidden_rows.append(row_slice[batch_idx, cont_len - 1]) - hidden = torch.stack(hidden_rows, dim=0) - runner.reduce_fast_scores(hidden=hidden, out_scores=plan.score_buffer) - stage_name = "tail_scorer" - else: - assert runner.multi_token_tables is not None - score_positions = torch.stack( - [ - cont_len - 1 + runner.multi_token_tables.label_position_offsets - for cont_len in continuation_lengths - ], - dim=0, - ) - runner.reduce_multi_token_scores( - last_hidden_state=bucket.last_hidden_state[plan.row_start : plan.row_stop], - batch_size=layout.batch_size, - max_input_len=bucket.max_input_len, - score_positions=score_positions, - out_scores=plan.score_buffer, - ) - stage_name = "candidate_reduce" - self._sync() - reduce_end = time.perf_counter() - reduce_ms = (reduce_end - reduce_start) * 1000.0 - prompt_reduce_total_ms += reduce_ms - for batch_idx, query in enumerate(query_texts): - stage_ms = dict(shared_stage_ms) - stage_ms[stage_name] = reduce_ms / layout.batch_size - result_rows[batch_idx].append( - runner.build_score_result( - query=query, - scores=plan.score_buffer[batch_idx], - stage_ms=stage_ms, - continuation_tokens=continuation_lengths[batch_idx], - ) - ) - - batch_total_ms = sum(shared_stage_ms.values()) + prompt_reduce_total_ms - shared_plus_reduce = dict(shared_stage_ms) - shared_plus_reduce["prompt_reduce_total"] = prompt_reduce_total_ms - results: list[MultiPromptScoreResult] = [] - for batch_idx, query in enumerate(query_texts): - results.append( - MultiPromptScoreResult( - query=query, - total_ms=batch_total_ms / layout.batch_size, - details=result_rows[batch_idx], - stage_ms={ - **shared_plus_reduce, - "per_query_total_estimate": batch_total_ms / layout.batch_size, - }, - ) - ) - return results - - @torch.inference_mode() - def score_queries(self, queries: list[str]) -> BatchScoreResult: - if not queries: - raise ValueError("queries must not be empty") - batch_size = len(queries) - if batch_size not in self.batch_layouts: - raise ValueError(f"batch size {batch_size} is not preloaded; configured batch_sizes={self.cfg.batch_sizes}") - layout = self.batch_layouts[batch_size] - prefix_cache = self.mixed_prefix_caches[batch_size] - - self._sync() - t0 = time.perf_counter() - query_ids_batch = [self.tokenizer.encode(query, add_special_tokens=False) for query in queries] - self._sync() - t1 = time.perf_counter() - - max_continuation_len = max( - len(plan.runner.build_continuation_from_query_ids(query_ids)) - for plan in layout.plans - for query_ids in query_ids_batch - ) - picked_bucket = self._pick_bucket(max_continuation_len) - bucket = self.mixed_buckets[(batch_size, picked_bucket)] - _, prep_meta = self._prepare_bucket(layout, prefix_cache, bucket, query_ids_batch) - self._sync() - t2 = time.perf_counter() - - if bucket.graph is not None: - bucket.graph.replay() - else: - self._run_mixed_backbone(bucket, prefix_cache) - self._sync() - t3 = time.perf_counter() - - shared_stage_ms = { - "encode_queries_shared": (t1 - t0) * 1000.0, - "prepare_batch_shared": (t2 - t1) * 1000.0, - "backbone_shared": (t3 - t2) * 1000.0, - } - results = self._reduce_prompt_scores(layout, bucket, queries, prep_meta, shared_stage_ms) - total_ms = sum(shared_stage_ms.values()) + results[0].stage_ms["prompt_reduce_total"] - return BatchScoreResult( - batch_size=batch_size, - total_ms=total_ms, - results=results, - stage_ms={ - **shared_stage_ms, - "prompt_reduce_total": results[0].stage_ms["prompt_reduce_total"], - }, - ) - - def score_query(self, query: str) -> MultiPromptScoreResult: - return self.score_queries([query]).results[0] - - def preload(self) -> PreloadReport: - if self._preload_report is not None: - return self._preload_report - stage_ms: dict[str, float] = dict(self._init_stage_ms) - start = time.perf_counter() - self._sync() - t0 = time.perf_counter() - warmup_batch_sizes = self.cfg.warmup_batch_sizes or self.cfg.batch_sizes - for batch_size in warmup_batch_sizes: - queries = [self.cfg.warmup_query] * batch_size - self._warmup_results[batch_size] = self.score_queries(queries) - self._sync() - t1 = time.perf_counter() - stage_ms["warmup_end_to_end"] = (t1 - t0) * 1000.0 - stage_ms["startup_total_before_warmup"] = self._init_total_ms - total_ms = self._init_total_ms + (t1 - start) * 1000.0 - runtime = self.preload_report() - self._preload_report = PreloadReport(total_ms=total_ms, stage_ms=stage_ms, runtime=runtime) - return self._preload_report - - def preload_report(self) -> dict[str, object]: - return { - "model_name": self.cfg.resolved_model_name, - "model_source": self.cfg.resolved_model_source, - "device": str(self.device), - "dtype": self.cfg.dtype, - "attn_backend": self.cfg.attn_backend, - "execution_model": "single_mixed_backbone_per_batch", - "num_tasks": len(self.runners), - "task_names": [runner.task_cfg.name for runner in self.runners], - "batch_sizes": list(self.cfg.batch_sizes), - "continuation_buckets": list(self.cfg.continuation_buckets), - "mixed_bucket_count": len(self.mixed_buckets), - "captured_mixed_buckets": sum(bucket.graph is not None for bucket in self.mixed_buckets.values()), - "all_configured_buckets_preloaded": True, - "init_stage_ms": dict(self._init_stage_ms), - "init_total_ms": self._init_total_ms, - "force_single_token_labels": self.cfg.force_single_token_labels, - "warmup_query": self.cfg.warmup_query, - "tasks": [ - { - "task_name": runner.task_cfg.name, - "fast_path": runner.fast_path, - "num_labels": runner.num_labels, - "label_token_lengths": {item.text: len(item.token_ids) for item in runner.encoded_labels}, - "prefix_tokens": runner.prefix_cache.prefix_len, - "prefix_hashes": runner.prefix_cache.prefix_hashes, - "label_prefix": runner.task_cfg.label_prefix, - } - for runner in self.runners - ], - } - -但是关于多token的处理方式是低效的、是错的,不要参考他,请你重新实现。 -本地的.venv已经创建好,是复用llm-qp的,请使用该环境 -有用的代码我已经引用过来了,请你基于需求,搜寻相关资料,专注于重新开发全新的版本。 -任务较重,请耐心逐步完成,慢慢来,要长时间迭代一直到服务建设完成、测试完全符合预期。 diff --git a/docs/issues/issue-2026-04-06-推理优化.md b/docs/issues/issue-2026-04-06-推理优化.md deleted file mode 100644 index 4b27b24..0000000 --- a/docs/issues/issue-2026-04-06-推理优化.md +++ /dev/null @@ -1,59 +0,0 @@ - -总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 -先专注于推理的优化,不用考虑服务后,可以程序启动后标准输入读取query,输出分类词。 - -示例prompt和对应的9个分类词: -Analyze the category intent of the given query. Output exactly one label from: dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other. Output nothing else query:{query} - -做专用执行路径,不是在通用生成引擎上做配置优化。 - -prompt已经固定(不要考虑蒸馏、微调、或者缩短prompt。缩短prompt是可以我自己调整的,并且prompt可能改为其他场景,与你做专用推理优化不相关,我给的prompt只是一个例子,你专注于专用推理的框架,适配配置化的prompt+分类词列表打分检查) - - -使用Tesla T4,因此: -1. 使用FP16。不用 BF16 作为主路径 -2. 不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention -3. 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。 - - -主要考虑优化方向为: -hidden_last -> N-class scorer -> argmax -参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV -去 full vocab logits -去 decode / constrained decode -query长度分桶 + 小微批,只为少数 batch 模式 capture(batch= 1 2 4 8) -专用 tail kernel(输出 N 类原始分数) - -以下仅供参考,具体怎么做还请你自己基于搜索的开源项目作为baseline进行优化: -15.1 预分配 -对每个 bucket 预分配: -input ids buffer -attention scratch buffer -output hidden buffer -class id buffer -15.2 Graph capture - -每个 bucket 预先 capture 一张图: -embedding -transformer 主干 -compact last-state -9-way tail kernel - -注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 - -你有sudo权限,你可以执行为本项目安装自己的环境 - -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型) -或者,请你在huggingface上面查找其他模型,完成部署,并进行推理耗时的测试。 - - - - - - - -另外,我想要有个命令行工具,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。 -输入query,输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下),目前是否已经满足了。 - -请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 -使用qwen3:4b-instruct-2507-q4_K_M模型。(因为qwen3:4b不能关闭思考模式,所以使用qwen3:4b-instruct-2507-q4_K_M模型),完成部署,并进行推理耗时的测试。 diff --git a/docs/issues/issue-2026-04-07-服务化.md b/docs/issues/issue-2026-04-07-服务化.md deleted file mode 100644 index c431213..0000000 --- a/docs/issues/issue-2026-04-07-服务化.md +++ /dev/null @@ -1,66 +0,0 @@ -彻底清除掉命令行交互方式的代码,改造为提供 HTTP 接口,端口6001。 -并提供scripts/service_ctl.sh的方式管理服务的启停 -返回分类结果的list(results):list由多个dict组成,每个dict的key为对应的分类任务,value是三元组,即值、打分、概率(如果值为none则该项不输出) -完善日志系统,当前cli方式输出的重要信息,在日志以及http接口的details字段体现。 -提供一个压测脚本放到本项目的合适的目录下,作为压测工具。该压测工具针对每条请求打印出结果,并最后给出性能指标,参考: -#!/bin/bash - -# 默认值 -concurrency=${1:-1} -top_lines=${2:-100} - -# 固定查询文件路径 -query_file="/data/saas-search/scripts/evaluation/queries/queries.txt" - -# 检查文件是否存在 -if [ ! -f "$query_file" ]; then - echo "错误: 查询文件不存在: $query_file" >&2 - exit 1 -fi - -# 检查 jq 是否可用 -if ! command -v jq &> /dev/null; then - echo "错误: 需要安装 jq 来解析 JSON" >&2 - exit 1 -fi - -url="http://127.0.0.1:6001/..." -max_jobs=$concurrency -job_count=0 - -# 读取文件前 top_lines 行,每行作为一个 query -while IFS= read -r query; do - # 跳过空行 - [ -z "$query" ] && continue - - # 启动子进程执行请求 - ( - # 安全构建 JSON payload - payload=$(jq -n --arg q "$query" '{query: $q}') - # 发送请求并获取响应 - response=$(curl -s -X POST "$url" \ - -H 'Content-Type: application/json' \ - -d "$payload") - # 提取 results 字段(紧凑 JSON 格式) - results=$(echo "$response" | jq -c '.results') - # 输出 query 和对应的 results - printf "%s\t%s\n" "$query" "$results" - ) & - - # 控制并发数量 - ((job_count++)) - if (( job_count >= max_jobs )); then - wait -n # 等待任意一个后台进程完成 - ((job_count--)) - fi -done < <(head -n "$top_lines" "$query_file") - -# 等待所有剩余后台进程完成 -# 在这里统计处性能情况,指标: - 平均耗时: - 最大耗时: - 最小耗时: - TP50: - TP90: - TP99: -wait \ No newline at end of file diff --git a/docs/issues/issue.md b/docs/issues/issue.md index 2180025..0f4019b 100644 --- a/docs/issues/issue.md +++ b/docs/issues/issue.md @@ -1,5 +1,7 @@ 项目 TODO 清单 +CLAUDE.md需要更新 + 2. 核心搜索功能优化 2.1 意图识别模块 diff --git a/scripts/evaluation/eval_framework/__init__.py b/scripts/evaluation/eval_framework/__init__.py index 236fb67..074e558 100644 --- a/scripts/evaluation/eval_framework/__init__.py +++ b/scripts/evaluation/eval_framework/__init__.py @@ -20,7 +20,6 @@ from .constants import ( # noqa: E402 RELEVANCE_LOW, RELEVANCE_NON_IRRELEVANT, VALID_LABELS, - normalize_stored_label, ) from .framework import SearchEvaluationFramework # noqa: E402 from .store import EvalStore, QueryBuildResult # noqa: E402 @@ -51,7 +50,6 @@ __all__ = [ "create_web_app", "ensure_dir", "main", - "normalize_stored_label", "render_batch_report_markdown", "sha1_text", "utc_now_iso", diff --git a/scripts/evaluation/eval_framework/constants.py b/scripts/evaluation/eval_framework/constants.py index d9921f8..3d1379e 100644 --- a/scripts/evaluation/eval_framework/constants.py +++ b/scripts/evaluation/eval_framework/constants.py @@ -42,20 +42,6 @@ STOP_PROB_MAP = { RELEVANCE_IRRELEVANT: 0.0, } -_LEGACY_LABEL_MAP = { - "Exact": RELEVANCE_EXACT, - "Partial": RELEVANCE_HIGH, -} - - -def normalize_stored_label(label: str) -> str: - """Map legacy 3-way SQLite labels to current 4-way strings; pass through canonical labels.""" - s = str(label).strip() - if s in VALID_LABELS: - return s - return _LEGACY_LABEL_MAP.get(s, s) - - DEFAULT_ARTIFACT_ROOT = PROJECT_ROOT / "artifacts" / "search_evaluation" DEFAULT_QUERY_FILE = _SCRIPTS_EVAL_DIR / "queries" / "queries.txt" diff --git a/scripts/evaluation/eval_framework/static/eval_web.css b/scripts/evaluation/eval_framework/static/eval_web.css index 0e73cd9..a3893bc 100644 --- a/scripts/evaluation/eval_framework/static/eval_web.css +++ b/scripts/evaluation/eval_framework/static/eval_web.css @@ -48,9 +48,9 @@ .results { display: grid; gap: 10px; } .result { display: grid; grid-template-columns: 110px 100px 1fr; gap: 14px; align-items: center; background: var(--panel); border: 1px solid var(--line); border-radius: 18px; padding: 12px; } .badge { display: inline-block; padding: 8px 10px; border-radius: 999px; color: white; font-weight: 700; text-align: center; } - .label-exact-match { background: var(--exact); } - .label-high-relevant { background: var(--high); } - .label-low-relevant { background: var(--low); } + .label-fully-relevant { background: var(--exact); } + .label-mostly-relevant { background: var(--high); } + .label-weakly-relevant { background: var(--low); } .label-irrelevant { background: var(--irrelevant); } .badge-unknown { background: #637381; } .thumb { width: 100px; height: 100px; object-fit: cover; border-radius: 14px; background: #e7e1d4; } diff --git a/scripts/evaluation/eval_framework/store.py b/scripts/evaluation/eval_framework/store.py index 8261a4b..da030f4 100644 --- a/scripts/evaluation/eval_framework/store.py +++ b/scripts/evaluation/eval_framework/store.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Sequence -from .constants import VALID_LABELS, normalize_stored_label +from .constants import VALID_LABELS from .utils import ensure_dir, safe_json_dumps, utc_now_iso @@ -220,7 +220,7 @@ class EvalStore: """, (tenant_id, query_text), ).fetchall() - return {str(row["spu_id"]): normalize_stored_label(str(row["label"])) for row in rows} + return {str(row["spu_id"]): str(row["label"]) for row in rows} def upsert_labels( self, @@ -379,8 +379,8 @@ class EvalStore: SELECT query_text, COUNT(*) AS total, - SUM(CASE WHEN label IN ('Fully Relevant','Exact') THEN 1 ELSE 0 END) AS exact_count, - SUM(CASE WHEN label IN ('Mostly Relevant','Partial') THEN 1 ELSE 0 END) AS high_relevant_count, + SUM(CASE WHEN label='Fully Relevant' THEN 1 ELSE 0 END) AS exact_count, + SUM(CASE WHEN label='Mostly Relevant' THEN 1 ELSE 0 END) AS high_relevant_count, SUM(CASE WHEN label='Weakly Relevant' THEN 1 ELSE 0 END) AS low_relevant_count, SUM(CASE WHEN label='Irrelevant' THEN 1 ELSE 0 END) AS irrelevant_count, MAX(updated_at) AS updated_at @@ -409,8 +409,8 @@ class EvalStore: """ SELECT COUNT(*) AS total, - SUM(CASE WHEN label IN ('Fully Relevant','Exact') THEN 1 ELSE 0 END) AS exact_count, - SUM(CASE WHEN label IN ('Mostly Relevant','Partial') THEN 1 ELSE 0 END) AS high_relevant_count, + SUM(CASE WHEN label='Fully Relevant' THEN 1 ELSE 0 END) AS exact_count, + SUM(CASE WHEN label='Mostly Relevant' THEN 1 ELSE 0 END) AS high_relevant_count, SUM(CASE WHEN label='Weakly Relevant' THEN 1 ELSE 0 END) AS low_relevant_count, SUM(CASE WHEN label='Irrelevant' THEN 1 ELSE 0 END) AS irrelevant_count, MAX(updated_at) AS updated_at diff --git a/suggestion/ARCHITECTURE_V2.md b/suggestion/ARCHITECTURE_V2.md deleted file mode 100644 index b84e142..0000000 --- a/suggestion/ARCHITECTURE_V2.md +++ /dev/null @@ -1,304 +0,0 @@ -# Suggestion 架构方案 V2(仅 Suggest,去除结果直达) - -## 0. 结论 - -本方案将 Suggest 设计为**独立高性能检索系统**,只返回建议词,不再返回商品卡片,也不做历史兼容。 - -- 只保留 `/search/suggestions` 的词级自动补全能力 -- 完全移除 `with_results/result_size/products[]` 链路 -- 多语言优先,支持高并发、低延迟、可持续演进 - ---- - -## 1. 当前实现的关键问题(基于现有代码审视) - -1. 在线链路曾包含“suggest -> 二次商品查询”,属于典型 N+1 放大,QPS 上升后延迟和 ES 负载都不稳定。 -2. `builder.py` 全量构建使用“大量 in-memory 聚合 + fetchall”,大租户下内存风险高。 -3. 查询参数上限过大(原 `size<=200`),不符合自动补全接口性能边界。 -4. 文档与实现长期混合(README 仍包含结果直达),导致认知不一致。 -5. 多语言归一化仍偏基础(仅 lower/空白折叠),对 Unicode、变音符、跨语系兼容不够。 - ---- - -## 2. 目标与 SLO - -### 2.1 业务目标 - -- 输入时实时返回高相关建议词(query suggestion) -- 多语言稳定(至少覆盖租户配置 `index_languages`) -- 支持词级排序和运营治理(黑白名单、降噪、降权) - -### 2.2 性能目标(建议) - -- P50 < 10ms,P95 < 25ms,P99 < 50ms(ES 查询耗时,不含网关) -- 单集群支持高并发(千级 QPS 可横向扩展) -- 数据新鲜度:增量 5-15 分钟可见 - ---- - -## 3. 总体架构 - -## 3.1 在线路径(单跳) - -Client -> API `/search/suggestions` -> ES `search_suggestions_v2` -> 返回 suggestions - -原则: - -- **单次 ES 查询完成主路径**(可选双召回融合,但仍在同一次 API 请求内完成) -- 不调用 `search_products`,不返回商品结果 -- 通过 `routing=tenant_id` 避免跨分片 fan-out - -## 3.2 离线路径(构建) - -数据源: - -- 商品字段:`title.{lang}`、`qanchors.{lang}` -- 搜索日志:`shoplazza_search_log`(含 `language/request_params`) -- 行为信号(可选增强):点击、加购、下单 - -产物: - -- Suggest 文档(`tenant_id + lang + text_norm` 唯一) -- completion + prefix 检索字段 -- 排序特征(热度、近期度、质量分) - -发布方式: - -- 写入新物理索引(版本化) -- 原子切换 alias(零停机) - ---- - -## 4. 索引设计(ES) - -## 4.1 索引组织 - -推荐两级策略: - -1. 默认:环境级共享索引(降低海量租户 index 数量) -2. 大租户:可升级为租户独享索引(隔离资源) - -统一通过 alias 暴露: - -- `search_suggestions_v2_current` - -## 4.2 Mapping(核心字段) - -```json -{ - "settings": { - "number_of_shards": 3, - "number_of_replicas": 1, - "refresh_interval": "30s" - }, - "mappings": { - "properties": { - "tenant_id": { "type": "keyword" }, - "lang": { "type": "keyword" }, - "text": { "type": "keyword" }, - "text_norm": { "type": "keyword" }, - "status": { "type": "byte" }, - "sources": { "type": "keyword" }, - - "query_count_7d": { "type": "integer" }, - "query_count_30d": { "type": "integer" }, - "ctr_30d": { "type": "float" }, - "order_rate_30d": { "type": "float" }, - "rank_score": { "type": "float" }, - - "suggest": { - "type": "completion", - "contexts": [ - { "name": "tenant", "type": "category" }, - { "name": "lang", "type": "category" } - ] - }, - - "sat": { - "properties": { - "zh": { "type": "search_as_you_type", "analyzer": "index_ik" }, - "en": { "type": "search_as_you_type", "analyzer": "english" }, - "ar": { "type": "search_as_you_type", "analyzer": "arabic" } - } - }, - - "updated_at": { "type": "date" } - } - } -} -``` - -说明: - -- `completion` 负责极速前缀命中(主召回) -- `search_as_you_type` 负责多词前缀和召回兜底 -- `contexts` 强制租户与语言隔离 - ---- - -## 5. 多语言策略 - -1. 语言归属优先级:`log.language > request_params.language > 脚本识别 > tenant.primary_language` -2. 统一归一化:NFKC、大小写折叠、空白折叠、标点清洗 -3. 分词器按语言配置: - - 中文:IK/ANSJ(与主索引保持一致) - - 拉丁语系:对应内置 analyzer - - 未覆盖语种:`standard + ICU folding` 兜底 -4. 保证写入语言必须在租户 `index_languages` 内 - ---- - -## 6. 在线检索策略(高性能) - -## 6.1 双通道召回(推荐) - -1. 通道 A:`completion suggester`(prefix,skip_duplicates) -2. 通道 B:`multi_match(type=bool_prefix)` on `search_as_you_type` -3. 融合去重:按 `text_norm` 去重,按最终分排序截断 - -## 6.2 查询约束 - -- 默认 `size=10`,最大 `size=50` -- `track_total_hits=false` -- `_source` 仅返回必要字段(`text/lang/rank_score/sources`) -- `routing=tenant_id` - -## 6.3 打分建议 - -```text -final_score = - es_score - + a1*log1p(query_count_30d) - + a2*log1p(query_count_7d) - + a3*ctr_30d - + a4*order_rate_30d - + a5*freshness_decay -``` - ---- - -## 7. 构建与发布 - -## 7.1 构建模式 - -- 每日全量:重建全量特征,清理脏词 -- 小时级增量:只处理新日志窗口 - -## 7.2 工程要求 - -- 禁止 `fetchall` 全量入内存,改为流式读取(分页/游标) -- ES 扫描采用 `search_after` 流式聚合 -- 批量写入采用 bulk(分块 + 重试 + 失败重放) - -## 7.3 发布策略 - -1. `search_suggestions_v2_YYYYMMDDHHmm` 写入完成 -2. 校验 count/抽样查询/核心词覆盖 -3. alias 原子切换到新索引 -4. 保留上一个版本用于快速回滚 - ---- - -## 8. API 契约(V2) - -请求: - -- `GET /search/suggestions` -- 参数:`q`、`language`、`size` -- Header:`X-Tenant-ID` - -响应: - -```json -{ - "query": "iph", - "language": "en", - "resolved_language": "en", - "suggestions": [ - { - "text": "iphone 15", - "lang": "en", - "score": 8.31, - "rank_score": 6.72, - "sources": ["query_log", "qanchor"] - } - ], - "took_ms": 12 -} -``` - -删除项(明确不支持): - -- `with_results` -- `result_size` -- `products[]` - ---- - -## 9. 观测与治理 - -核心监控: - -- QPS、P50/P95/P99、错误率 -- 空结果率(按语言、按租户) -- suggestion 覆盖率(top query 是否命中) -- 语言冲突率(log vs request_params) -- 噪声词比例、黑名单命中率 - -治理机制: - -- 黑名单:强制下线 -- 白名单:强制保留并可加权 -- 最小热度阈值:低频垃圾词过滤 -- 时间衰减:过期词自动下沉 - ---- - -## 10. 与官方最佳实践对齐(ES) - -本方案直接采用以下官方建议: - -1. `completion` 适合高性能自动补全,支持 `skip_duplicates` 与上下文过滤。 -2. `search_as_you_type + bool_prefix` 是官方推荐的 as-you-type 查询方式。 -3. `edge_ngram` 仅用于索引时分词,查询时应用普通 analyzer(`search_analyzer`)。 -4. 多语言场景使用 ICU Analysis 插件增强 Unicode 处理。 -5. 通过 `routing` 将租户请求路由到单分片,降低 fan-out。 - ---- - -## 11. 分阶段落地 - -1. Phase 1(本次):去除结果直达,稳定 Suggest 单能力 -2. Phase 2:流式增量构建 + alias 原子发布 -3. Phase 3:行为信号排序(CTR/CVR)+ 运营治理台 -4. Phase 4:大租户独享索引自动升降级 - ---- - -## 12. Phase 2 落地命令(当前仓库) - -全量重建(版本化索引 + alias 发布): - -```bash -python main.py build-suggestions \ - --tenant-id 162 \ - --mode full \ - --days 365 \ - --publish-alias \ - --keep-versions 2 -``` - -增量更新(基于 watermark): - -```bash -python main.py build-suggestions \ - --tenant-id 162 \ - --mode incremental \ - --overlap-minutes 30 -``` - -一键脚本(全量 + 增量 + ES/API 验证): - -```bash -./scripts/rebuild_suggestions.sh 162 -``` diff --git a/suggestion/README.md b/suggestion/README.md index 645b4c1..4cf9475 100644 --- a/suggestion/README.md +++ b/suggestion/README.md @@ -1,46 +1,357 @@ -# Suggestion 模块说明(统一入口) +# Suggestion 模块说明 -本文档是 suggestion 模块的统一入口,遵循 `docs/DEVELOPER_GUIDE.md` 的“单一入口、避免分叉”原则。 +`suggestion/` 目录负责搜索框自动补全能力,当前实现只关注 suggestion 本身:离线构建建议词索引,在线根据输入前缀返回建议词列表。 -## 1. 当前状态(Phase 2) +这份 README 以当前代码实现为准,重点说明模块现状、关键设计、索引结构、构建发布方式,以及在线检索和排序细节。 -- 仅保留 Suggest 自动补全能力 -- 不支持结果直达(`with_results` / `result_size` / `products[]` 已移除) -- 索引采用版本化发布: - - 物理索引:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_v` - - 读别名:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_current` -- 支持增量更新(watermark + overlap) +## 1. 当前能力边界 -## 2. 文档导航(唯一推荐顺序) +- 对外接口:`GET /search/suggestions` +- 输入参数:`q`、`size`、`language` +- Header:`X-Tenant-ID` +- 返回内容:建议词列表 `suggestions[]` +- 不做商品结果拼接,也不走二次商品查询链路 + +当前模块由三部分组成: + +1. 离线构建:从商品索引和搜索日志构建 suggestion 文档 +2. 索引发布:写入版本化索引,并通过 alias 原子切换 +3. 在线查询:优先走 completion,必要时再走 `search_as_you_type` 兜底召回 + +## 2. 目录与关键代码 + +- [builder.py](/data/saas-search/suggestion/builder.py):离线构建、增量更新、alias 发布、meta 状态维护 +- [mapping.py](/data/saas-search/suggestion/mapping.py):suggestion 索引 settings 和 mappings 生成 +- [service.py](/data/saas-search/suggestion/service.py):在线查询服务,负责语言归一化、双路召回、去重和最终排序 +- [RUNBOOK.md](/data/saas-search/suggestion/RUNBOOK.md):构建、发布、验证操作说明 +- [TROUBLESHOOTING.md](/data/saas-search/suggestion/TROUBLESHOOTING.md):常见问题排查 + +命令入口在 [main.py](/data/saas-search/main.py) 中的 `build-suggestions` 子命令。 + +## 3. 整体架构 + +在线路径: + +`Client -> /search/suggestions -> SuggestionService -> Elasticsearch suggestion alias` + +离线路径: + +`商品索引 + 搜索日志 -> SuggestionIndexBuilder -> 版本化 suggestion index -> alias publish` + +设计上有几个核心点: + +- suggestion 独立建索引,不依赖在线商品检索 +- 每个租户单独维护 suggestion alias,避免租户间相互影响 +- 全量构建写新索引,切换 alias 时零停机 +- 增量更新只处理 query log 增量,减少重建成本 + +## 4. 索引组织与发布 + +索引命名在 [builder.py](/data/saas-search/suggestion/builder.py) 中统一定义: + +- 读别名:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_current` +- 版本索引:`{ES_INDEX_NAMESPACE}search_suggestions_tenant_{tenant_id}_v` +- 元信息索引:`{ES_INDEX_NAMESPACE}search_suggestions_meta` + +当前实现是“每租户一个 suggestion alias + 多个版本索引”的模式,而不是环境级共享大索引。 + +全量构建时的发布流程: + +1. 创建新的版本化索引 +2. 写入本次构建出的 suggestion 文档 +3. 校验新索引可分配、可读 +4. alias 原子切换到新索引 +5. 清理旧版本索引,只保留最近若干份 +6. 更新 `search_suggestions_meta` + +元信息索引里记录: + +- `active_alias` +- `active_index` +- `last_full_build_at` +- `last_incremental_build_at` +- `last_incremental_watermark` + +这些信息主要服务于增量更新和排障。 + +## 5. Mapping 与索引字段 + +[mapping.py](/data/saas-search/suggestion/mapping.py) 会根据租户的 `index_languages` 动态生成字段。 + +### 5.1 索引设置 + +- `number_of_shards = 1` +- `number_of_replicas = 0` +- `refresh_interval = 30s` + +中文使用自定义 analyzer: + +- `index_ik`:`ik_max_word + lowercase + asciifolding` +- `query_ik`:`ik_smart + lowercase + asciifolding` + +其他语言优先使用 Elasticsearch 内置 analyzer,例如 `english`、`arabic`、`french`、`german` 等;未覆盖语言回退到 `standard`。 + +### 5.2 核心字段 + +- `tenant_id`:租户隔离 +- `lang`:建议词所属语言 +- `text`:原始展示文本 +- `text_norm`:归一化文本,用于唯一键和去重 +- `sources`:来源集合,可能包含 `title`、`qanchor`、`tag`、`query_log` +- `title_doc_count` / `qanchor_doc_count` / `tag_doc_count`:该词被多少商品字段支撑 +- `query_count_7d` / `query_count_30d`:近 7/30 天搜索热度 +- `rank_score`:离线预计算排序分 +- `lang_confidence` / `lang_source` / `lang_conflict`:语言识别与冲突信息 +- `status`:当前是否有效 +- `updated_at`:最近更新时间 + +### 5.3 两类检索字段 + +1. `completion.` +2. `sat.` + +`completion.` 用于极速前缀补全,是短 query 下的主召回通道。 + +`sat.` 使用 `search_as_you_type`,用于多词前缀和 completion 未补足时的兜底召回。 + +也就是说,当前线上不是只靠一种召回方式,而是 completion 优先、SAT 补全。 + +## 6. 候选词从哪里来 + +[builder.py](/data/saas-search/suggestion/builder.py) 在全量构建中会聚合两大类数据源。 + +### 6.1 商品侧 + +从租户商品索引中流式读取: + +- `title` +- `qanchors` +- `enriched_tags` + +处理方式: + +- `title.`:经 `_prepare_title_for_suggest()` 裁剪后作为候选词 +- `qanchors.`:按分隔符拆分后作为候选词 +- `enriched_tags`:支持多语言对象或普通列表,必要时做语言识别 + +商品扫描不是一次性全量拉入内存,而是通过 `search_after` 分批读取,这一点在 [_iter_products()](/data/saas-search/suggestion/builder.py#L363) 已实现。 + +### 6.2 搜索日志侧 + +从 MySQL `shoplazza_search_log` 中按时间窗口流式读取: + +- `query` +- `language` +- `request_params` +- `create_time` + +读取方式使用 `stream_results=True + fetchmany()`,避免 `fetchall()` 带来的内存风险,这也是当前实现相对旧方案的重要改进。 + +搜索日志主要用于补充: + +- 用户真实搜索词 +- 近 7/30 天热度 +- 语言归属信息 + +## 7. 文本清洗与语言策略 + +### 7.1 文本归一化 + +在 [_normalize_text()](/data/saas-search/suggestion/builder.py#L176) 中,当前实现会做: + +- Unicode `NFKC` 归一化 +- 去首尾空白 +- 转小写 +- 多空白折叠为单空格 + +这份 `text_norm` 是 suggestion 文档的稳定键的一部分,文档 `_id` 形式为: + +`{tenant_id}|{lang}|{text_norm}` + +这保证了同租户、同语言、同一归一化词面只会保留一份文档。 + +### 7.2 噪声过滤 + +在 [_looks_noise()](/data/saas-search/suggestion/builder.py#L264) 中,以下内容会被过滤: + +- 空文本 +- 长度超过 120 +- 全部由符号组成的文本 -1. `ARCHITECTURE_V2.md`:架构与设计原则 -2. `RUNBOOK.md`:构建/发布/验证流程 -3. `TROUBLESHOOTING.md`:常见问题排查 +### 7.3 语言判定优先级 -## 3. 命令入口 +日志 query 的语言归属由 [_resolve_query_language()](/data/saas-search/suggestion/builder.py#L299) 负责,优先级是: -- 全量或增量构建: +1. `shoplazza_search_log.language` +2. `request_params.language` +3. `detect_text_language_for_suggestions()` +4. 租户 `primary_language` + +同时会记录: + +- `lang_source`:语言来自哪里 +- `lang_confidence`:识别置信度 +- `lang_conflict`:日志语言与请求语言是否冲突 + +在线查询侧在 [_resolve_language()](/data/saas-search/suggestion/service.py#L24) 也会做一次语言归一化,确保查询只打到租户允许的 `index_languages`。 + +## 8. 排序与 rank 细节 + +当前排序分成两层:离线 `rank_score`,以及在线最终排序。 + +### 8.1 离线 `rank_score` + +在 [_compute_rank_score()](/data/saas-search/suggestion/builder.py#L338) 中,当前公式是: + +```text +rank_score = + 1.8 * log1p(query_count_30d) + + 1.2 * log1p(query_count_7d) + + 1.0 * log1p(qanchor_doc_count) + + 0.85 * log1p(tag_doc_count) + + 0.6 * log1p(title_doc_count) +``` + +含义上是: + +- 搜索日志热度权重大于商品静态字段 +- 30 天热度权重大于 7 天热度,但 7 天热度也会强化近期趋势 +- `qanchor` 比普通标题更像“可搜索表达”,所以权重更高 +- `tag` 次之 +- `title` 提供基础覆盖,但权重相对更低 + +这个分数会被写入: + +- 文档字段 `rank_score` +- `completion..weight` + +因此它同时影响 completion 通道和 SAT 通道。 + +### 8.2 在线召回排序 + +[service.py](/data/saas-search/suggestion/service.py) 中的在线策略如下: + +1. 先查 `completion.` +2. 若 query 长度大于 2 且 completion 结果不足,再查 `sat.` +3. 按 `text` 归一化结果去重 +4. 最终排序后截断 + +completion 通道本身依赖 ES completion 的 `_score`;SAT 通道则用 `function_score + field_value_factor(rank_score)`。 + +最终排序由 [_finalize_suggestion_list()](/data/saas-search/suggestion/service.py#L155) 负责,排序 key 为: + +1. `score * 长度惩罚系数` +2. `rank_score` + +长度惩罚定义在 [_suggestion_length_factor()](/data/saas-search/suggestion/service.py#L16): + +```text +length_factor = 1 / sqrt(token_len) +``` + +这意味着在分数相近时,较短、较直接的 suggestion 会更容易排在前面,避免长尾长句把前缀补全结果“顶掉”。 + +## 9. 在线查询细节 + +[SuggestionService.search()](/data/saas-search/suggestion/service.py#L110) 的行为可以概括为: + +- 如果 alias 不存在,直接返回空数组,不抛 500 +- 短 query 优先走 completion 快速返回 +- 对于更长 query,再补一次 `bool_prefix` +- 查询时始终带 `routing=tenant_id` + +这里的 `routing` 很重要,它保证 suggestion 查询尽量只落在目标租户对应的分片路由上,减少无效 fan-out。 + +SAT 查询部分还有两个显式过滤条件: + +- `lang == resolved_language` +- `status == 1` + +这能保证召回结果只来自当前语言、当前有效文档。 + +## 10. 全量构建与增量更新 + +### 10.1 全量构建 + +入口在 [main.py](/data/saas-search/main.py#L104) 的 `build-suggestions --mode full`。 + +行为是: + +1. 读取租户配置中的 `index_languages` 和 `primary_language` +2. 创建新版本索引并等待 ready +3. 聚合商品数据和搜索日志,构造候选词 +4. 计算 `rank_score` +5. bulk 写入新索引 +6. refresh +7. 发布 alias +8. 更新 meta 信息 + +### 10.2 增量更新 + +入口在 [main.py](/data/saas-search/main.py#L104) 的 `build-suggestions --mode incremental`。 + +当前增量只处理 query log,不回扫商品数据。它依赖 meta 中的 watermark: + +- `last_incremental_watermark` +- 不存在时回退到 `last_full_build_at` +- 再不行就使用 `fallback_days` + +为了避免边界时间漏数,会额外减去 `overlap_minutes`,形成一个带重叠窗口的增量区间。 + +增量写入不是整文档重建,而是通过 `scripted_upsert` 做原地累加: + +- 增加 `query_count_30d` +- 增加 `query_count_7d` +- 更新 `lang_confidence` / `lang_source` / `lang_conflict` +- 重新计算 `rank_score` +- 更新 `completion` 和 `sat` + +对应逻辑在 [_build_incremental_update_script()](/data/saas-search/suggestion/builder.py#L834)。 + +如果 alias 尚不存在,而 `bootstrap_if_missing=True`,增量任务会先自动做一次全量构建作为初始化。 + +## 11. 当前实现的一些取舍 + +### 11.1 优点 + +- 在线链路很短,没有 suggestion 后再查商品的放大成本 +- 构建和发布流程清晰,支持零停机切换 +- 商品侧和日志侧都采用流式处理,能控制内存占用 +- completion + SAT 双路召回兼顾低延迟和补全能力 +- 语言、热度、来源信息都保存在索引中,便于后续优化 + +### 11.2 现阶段边界 + +- 增量更新目前只增量处理 query log,商品标题、qanchor、tag 变更仍依赖全量构建刷新 +- `rank_score` 目前只使用热度和商品字段覆盖度,没有接入点击、转化等行为质量信号 +- 文本归一化目前以 `NFKC + lower + whitespace fold` 为主,尚未做更激进的跨语种归并策略 +- `number_of_replicas=0` 更偏开发或成本优先配置,生产是否需要副本要结合集群策略评估 + +## 12. 常用命令 + +全量构建: ```bash ./scripts/build_suggestions.sh --mode full -./scripts/build_suggestions.sh --mode incremental ``` -- 一键重建 + 验证: +增量构建: ```bash -./scripts/rebuild_suggestions.sh +./scripts/build_suggestions.sh --mode incremental ``` -## 4. API 约定(简版) +一键重建并验证: -- 端点:`GET /search/suggestions` -- 参数:`q`, `size`, `language` -- Header:`X-Tenant-ID` +```bash +./scripts/rebuild_suggestions.sh +``` -示例: +接口示例: ```bash curl "http://localhost:6002/search/suggestions?q=shi&size=10&language=en" \ -H "X-Tenant-ID: 162" ``` + +更多操作细节见 [RUNBOOK.md](/data/saas-search/suggestion/RUNBOOK.md),故障排查看 [TROUBLESHOOTING.md](/data/saas-search/suggestion/TROUBLESHOOTING.md)。 diff --git a/suggestion/RUNBOOK.md b/suggestion/RUNBOOK.md index 502c6aa..3327969 100644 --- a/suggestion/RUNBOOK.md +++ b/suggestion/RUNBOOK.md @@ -34,7 +34,7 @@ DB_PASSWORD=... ### 3.1 执行 ```bash -./scripts/build_suggestions.sh 162 \ +./scripts/build_suggestions.sh 163 \ --mode full \ --days 365 \ --publish-alias \ @@ -56,7 +56,7 @@ DB_PASSWORD=... ### 4.1 执行 ```bash -./scripts/build_suggestions.sh 162 \ +./scripts/build_suggestions.sh 163 \ --mode incremental \ --overlap-minutes 30 ``` @@ -76,7 +76,7 @@ DB_PASSWORD=... > 若 ES 开启鉴权,请附带 `-u "$ES_USERNAME:$ES_PASSWORD"`。 ```bash -ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_162_current" +ALIAS_NAME="${ES_INDEX_NAMESPACE:-}search_suggestions_tenant_163_current" curl "$ES_HOST/$ALIAS_NAME/_count?pretty" @@ -97,10 +97,10 @@ curl "$ES_HOST/$ALIAS_NAME/_search?pretty" -H 'Content-Type: application/json' - ```bash curl "http://localhost:6002/search/suggestions?q=shirt&size=10&language=en" \ - -H "X-Tenant-ID: 162" + -H "X-Tenant-ID: 163" curl "http://localhost:6002/search/suggestions?q=玩具&size=10&language=zh" \ - -H "X-Tenant-ID: 162" + -H "X-Tenant-ID: 163" ``` 通过标准: @@ -112,7 +112,7 @@ curl "http://localhost:6002/search/suggestions?q=玩具&size=10&language=zh" \ ## 7. 一键验证脚本 ```bash -./scripts/rebuild_suggestions.sh 162 +./scripts/rebuild_suggestions.sh 163 ``` 该脚本执行: -- libgit2 0.21.2