|
| @@ -0,0 +1,1003 @@ |
| @@ -0,0 +1,1003 @@ |
|
| |
1
| +总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 |
|
| |
2
| +先专注于推理的优化,最后再考虑服务化,支持一定程度的并发(比如4)的请求,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,使用N个prompt进行N个维度的分类,对于每个prompt,推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 |
|
| |
3
| + |
|
| |
4
| +下面有一些参考技术资料,但是你并不需要严格,你应该有一定的灵活度,来追求极致的性能。 |
|
| |
5
| + |
|
| |
6
| +在 Tesla T4 上,用 3B 到 6B 级别的开源 decoder-only 基座模型做 query 分类。 |
|
| |
7
| +启动时完成 tokenizer、权重、prefix cache 和共享执行器准备工作。 |
|
| |
8
| +每次输入一个 query,输出每个 prompt 下每个 label 的分数分布,以及预测耗时和阶段耗时。 |
|
| |
9
| +不走通用生成路径,不做 decode,不取 full vocab logits,不做 constrained decode。 |
|
| |
10
| +对 multi-token label 做专门优化,避免 Python 侧串行 decode。 |
|
| |
11
| +prompt 和 label 集合必须可配置(目前只有两个,以后我会加到8个,每次请求输入一个query,并行的调用8个prompt进行推理得到得分最高的label: |
|
| |
12
| + |
|
| |
13
| + |
|
| |
14
| +prompts: |
|
| |
15
| + - name: category |
|
| |
16
| + prompt_template: | |
|
| |
17
| + <|im_start|>system |
|
| |
18
| + Analyze the category intent. Output exactly one label from [none, dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other]. Use 'none' if no category intent.<|im_end|> |
|
| |
19
| + <|im_start|>user |
|
| |
20
| + query: {query}<|im_end|> |
|
| |
21
| + <|im_start|>assistant |
|
| |
22
| + label: |
|
| |
23
| + label_prefix: " " |
|
| |
24
| + labels: [dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other, none] |
|
| |
25
| + |
|
| |
26
| + - name: audience |
|
| |
27
| + prompt_template: | |
|
| |
28
| + <|im_start|>system |
|
| |
29
| + Analyze the target user group. Output exactly one label from [none, boy, girl, man, woman, pregnant]. Use 'none' if no audience mentioned.<|im_end|> |
|
| |
30
| + <|im_start|>user |
|
| |
31
| + query: {query}<|im_end|> |
|
| |
32
| + <|im_start|>assistant |
|
| |
33
| + label: |
|
| |
34
| + label_prefix: " " |
|
| |
35
| + labels: [boy, girl, man, woman, pregnant, none] |
|
| |
36
| + |
|
| |
37
| + |
|
| |
38
| +做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟。 |
|
| |
39
| + |
|
| |
40
| +主要考虑优化方向为: |
|
| |
41
| +1. hidden_last -> N-class scorer -> argmax |
|
| |
42
| +2. 参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV |
|
| |
43
| +3. 去 full vocab logits |
|
| |
44
| +4. 去 decode / constrained decode |
|
| |
45
| +5. 专用 tail kernel(输出 N 类原始分数) |
|
| |
46
| +6. 配置的N个 prompt推理要并行推理(2-8个) |
|
| |
47
| +7. 使用Tesla T4,因此不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention |
|
| |
48
| +多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 |
|
| |
49
| + |
|
| |
50
| + |
|
| |
51
| +你有sudo权限,你可以执行为本项目安装自己的环境 |
|
| |
52
| + |
|
| |
53
| +使用Qwen/Qwen3-8B的Q4或Q8模型,具体用哪个版本,请你查找huggingface相关资料,选择合适的版本完成部署,并进行推理耗时的测试。 |
|
| |
54
| + |
|
| |
55
| +请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 |
|
| |
56
| + |
|
| |
57
| +一个重要的问题:一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况。 |
|
| |
58
| +需要通过多 token 标签做极致的性能优化,避免串行decode。 |
|
| |
59
| +我们最终目的是得到哪个label的得分最高,不一定要精确的概率,计算log P(id1 | query, prompt) + log P(id2 | query, prompt, id1)有可能导致难以优化性能,精确的概率是可以考虑放弃的,要清楚我们的最终目的,达到分类的目的即可,只要得到分类,优先考虑性能,精确的概率可以放下。 |
|
| |
60
| +如何通过一次模型 forward处理包括多token label的整个 batch,是你需要探索的问题。 |
|
| |
61
| + |
|
| |
62
| +单 token fast path 的做法比较确定: last_hidden -> small class scorer -> argmax。 只取目标 label 对应 LM head 行,不做 full vocab 输出。 |
|
| |
63
| +multi-token 怎么做需要搜索相关资料进行考量,最好要做到跟单token开销相同(放弃精确的log-prob的前提下。但是:多token和单token的label的打分的对比,一定要是可比的才能正确的分类,兼顾性能和打分的准确性) |
|
| |
64
| + |
|
| |
65
| +还需要增加一个配置:force_single_token_labels,所有 label 都按首 token 处理,因为,如果各个label收token不同,那么可以近似的认为首token打分代表整个label打分。 |
|
| |
66
| +你需要找到多label打分性能和准确性上面的最佳实践。同时也支持force_single_token_labels以达到极致的性能。 |
|
| |
67
| + |
|
| |
68
| +也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 |
|
| |
69
| +以下是一些参考示例: |
|
| |
70
| +vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ |
|
| |
71
| +PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ |
|
| |
72
| +TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html |
|
| |
73
| +Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq |
|
| |
74
| + |
|
| |
75
| +TensorRT support matrix: T4 / SM7.5 supports FP16 and INT8, but not BF16/FP8 in the main matrix. https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Support-Matrix-Guide.pdf |
|
| |
76
| +TensorRT-LLM support matrix: current official hardware list omits Turing/T4, so T4 is effectively community-support territory there. https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html |
|
| |
77
| +FlashAttention repo: FA3 is Hopper-focused; current published benchmarks are A100/H100-centric. https://github.com/Dao-AILab/flash-attention |
|
| |
78
| +vLLM APC docs: KV reuse via hashed KV blocks was the baseline idea for the prefix-cache metadata. https://docs.vllm.ai/_/downloads/en/v0.6.2/pdf/ |
|
| |
79
| +SGLang HiCache/RadixAttention docs: useful reference for prefix-cache reuse and page-granular KV organization. https://docs.sglang.io/advanced_features/hicache_design.html |
|
| |
80
| +FasterTransformer repo: still a useful T4 FP16 optimization baseline and historical Turing-oriented reference. https://github.com/NVIDIA/FasterTransformer |
|
| |
81
| +xFormers README: relevant as a Turing-friendly attention alternative; my mainline choice here is PyTorch SDPA on T4, which is an engineering inference from these sources rather than a direct vendor recommendation. https://github.com/facebookresearch/xformers |
|
| |
82
| + |
|
| |
83
| +注意:已经有一个项目 llm-qp, llm-qp2,这两个项目,对于单token的处理方式是可以的: |
|
| |
84
| +SDPA |
|
| |
85
| +prefix cache |
|
| |
86
| +prebuilt bucket + CUDA graph |
|
| |
87
| +他的核心代码是: |
|
| |
88
| +from __future__ import annotations |
|
| |
89
| + |
|
| |
90
| +import hashlib |
|
| |
91
| +import time |
|
| |
92
| +from dataclasses import asdict, dataclass |
|
| |
93
| +from typing import Iterable |
|
| |
94
| + |
|
| |
95
| +import torch |
|
| |
96
| +import torch.nn.functional as F |
|
| |
97
| +from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache |
|
| |
98
| + |
|
| |
99
| +try: |
|
| |
100
| + from transformers import BitsAndBytesConfig |
|
| |
101
| +except ImportError: # pragma: no cover |
|
| |
102
| + BitsAndBytesConfig = None |
|
| |
103
| + |
|
| |
104
| +from llm_qp.config import PromptTaskConfig, RuntimeConfig |
|
| |
105
| +from llm_qp.scorer import SmallClassScorer |
|
| |
106
| + |
|
| |
107
| +try: |
|
| |
108
| + from torch.nn.attention import SDPBackend, sdpa_kernel |
|
| |
109
| +except ImportError: # pragma: no cover |
|
| |
110
| + SDPBackend = None |
|
| |
111
| + sdpa_kernel = None |
|
| |
112
| + |
|
| |
113
| + |
|
| |
114
| +@dataclass(slots=True) |
|
| |
115
| +class EncodedLabel: |
|
| |
116
| + text: str |
|
| |
117
| + token_ids: list[int] |
|
| |
118
| + |
|
| |
119
| + |
|
| |
120
| +@dataclass(slots=True) |
|
| |
121
| +class PrefixCache: |
|
| |
122
| + prefix_ids: list[int] |
|
| |
123
| + prefix_hashes: list[str] |
|
| |
124
| + raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] |
|
| |
125
| + |
|
| |
126
| + @property |
|
| |
127
| + def prefix_len(self) -> int: |
|
| |
128
| + return len(self.prefix_ids) |
|
| |
129
| + |
|
| |
130
| + |
|
| |
131
| +@dataclass(slots=True) |
|
| |
132
| +class MultiTokenTables: |
|
| |
133
| + label_token_ids: torch.Tensor |
|
| |
134
| + label_token_mask: torch.Tensor |
|
| |
135
| + label_prefix_ids: torch.Tensor |
|
| |
136
| + label_prefix_mask: torch.Tensor |
|
| |
137
| + label_position_offsets: torch.Tensor |
|
| |
138
| + |
|
| |
139
| + @property |
|
| |
140
| + def max_label_len(self) -> int: |
|
| |
141
| + return self.label_token_ids.shape[1] |
|
| |
142
| + |
|
| |
143
| + @property |
|
| |
144
| + def max_label_prefix_len(self) -> int: |
|
| |
145
| + return self.label_prefix_ids.shape[1] |
|
| |
146
| + |
|
| |
147
| + |
|
| |
148
| +@dataclass(slots=True) |
|
| |
149
| +class QueryScoreResult: |
|
| |
150
| + task_name: str |
|
| |
151
| + query: str |
|
| |
152
| + predicted_label: str |
|
| |
153
| + scores: list[tuple[str, float, float]] |
|
| |
154
| + total_ms: float |
|
| |
155
| + stage_ms: dict[str, float] |
|
| |
156
| + fast_path: bool |
|
| |
157
| + prefix_tokens: int |
|
| |
158
| + continuation_tokens: int |
|
| |
159
| + label_token_lengths: dict[str, int] |
|
| |
160
| + |
|
| |
161
| + @property |
|
| |
162
| + def predicted_prob(self) -> float: |
|
| |
163
| + for label, _score, prob in self.scores: |
|
| |
164
| + if label == self.predicted_label: |
|
| |
165
| + return prob |
|
| |
166
| + return 0.0 |
|
| |
167
| + |
|
| |
168
| + |
|
| |
169
| +@dataclass(slots=True) |
|
| |
170
| +class MultiPromptScoreResult: |
|
| |
171
| + query: str |
|
| |
172
| + total_ms: float |
|
| |
173
| + details: list[QueryScoreResult] |
|
| |
174
| + stage_ms: dict[str, float] |
|
| |
175
| + |
|
| |
176
| + def http_json(self) -> dict[str, object]: |
|
| |
177
| + return { |
|
| |
178
| + "query": self.query, |
|
| |
179
| + "total_ms": self.total_ms, |
|
| |
180
| + "stage_ms": self.stage_ms, |
|
| |
181
| + "details": [asdict(t) for t in self.details], |
|
| |
182
| + "task_results": { |
|
| |
183
| + t.task_name: [t.predicted_label, t.continuation_tokens, t.predicted_prob] for t in self.details if t.predicted_label != 'none' |
|
| |
184
| + }, |
|
| |
185
| + } |
|
| |
186
| + |
|
| |
187
| + |
|
| |
188
| +@dataclass(slots=True) |
|
| |
189
| +class BatchScoreResult: |
|
| |
190
| + batch_size: int |
|
| |
191
| + total_ms: float |
|
| |
192
| + results: list[MultiPromptScoreResult] |
|
| |
193
| + stage_ms: dict[str, float] |
|
| |
194
| + |
|
| |
195
| + |
|
| |
196
| +@dataclass(slots=True) |
|
| |
197
| +class SharedRuntime: |
|
| |
198
| + device: torch.device |
|
| |
199
| + dtype: torch.dtype |
|
| |
200
| + tokenizer: object |
|
| |
201
| + model: object |
|
| |
202
| + backbone: object |
|
| |
203
| + hidden_size: int |
|
| |
204
| + graph_capture_pool: object | None = None |
|
| |
205
| + graph_capture_stream: torch.cuda.Stream | None = None |
|
| |
206
| + |
|
| |
207
| + |
|
| |
208
| +@dataclass(slots=True) |
|
| |
209
| +class PromptBatchPlan: |
|
| |
210
| + runner: "PromptClassifierRunner" |
|
| |
211
| + row_start: int |
|
| |
212
| + row_count: int |
|
| |
213
| + score_buffer: torch.Tensor |
|
| |
214
| + |
|
| |
215
| + @property |
|
| |
216
| + def row_stop(self) -> int: |
|
| |
217
| + return self.row_start + self.row_count |
|
| |
218
| + |
|
| |
219
| + |
|
| |
220
| +@dataclass(slots=True) |
|
| |
221
| +class MixedPrefixCache: |
|
| |
222
| + batch_size: int |
|
| |
223
| + total_rows: int |
|
| |
224
| + prefix_lengths: torch.Tensor |
|
| |
225
| + attention_mask: torch.Tensor |
|
| |
226
| + raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] |
|
| |
227
| + |
|
| |
228
| + @property |
|
| |
229
| + def max_prefix_len(self) -> int: |
|
| |
230
| + return int(self.prefix_lengths.max().item()) if self.prefix_lengths.numel() else 0 |
|
| |
231
| + |
|
| |
232
| + |
|
| |
233
| +@dataclass(slots=True) |
|
| |
234
| +class BatchLayout: |
|
| |
235
| + batch_size: int |
|
| |
236
| + total_rows: int |
|
| |
237
| + plans: list[PromptBatchPlan] |
|
| |
238
| + |
|
| |
239
| + |
|
| |
240
| +@dataclass(slots=True) |
|
| |
241
| +class MixedBucketRuntime: |
|
| |
242
| + batch_size: int |
|
| |
243
| + total_rows: int |
|
| |
244
| + continuation_len: int |
|
| |
245
| + max_input_len: int |
|
| |
246
| + input_ids: torch.Tensor |
|
| |
247
| + attention_mask: torch.Tensor |
|
| |
248
| + position_ids: torch.Tensor |
|
| |
249
| + last_hidden_state: torch.Tensor |
|
| |
250
| + graph: torch.cuda.CUDAGraph | None = None |
|
| |
251
| + |
|
| |
252
| + |
|
| |
253
| +@dataclass(slots=True) |
|
| |
254
| +class PreloadReport: |
|
| |
255
| + total_ms: float |
|
| |
256
| + stage_ms: dict[str, float] |
|
| |
257
| + runtime: dict[str, object] |
|
| |
258
| + |
|
| |
259
| + |
|
| |
260
| +def _hash_blocks(token_ids: Iterable[int], block_size: int) -> list[str]: |
|
| |
261
| + token_list = list(token_ids) |
|
| |
262
| + hashes: list[str] = [] |
|
| |
263
| + for start in range(0, len(token_list), block_size): |
|
| |
264
| + block = token_list[start : start + block_size] |
|
| |
265
| + payload = ",".join(str(x) for x in block).encode("utf-8") |
|
| |
266
| + hashes.append(hashlib.sha1(payload).hexdigest()) |
|
| |
267
| + return hashes |
|
| |
268
| + |
|
| |
269
| + |
|
| |
270
| +def _expand_legacy_cache( |
|
| |
271
| + raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...], |
|
| |
272
| + batch_size: int, |
|
| |
273
| +) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: |
|
| |
274
| + expanded: list[tuple[torch.Tensor, torch.Tensor]] = [] |
|
| |
275
| + for key, value in raw_cache: |
|
| |
276
| + expanded.append( |
|
| |
277
| + ( |
|
| |
278
| + key.expand(batch_size, *key.shape[1:]).contiguous(), |
|
| |
279
| + value.expand(batch_size, *value.shape[1:]).contiguous(), |
|
| |
280
| + ) |
|
| |
281
| + ) |
|
| |
282
| + return tuple(expanded) |
|
| |
283
| + |
|
| |
284
| + |
|
| |
285
| +class PromptClassifierRunner: |
|
| |
286
| + def __init__( |
|
| |
287
| + self, |
|
| |
288
| + cfg: RuntimeConfig, |
|
| |
289
| + task_cfg: PromptTaskConfig, |
|
| |
290
| + shared_runtime: SharedRuntime, |
|
| |
291
| + ): |
|
| |
292
| + self.cfg = cfg |
|
| |
293
| + self.task_cfg = task_cfg |
|
| |
294
| + self.device = shared_runtime.device |
|
| |
295
| + self.dtype = shared_runtime.dtype |
|
| |
296
| + self.tokenizer = shared_runtime.tokenizer |
|
| |
297
| + self.model = shared_runtime.model |
|
| |
298
| + self.backbone = shared_runtime.backbone |
|
| |
299
| + self.hidden_size = shared_runtime.hidden_size |
|
| |
300
| + self.prefix_text, self.suffix_text = task_cfg.prompt_parts |
|
| |
301
| + self.prefix_ids = self.tokenizer.encode(self.prefix_text, add_special_tokens=False) |
|
| |
302
| + self.suffix_ids = self.tokenizer.encode(self.suffix_text, add_special_tokens=False) |
|
| |
303
| + self.labels = list(task_cfg.labels) |
|
| |
304
| + self.encoded_labels = [ |
|
| |
305
| + EncodedLabel(text=label, token_ids=self._encode_label_token_ids(label)) |
|
| |
306
| + for label in self.labels |
|
| |
307
| + ] |
|
| |
308
| + self.num_labels = len(self.labels) |
|
| |
309
| + self.lm_head = self.model.get_output_embeddings() |
|
| |
310
| + self.lm_head_weight = self.lm_head.weight.detach() |
|
| |
311
| + self.lm_head_bias = self.lm_head.bias.detach() if getattr(self.lm_head, "bias", None) is not None else None |
|
| |
312
| + if self.cfg.force_single_token_labels and not self._has_unique_single_token_labels(): |
|
| |
313
| + raise ValueError( |
|
| |
314
| + f"prompt task '{self.task_cfg.name}' cannot force single-token labels because first tokens collide" |
|
| |
315
| + ) |
|
| |
316
| + self.fast_path = self._has_unique_single_token_labels() |
|
| |
317
| + self.fast_path_token_ids = [item.token_ids[0] for item in self.encoded_labels] if self.fast_path else [] |
|
| |
318
| + self.scorer = self._build_scorer() if self.fast_path else None |
|
| |
319
| + self.multi_token_tables = self._build_multi_token_tables() if not self.fast_path else None |
|
| |
320
| + self.prefix_cache = self._build_prefix_cache() |
|
| |
321
| + |
|
| |
322
| + def _encode_label_token_ids(self, label: str) -> list[int]: |
|
| |
323
| + token_ids = self.tokenizer.encode( |
|
| |
324
| + f"{self.task_cfg.label_prefix}{label}", |
|
| |
325
| + add_special_tokens=False, |
|
| |
326
| + ) |
|
| |
327
| + if not token_ids: |
|
| |
328
| + raise ValueError(f"label '{label}' in prompt '{self.task_cfg.name}' tokenizes to an empty sequence") |
|
| |
329
| + if self.cfg.force_single_token_labels: |
|
| |
330
| + return token_ids[:1] |
|
| |
331
| + return token_ids |
|
| |
332
| + |
|
| |
333
| + def _has_unique_single_token_labels(self) -> bool: |
|
| |
334
| + token_ids: list[int] = [] |
|
| |
335
| + for item in self.encoded_labels: |
|
| |
336
| + if len(item.token_ids) != 1: |
|
| |
337
| + return False |
|
| |
338
| + token_ids.append(item.token_ids[0]) |
|
| |
339
| + return len(token_ids) == len(set(token_ids)) |
|
| |
340
| + |
|
| |
341
| + def _build_scorer(self) -> SmallClassScorer: |
|
| |
342
| + token_ids = torch.tensor(self.fast_path_token_ids, dtype=torch.long, device=self.device) |
|
| |
343
| + weights = torch.index_select(self.lm_head_weight, 0, token_ids).to(dtype=self.dtype).contiguous() |
|
| |
344
| + bias = None |
|
| |
345
| + if self.lm_head_bias is not None: |
|
| |
346
| + bias = torch.index_select(self.lm_head_bias, 0, token_ids).to(dtype=self.dtype).contiguous() |
|
| |
347
| + return SmallClassScorer(weights=weights, bias=bias) |
|
| |
348
| + |
|
| |
349
| + def _build_multi_token_tables(self) -> MultiTokenTables: |
|
| |
350
| + max_label_len = max(len(item.token_ids) for item in self.encoded_labels) |
|
| |
351
| + max_prefix_len = max(len(item.token_ids) - 1 for item in self.encoded_labels) |
|
| |
352
| + label_token_ids = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.long) |
|
| |
353
| + label_token_mask = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.float32) |
|
| |
354
| + label_prefix_ids = torch.full( |
|
| |
355
| + (self.num_labels, max_prefix_len), |
|
| |
356
| + fill_value=self.tokenizer.pad_token_id, |
|
| |
357
| + device=self.device, |
|
| |
358
| + dtype=torch.long, |
|
| |
359
| + ) |
|
| |
360
| + label_prefix_mask = torch.zeros((self.num_labels, max_prefix_len), device=self.device, dtype=torch.long) |
|
| |
361
| + for idx, item in enumerate(self.encoded_labels): |
|
| |
362
| + token_ids = torch.tensor(item.token_ids, device=self.device, dtype=torch.long) |
|
| |
363
| + token_len = token_ids.numel() |
|
| |
364
| + label_token_ids[idx, :token_len] = token_ids |
|
| |
365
| + label_token_mask[idx, :token_len] = 1.0 |
|
| |
366
| + if token_len > 1: |
|
| |
367
| + prefix_len = token_len - 1 |
|
| |
368
| + label_prefix_ids[idx, :prefix_len] = token_ids[:-1] |
|
| |
369
| + label_prefix_mask[idx, :prefix_len] = 1 |
|
| |
370
| + return MultiTokenTables( |
|
| |
371
| + label_token_ids=label_token_ids.contiguous(), |
|
| |
372
| + label_token_mask=label_token_mask.contiguous(), |
|
| |
373
| + label_prefix_ids=label_prefix_ids.contiguous(), |
|
| |
374
| + label_prefix_mask=label_prefix_mask.contiguous(), |
|
| |
375
| + label_position_offsets=torch.arange(max_label_len, device=self.device, dtype=torch.long), |
|
| |
376
| + ) |
|
| |
377
| + |
|
| |
378
| + @torch.inference_mode() |
|
| |
379
| + def _build_prefix_cache(self) -> PrefixCache: |
|
| |
380
| + if not self.prefix_ids: |
|
| |
381
| + return PrefixCache(prefix_ids=[], prefix_hashes=[], raw_cache=tuple()) |
|
| |
382
| + prefix_tensor = torch.tensor([self.prefix_ids], dtype=torch.long, device=self.device) |
|
| |
383
| + attention_mask = torch.ones_like(prefix_tensor, dtype=torch.long, device=self.device) |
|
| |
384
| + outputs = self.model( |
|
| |
385
| + input_ids=prefix_tensor, |
|
| |
386
| + attention_mask=attention_mask, |
|
| |
387
| + use_cache=True, |
|
| |
388
| + return_dict=True, |
|
| |
389
| + ) |
|
| |
390
| + raw_cache = tuple( |
|
| |
391
| + (layer.keys.detach(), layer.values.detach()) |
|
| |
392
| + for layer in outputs.past_key_values.layers |
|
| |
393
| + ) |
|
| |
394
| + return PrefixCache( |
|
| |
395
| + prefix_ids=list(self.prefix_ids), |
|
| |
396
| + prefix_hashes=_hash_blocks(self.prefix_ids, self.cfg.prefix_block_size), |
|
| |
397
| + raw_cache=raw_cache, |
|
| |
398
| + ) |
|
| |
399
| + |
|
| |
400
| + def expand_prefix_raw_cache(self, batch_size: int) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: |
|
| |
401
| + if not self.prefix_cache.raw_cache: |
|
| |
402
| + return tuple() |
|
| |
403
| + return _expand_legacy_cache(self.prefix_cache.raw_cache, batch_size) |
|
| |
404
| + |
|
| |
405
| + def build_continuation_from_query_ids(self, query_ids: list[int]) -> list[int]: |
|
| |
406
| + continuation = query_ids + self.suffix_ids |
|
| |
407
| + if not continuation: |
|
| |
408
| + raise ValueError("prompt continuation is empty after substituting query") |
|
| |
409
| + if self.prefix_cache.prefix_len + len(continuation) > self.cfg.max_length: |
|
| |
410
| + raise ValueError( |
|
| |
411
| + f"sequence length {self.prefix_cache.prefix_len + len(continuation)} exceeds max_length={self.cfg.max_length}" |
|
| |
412
| + ) |
|
| |
413
| + return continuation |
|
| |
414
| + |
|
| |
415
| + @torch.inference_mode() |
|
| |
416
| + def reduce_fast_scores( |
|
| |
417
| + self, |
|
| |
418
| + hidden: torch.Tensor, |
|
| |
419
| + out_scores: torch.Tensor, |
|
| |
420
| + ) -> None: |
|
| |
421
| + assert self.scorer is not None |
|
| |
422
| + out_scores.copy_(self.scorer(hidden)) |
|
| |
423
| + |
|
| |
424
| + @torch.inference_mode() |
|
| |
425
| + def reduce_multi_token_scores( |
|
| |
426
| + self, |
|
| |
427
| + last_hidden_state: torch.Tensor, |
|
| |
428
| + batch_size: int, |
|
| |
429
| + max_input_len: int, |
|
| |
430
| + score_positions: torch.Tensor, |
|
| |
431
| + out_scores: torch.Tensor, |
|
| |
432
| + ) -> None: |
|
| |
433
| + assert self.multi_token_tables is not None |
|
| |
434
| + hidden = last_hidden_state.reshape(batch_size, self.num_labels, max_input_len, self.hidden_size) |
|
| |
435
| + gather_index = score_positions[:, None, :, None].expand( |
|
| |
436
| + batch_size, |
|
| |
437
| + self.num_labels, |
|
| |
438
| + self.multi_token_tables.max_label_len, |
|
| |
439
| + self.hidden_size, |
|
| |
440
| + ) |
|
| |
441
| + gathered_hidden = torch.gather(hidden, 2, gather_index) |
|
| |
442
| + used_mask = self.multi_token_tables.label_token_mask.unsqueeze(0).expand(batch_size, -1, -1).bool() |
|
| |
443
| + |
|
| |
444
| + token_log_probs = torch.zeros( |
|
| |
445
| + (batch_size, self.num_labels, self.multi_token_tables.max_label_len), |
|
| |
446
| + device=self.device, |
|
| |
447
| + dtype=torch.float32, |
|
| |
448
| + ) |
|
| |
449
| + if used_mask.any(): |
|
| |
450
| + flat_hidden = gathered_hidden[used_mask] |
|
| |
451
| + flat_token_ids = self.multi_token_tables.label_token_ids.unsqueeze(0).expand(batch_size, -1, -1)[used_mask] |
|
| |
452
| + linear_hidden = flat_hidden.to(self.dtype) if self.device.type == "cuda" else flat_hidden.float() |
|
| |
453
| + linear_weight = self.lm_head_weight if self.device.type == "cuda" else self.lm_head_weight.float() |
|
| |
454
| + linear_bias = self.lm_head_bias |
|
| |
455
| + if linear_bias is not None and self.device.type != "cuda": |
|
| |
456
| + linear_bias = linear_bias.float() |
|
| |
457
| + flat_logits = F.linear(linear_hidden, linear_weight, linear_bias) |
|
| |
458
| + flat_selected = flat_logits.gather(1, flat_token_ids.unsqueeze(1)).squeeze(1).float() |
|
| |
459
| + flat_log_norm = torch.logsumexp(flat_logits.float(), dim=-1) |
|
| |
460
| + token_log_probs[used_mask] = flat_selected - flat_log_norm |
|
| |
461
| + out_scores.copy_(token_log_probs.sum(dim=-1)) |
|
| |
462
| + |
|
| |
463
| + def build_score_result( |
|
| |
464
| + self, |
|
| |
465
| + query: str, |
|
| |
466
| + scores: torch.Tensor, |
|
| |
467
| + stage_ms: dict[str, float], |
|
| |
468
| + continuation_tokens: int, |
|
| |
469
| + ) -> QueryScoreResult: |
|
| |
470
| + score_values = scores.detach().float().cpu().tolist() |
|
| |
471
| + best_idx = max(range(len(score_values)), key=score_values.__getitem__) |
|
| |
472
| + probs = torch.softmax(torch.tensor(score_values, dtype=torch.float32), dim=0).tolist() |
|
| |
473
| + return QueryScoreResult( |
|
| |
474
| + task_name=self.task_cfg.name, |
|
| |
475
| + query=query, |
|
| |
476
| + predicted_label=self.labels[best_idx], |
|
| |
477
| + scores=[ |
|
| |
478
| + (label, score, prob) |
|
| |
479
| + for label, score, prob in zip(self.labels, score_values, probs, strict=True) |
|
| |
480
| + ], |
|
| |
481
| + total_ms=sum(stage_ms.values()), |
|
| |
482
| + stage_ms=stage_ms, |
|
| |
483
| + fast_path=self.fast_path, |
|
| |
484
| + prefix_tokens=self.prefix_cache.prefix_len, |
|
| |
485
| + continuation_tokens=continuation_tokens, |
|
| |
486
| + label_token_lengths={item.text: len(item.token_ids) for item in self.encoded_labels}, |
|
| |
487
| + ) |
|
| |
488
| + |
|
| |
489
| + |
|
| |
490
| +class MultiPromptRunner: |
|
| |
491
| + def __init__(self, cfg: RuntimeConfig): |
|
| |
492
| + self.cfg = cfg |
|
| |
493
| + t0 = time.perf_counter() |
|
| |
494
| + self.shared_runtime = self.build_shared_runtime(cfg) |
|
| |
495
| + t1 = time.perf_counter() |
|
| |
496
| + self.device = self.shared_runtime.device |
|
| |
497
| + self.dtype = self.shared_runtime.dtype |
|
| |
498
| + self.tokenizer = self.shared_runtime.tokenizer |
|
| |
499
| + self.model = self.shared_runtime.model |
|
| |
500
| + self.backbone = self.shared_runtime.backbone |
|
| |
501
| + self.hidden_size = self.shared_runtime.hidden_size |
|
| |
502
| + self.graph_capture_pool = self.shared_runtime.graph_capture_pool |
|
| |
503
| + self.graph_capture_stream = self.shared_runtime.graph_capture_stream |
|
| |
504
| + self.runners = [ |
|
| |
505
| + PromptClassifierRunner(cfg=cfg, task_cfg=task_cfg, shared_runtime=self.shared_runtime) |
|
| |
506
| + for task_cfg in cfg.tasks |
|
| |
507
| + ] |
|
| |
508
| + t2 = time.perf_counter() |
|
| |
509
| + self.batch_layouts = {batch_size: self._build_batch_layout(batch_size) for batch_size in self.cfg.batch_sizes} |
|
| |
510
| + t3 = time.perf_counter() |
|
| |
511
| + self.mixed_prefix_caches = { |
|
| |
512
| + batch_size: self._build_mixed_prefix_cache(self.batch_layouts[batch_size]) |
|
| |
513
| + for batch_size in self.cfg.batch_sizes |
|
| |
514
| + } |
|
| |
515
| + t4 = time.perf_counter() |
|
| |
516
| + self.max_label_prefix_len = max( |
|
| |
517
| + (runner.multi_token_tables.max_label_prefix_len if runner.multi_token_tables is not None else 0) |
|
| |
518
| + for runner in self.runners |
|
| |
519
| + ) |
|
| |
520
| + self.mixed_buckets = { |
|
| |
521
| + (batch_size, continuation_len): self._build_mixed_bucket( |
|
| |
522
| + self.batch_layouts[batch_size], |
|
| |
523
| + self.mixed_prefix_caches[batch_size], |
|
| |
524
| + continuation_len, |
|
| |
525
| + ) |
|
| |
526
| + for batch_size in self.cfg.batch_sizes |
|
| |
527
| + for continuation_len in self.cfg.continuation_buckets |
|
| |
528
| + } |
|
| |
529
| + t5 = time.perf_counter() |
|
| |
530
| + self._warmup_results: dict[int, BatchScoreResult] = {} |
|
| |
531
| + self._preload_report: PreloadReport | None = None |
|
| |
532
| + self._init_stage_ms = { |
|
| |
533
| + "load_model_and_tokenizer": (t1 - t0) * 1000.0, |
|
| |
534
| + "build_prompt_runtimes": (t2 - t1) * 1000.0, |
|
| |
535
| + "build_batch_layouts": (t3 - t2) * 1000.0, |
|
| |
536
| + "build_mixed_prefix_caches": (t4 - t3) * 1000.0, |
|
| |
537
| + "build_mixed_buckets_and_graphs": (t5 - t4) * 1000.0, |
|
| |
538
| + } |
|
| |
539
| + self._init_total_ms = sum(self._init_stage_ms.values()) |
|
| |
540
| + |
|
| |
541
| + @staticmethod |
|
| |
542
| + def build_shared_runtime(cfg: RuntimeConfig) -> SharedRuntime: |
|
| |
543
| + device = torch.device(cfg.device) |
|
| |
544
| + dtype = torch.float16 |
|
| |
545
| + tokenizer = AutoTokenizer.from_pretrained( |
|
| |
546
| + cfg.resolved_model_source, |
|
| |
547
| + trust_remote_code=cfg.resolved_trust_remote_code, |
|
| |
548
| + token=cfg.hf_token, |
|
| |
549
| + cache_dir=cfg.hf_cache_dir, |
|
| |
550
| + local_files_only=cfg.resolved_local_files_only, |
|
| |
551
| + ) |
|
| |
552
| + if tokenizer.pad_token_id is None: |
|
| |
553
| + tokenizer.pad_token = tokenizer.eos_token |
|
| |
554
| + attn_impl = MultiPromptRunner._resolve_attn_impl(cfg.attn_backend) |
|
| |
555
| + quantization_config = None |
|
| |
556
| + model_kwargs: dict[str, object] = { |
|
| |
557
| + "trust_remote_code": cfg.resolved_trust_remote_code, |
|
| |
558
| + "attn_implementation": attn_impl, |
|
| |
559
| + "token": cfg.hf_token, |
|
| |
560
| + "cache_dir": cfg.hf_cache_dir, |
|
| |
561
| + "local_files_only": cfg.resolved_local_files_only, |
|
| |
562
| + } |
|
| |
563
| + if cfg.load_in_4bit: |
|
| |
564
| + if BitsAndBytesConfig is None: |
|
| |
565
| + raise ImportError("transformers BitsAndBytesConfig is unavailable; install bitsandbytes support first") |
|
| |
566
| + quantization_config = BitsAndBytesConfig( |
|
| |
567
| + load_in_4bit=True, |
|
| |
568
| + bnb_4bit_compute_dtype=dtype, |
|
| |
569
| + bnb_4bit_quant_type=cfg.bnb_4bit_quant_type, |
|
| |
570
| + bnb_4bit_use_double_quant=cfg.bnb_4bit_use_double_quant, |
|
| |
571
| + ) |
|
| |
572
| + model_kwargs["quantization_config"] = quantization_config |
|
| |
573
| + model_kwargs["device_map"] = {"": device.index or 0} |
|
| |
574
| + else: |
|
| |
575
| + model_kwargs["dtype"] = dtype |
|
| |
576
| + model_kwargs["device_map"] = None |
|
| |
577
| + model = AutoModelForCausalLM.from_pretrained( |
|
| |
578
| + cfg.resolved_model_source, |
|
| |
579
| + **model_kwargs, |
|
| |
580
| + ).eval() |
|
| |
581
| + if not cfg.load_in_4bit: |
|
| |
582
| + model = model.to(device) |
|
| |
583
| + backbone = model.get_submodule(model.base_model_prefix) |
|
| |
584
| + hidden_size = model.get_output_embeddings().weight.shape[1] |
|
| |
585
| + graph_capture_pool = None |
|
| |
586
| + graph_capture_stream = None |
|
| |
587
| + if device.type == "cuda" and torch.cuda.is_available() and cfg.cuda_graphs and not cfg.load_in_4bit: |
|
| |
588
| + graph_capture_pool = torch.cuda.graph_pool_handle() |
|
| |
589
| + graph_capture_stream = torch.cuda.Stream(device=device) |
|
| |
590
| + return SharedRuntime( |
|
| |
591
| + device=device, |
|
| |
592
| + dtype=dtype, |
|
| |
593
| + tokenizer=tokenizer, |
|
| |
594
| + model=model, |
|
| |
595
| + backbone=backbone, |
|
| |
596
| + hidden_size=hidden_size, |
|
| |
597
| + graph_capture_pool=graph_capture_pool, |
|
| |
598
| + graph_capture_stream=graph_capture_stream, |
|
| |
599
| + ) |
|
| |
600
| + |
|
| |
601
| + @staticmethod |
|
| |
602
| + def _resolve_attn_impl(requested: str) -> str: |
|
| |
603
| + if requested in {"sdpa", "eager"}: |
|
| |
604
| + return requested |
|
| |
605
| + if requested == "auto": |
|
| |
606
| + return "sdpa" |
|
| |
607
| + raise ValueError(f"unsupported attn_backend: {requested}") |
|
| |
608
| + |
|
| |
609
| + def _attn_context(self): |
|
| |
610
| + if sdpa_kernel is not None and self.cfg.attn_backend in {"auto", "sdpa"}: |
|
| |
611
| + return sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]) |
|
| |
612
| + return torch.no_grad() |
|
| |
613
| + |
|
| |
614
| + def _sync(self) -> None: |
|
| |
615
| + if self.device.type == "cuda": |
|
| |
616
| + torch.cuda.synchronize() |
|
| |
617
| + |
|
| |
618
| + def _pick_bucket(self, continuation_len: int) -> int: |
|
| |
619
| + for bucket in self.cfg.continuation_buckets: |
|
| |
620
| + if continuation_len <= bucket: |
|
| |
621
| + return bucket |
|
| |
622
| + if self.cfg.pad_to_bucket: |
|
| |
623
| + raise ValueError( |
|
| |
624
| + f"continuation length {continuation_len} exceeds configured buckets; extend continuation_buckets" |
|
| |
625
| + ) |
|
| |
626
| + return continuation_len |
|
| |
627
| + |
|
| |
628
| + def _build_batch_layout(self, batch_size: int) -> BatchLayout: |
|
| |
629
| + plans: list[PromptBatchPlan] = [] |
|
| |
630
| + row_start = 0 |
|
| |
631
| + for runner in self.runners: |
|
| |
632
| + row_count = batch_size if runner.fast_path else batch_size * runner.num_labels |
|
| |
633
| + plans.append( |
|
| |
634
| + PromptBatchPlan( |
|
| |
635
| + runner=runner, |
|
| |
636
| + row_start=row_start, |
|
| |
637
| + row_count=row_count, |
|
| |
638
| + score_buffer=torch.empty((batch_size, runner.num_labels), device=self.device, dtype=torch.float32), |
|
| |
639
| + ) |
|
| |
640
| + ) |
|
| |
641
| + row_start += row_count |
|
| |
642
| + return BatchLayout(batch_size=batch_size, total_rows=row_start, plans=plans) |
|
| |
643
| + |
|
| |
644
| + def _build_mixed_prefix_cache(self, layout: BatchLayout) -> MixedPrefixCache: |
|
| |
645
| + prefix_lengths = torch.zeros((layout.total_rows,), device=self.device, dtype=torch.long) |
|
| |
646
| + non_empty = [plan.runner.prefix_cache.raw_cache for plan in layout.plans if plan.runner.prefix_cache.raw_cache] |
|
| |
647
| + if not non_empty: |
|
| |
648
| + return MixedPrefixCache( |
|
| |
649
| + batch_size=layout.batch_size, |
|
| |
650
| + total_rows=layout.total_rows, |
|
| |
651
| + prefix_lengths=prefix_lengths, |
|
| |
652
| + attention_mask=torch.zeros((layout.total_rows, 0), device=self.device, dtype=torch.long), |
|
| |
653
| + raw_cache=tuple(), |
|
| |
654
| + ) |
|
| |
655
| + |
|
| |
656
| + max_prefix_len = max(plan.runner.prefix_cache.prefix_len for plan in layout.plans) |
|
| |
657
| + num_layers = len(non_empty[0]) |
|
| |
658
| + attention_mask = torch.zeros((layout.total_rows, max_prefix_len), device=self.device, dtype=torch.long) |
|
| |
659
| + raw_layers: list[tuple[torch.Tensor, torch.Tensor]] = [] |
|
| |
660
| + for layer_idx in range(num_layers): |
|
| |
661
| + sample_key, sample_value = non_empty[0][layer_idx] |
|
| |
662
| + merged_key = sample_key.new_zeros( |
|
| |
663
| + (layout.total_rows, sample_key.shape[1], max_prefix_len, sample_key.shape[3]) |
|
| |
664
| + ) |
|
| |
665
| + merged_value = sample_value.new_zeros( |
|
| |
666
| + (layout.total_rows, sample_value.shape[1], max_prefix_len, sample_value.shape[3]) |
|
| |
667
| + ) |
|
| |
668
| + raw_layers.append((merged_key, merged_value)) |
|
| |
669
| + |
|
| |
670
| + for plan in layout.plans: |
|
| |
671
| + runner = plan.runner |
|
| |
672
| + prefix_len = runner.prefix_cache.prefix_len |
|
| |
673
| + row_slice = slice(plan.row_start, plan.row_stop) |
|
| |
674
| + prefix_lengths[row_slice] = prefix_len |
|
| |
675
| + if prefix_len == 0: |
|
| |
676
| + continue |
|
| |
677
| + attention_mask[row_slice, :prefix_len] = 1 |
|
| |
678
| + raw_cache = runner.expand_prefix_raw_cache(plan.row_count) |
|
| |
679
| + for layer_idx, (key, value) in enumerate(raw_cache): |
|
| |
680
| + merged_key, merged_value = raw_layers[layer_idx] |
|
| |
681
| + merged_key[row_slice, :, :prefix_len, :] = key |
|
| |
682
| + merged_value[row_slice, :, :prefix_len, :] = value |
|
| |
683
| + |
|
| |
684
| + return MixedPrefixCache( |
|
| |
685
| + batch_size=layout.batch_size, |
|
| |
686
| + total_rows=layout.total_rows, |
|
| |
687
| + prefix_lengths=prefix_lengths, |
|
| |
688
| + attention_mask=attention_mask.contiguous(), |
|
| |
689
| + raw_cache=tuple(raw_layers), |
|
| |
690
| + ) |
|
| |
691
| + |
|
| |
692
| + def _build_mixed_bucket( |
|
| |
693
| + self, |
|
| |
694
| + layout: BatchLayout, |
|
| |
695
| + prefix_cache: MixedPrefixCache, |
|
| |
696
| + continuation_len: int, |
|
| |
697
| + ) -> MixedBucketRuntime: |
|
| |
698
| + max_input_len = continuation_len + self.max_label_prefix_len |
|
| |
699
| + total_len = prefix_cache.max_prefix_len + max_input_len |
|
| |
700
| + input_ids = torch.full( |
|
| |
701
| + (layout.total_rows, max_input_len), |
|
| |
702
| + fill_value=self.tokenizer.pad_token_id, |
|
| |
703
| + device=self.device, |
|
| |
704
| + dtype=torch.long, |
|
| |
705
| + ) |
|
| |
706
| + attention_mask = torch.zeros((layout.total_rows, total_len), device=self.device, dtype=torch.long) |
|
| |
707
| + if prefix_cache.max_prefix_len: |
|
| |
708
| + attention_mask[:, : prefix_cache.max_prefix_len] = prefix_cache.attention_mask |
|
| |
709
| + position_ids = ( |
|
| |
710
| + prefix_cache.prefix_lengths[:, None] |
|
| |
711
| + + torch.arange(max_input_len, device=self.device, dtype=torch.long).unsqueeze(0) |
|
| |
712
| + ).contiguous() |
|
| |
713
| + last_hidden_state = torch.empty( |
|
| |
714
| + (layout.total_rows, max_input_len, self.hidden_size), |
|
| |
715
| + device=self.device, |
|
| |
716
| + dtype=self.dtype, |
|
| |
717
| + ) |
|
| |
718
| + bucket = MixedBucketRuntime( |
|
| |
719
| + batch_size=layout.batch_size, |
|
| |
720
| + total_rows=layout.total_rows, |
|
| |
721
| + continuation_len=continuation_len, |
|
| |
722
| + max_input_len=max_input_len, |
|
| |
723
| + input_ids=input_ids, |
|
| |
724
| + attention_mask=attention_mask, |
|
| |
725
| + position_ids=position_ids, |
|
| |
726
| + last_hidden_state=last_hidden_state, |
|
| |
727
| + ) |
|
| |
728
| + if self.cfg.cuda_graphs: |
|
| |
729
| + self._capture_mixed_bucket(bucket, prefix_cache) |
|
| |
730
| + return bucket |
|
| |
731
| + |
|
| |
732
| + @torch.inference_mode() |
|
| |
733
| + def _run_mixed_backbone( |
|
| |
734
| + self, |
|
| |
735
| + bucket: MixedBucketRuntime, |
|
| |
736
| + prefix_cache: MixedPrefixCache, |
|
| |
737
| + ) -> None: |
|
| |
738
| + cache = DynamicCache(ddp_cache_data=prefix_cache.raw_cache, config=self.model.config) |
|
| |
739
| + with self._attn_context(): |
|
| |
740
| + outputs = self.backbone( |
|
| |
741
| + input_ids=bucket.input_ids, |
|
| |
742
| + attention_mask=bucket.attention_mask, |
|
| |
743
| + position_ids=bucket.position_ids, |
|
| |
744
| + past_key_values=cache, |
|
| |
745
| + use_cache=False, |
|
| |
746
| + return_dict=True, |
|
| |
747
| + ) |
|
| |
748
| + bucket.last_hidden_state.copy_(outputs.last_hidden_state) |
|
| |
749
| + |
|
| |
750
| + def _capture_mixed_bucket(self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache) -> None: |
|
| |
751
| + if not torch.cuda.is_available(): |
|
| |
752
| + return |
|
| |
753
| + try: |
|
| |
754
| + torch.cuda.synchronize() |
|
| |
755
| + stream = self.graph_capture_stream or torch.cuda.Stream(device=self.device) |
|
| |
756
| + with torch.cuda.stream(stream): |
|
| |
757
| + for _ in range(self.cfg.graph_warmups): |
|
| |
758
| + self._run_mixed_backbone(bucket, prefix_cache) |
|
| |
759
| + stream.synchronize() |
|
| |
760
| + graph = torch.cuda.CUDAGraph() |
|
| |
761
| + with torch.cuda.graph(graph, pool=self.graph_capture_pool, stream=stream): |
|
| |
762
| + self._run_mixed_backbone(bucket, prefix_cache) |
|
| |
763
| + bucket.graph = graph |
|
| |
764
| + except RuntimeError: |
|
| |
765
| + bucket.graph = None |
|
| |
766
| + |
|
| |
767
| + def _prepare_bucket( |
|
| |
768
| + self, |
|
| |
769
| + layout: BatchLayout, |
|
| |
770
| + prefix_cache: MixedPrefixCache, |
|
| |
771
| + bucket: MixedBucketRuntime, |
|
| |
772
| + query_ids_batch: list[list[int]], |
|
| |
773
| + ) -> tuple[list[list[int]], dict[str, list[object]]]: |
|
| |
774
| + del prefix_cache |
|
| |
775
| + bucket.input_ids.fill_(self.tokenizer.pad_token_id) |
|
| |
776
| + bucket.attention_mask.zero_() |
|
| |
777
| + if self.mixed_prefix_caches[layout.batch_size].max_prefix_len: |
|
| |
778
| + bucket.attention_mask[:, : self.mixed_prefix_caches[layout.batch_size].max_prefix_len] = ( |
|
| |
779
| + self.mixed_prefix_caches[layout.batch_size].attention_mask |
|
| |
780
| + ) |
|
| |
781
| + continuation_lengths_per_task: dict[str, list[int]] = {} |
|
| |
782
| + continuation_tokens_per_task: dict[str, list[list[int]]] = {} |
|
| |
783
| + prefix_base = self.mixed_prefix_caches[layout.batch_size].max_prefix_len |
|
| |
784
| + for plan in layout.plans: |
|
| |
785
| + runner = plan.runner |
|
| |
786
| + per_query_continuations = [runner.build_continuation_from_query_ids(query_ids) for query_ids in query_ids_batch] |
|
| |
787
| + continuation_tokens_per_task[runner.task_cfg.name] = per_query_continuations |
|
| |
788
| + continuation_lengths_per_task[runner.task_cfg.name] = [len(ids) for ids in per_query_continuations] |
|
| |
789
| + if runner.fast_path: |
|
| |
790
| + for batch_idx, continuation in enumerate(per_query_continuations): |
|
| |
791
| + cont_len = len(continuation) |
|
| |
792
| + row_idx = plan.row_start + batch_idx |
|
| |
793
| + bucket.input_ids[row_idx, :cont_len] = torch.tensor(continuation, device=self.device, dtype=torch.long) |
|
| |
794
| + bucket.attention_mask[row_idx, prefix_base : prefix_base + cont_len] = 1 |
|
| |
795
| + continue |
|
| |
796
| + |
|
| |
797
| + assert runner.multi_token_tables is not None |
|
| |
798
| + for batch_idx, continuation in enumerate(per_query_continuations): |
|
| |
799
| + cont_len = len(continuation) |
|
| |
800
| + row_start = plan.row_start + batch_idx * runner.num_labels |
|
| |
801
| + row_stop = row_start + runner.num_labels |
|
| |
802
| + row_slice = slice(row_start, row_stop) |
|
| |
803
| + cont_tensor = torch.tensor(continuation, device=self.device, dtype=torch.long) |
|
| |
804
| + bucket.input_ids[row_slice, :cont_len] = cont_tensor.unsqueeze(0).expand(runner.num_labels, -1) |
|
| |
805
| + bucket.attention_mask[row_slice, prefix_base : prefix_base + cont_len] = 1 |
|
| |
806
| + if runner.multi_token_tables.max_label_prefix_len: |
|
| |
807
| + bucket.input_ids[ |
|
| |
808
| + row_slice, |
|
| |
809
| + cont_len : cont_len + runner.multi_token_tables.max_label_prefix_len, |
|
| |
810
| + ] = runner.multi_token_tables.label_prefix_ids |
|
| |
811
| + bucket.attention_mask[ |
|
| |
812
| + row_slice, |
|
| |
813
| + prefix_base + cont_len : prefix_base + cont_len + runner.multi_token_tables.max_label_prefix_len, |
|
| |
814
| + ] = runner.multi_token_tables.label_prefix_mask |
|
| |
815
| + return query_ids_batch, { |
|
| |
816
| + "continuation_lengths_per_task": continuation_lengths_per_task, |
|
| |
817
| + "continuation_tokens_per_task": continuation_tokens_per_task, |
|
| |
818
| + } |
|
| |
819
| + |
|
| |
820
| + def _reduce_prompt_scores( |
|
| |
821
| + self, |
|
| |
822
| + layout: BatchLayout, |
|
| |
823
| + bucket: MixedBucketRuntime, |
|
| |
824
| + query_texts: list[str], |
|
| |
825
| + prep_meta: dict[str, list[object]], |
|
| |
826
| + shared_stage_ms: dict[str, float], |
|
| |
827
| + ) -> list[MultiPromptScoreResult]: |
|
| |
828
| + result_rows = [[] for _ in range(layout.batch_size)] |
|
| |
829
| + prompt_reduce_total_ms = 0.0 |
|
| |
830
| + for plan in layout.plans: |
|
| |
831
| + runner = plan.runner |
|
| |
832
| + continuation_lengths = prep_meta["continuation_lengths_per_task"][runner.task_cfg.name] |
|
| |
833
| + reduce_start = time.perf_counter() |
|
| |
834
| + if runner.fast_path: |
|
| |
835
| + hidden_rows = [] |
|
| |
836
| + row_slice = bucket.last_hidden_state[plan.row_start : plan.row_start + layout.batch_size] |
|
| |
837
| + for batch_idx, cont_len in enumerate(continuation_lengths): |
|
| |
838
| + hidden_rows.append(row_slice[batch_idx, cont_len - 1]) |
|
| |
839
| + hidden = torch.stack(hidden_rows, dim=0) |
|
| |
840
| + runner.reduce_fast_scores(hidden=hidden, out_scores=plan.score_buffer) |
|
| |
841
| + stage_name = "tail_scorer" |
|
| |
842
| + else: |
|
| |
843
| + assert runner.multi_token_tables is not None |
|
| |
844
| + score_positions = torch.stack( |
|
| |
845
| + [ |
|
| |
846
| + cont_len - 1 + runner.multi_token_tables.label_position_offsets |
|
| |
847
| + for cont_len in continuation_lengths |
|
| |
848
| + ], |
|
| |
849
| + dim=0, |
|
| |
850
| + ) |
|
| |
851
| + runner.reduce_multi_token_scores( |
|
| |
852
| + last_hidden_state=bucket.last_hidden_state[plan.row_start : plan.row_stop], |
|
| |
853
| + batch_size=layout.batch_size, |
|
| |
854
| + max_input_len=bucket.max_input_len, |
|
| |
855
| + score_positions=score_positions, |
|
| |
856
| + out_scores=plan.score_buffer, |
|
| |
857
| + ) |
|
| |
858
| + stage_name = "candidate_reduce" |
|
| |
859
| + self._sync() |
|
| |
860
| + reduce_end = time.perf_counter() |
|
| |
861
| + reduce_ms = (reduce_end - reduce_start) * 1000.0 |
|
| |
862
| + prompt_reduce_total_ms += reduce_ms |
|
| |
863
| + for batch_idx, query in enumerate(query_texts): |
|
| |
864
| + stage_ms = dict(shared_stage_ms) |
|
| |
865
| + stage_ms[stage_name] = reduce_ms / layout.batch_size |
|
| |
866
| + result_rows[batch_idx].append( |
|
| |
867
| + runner.build_score_result( |
|
| |
868
| + query=query, |
|
| |
869
| + scores=plan.score_buffer[batch_idx], |
|
| |
870
| + stage_ms=stage_ms, |
|
| |
871
| + continuation_tokens=continuation_lengths[batch_idx], |
|
| |
872
| + ) |
|
| |
873
| + ) |
|
| |
874
| + |
|
| |
875
| + batch_total_ms = sum(shared_stage_ms.values()) + prompt_reduce_total_ms |
|
| |
876
| + shared_plus_reduce = dict(shared_stage_ms) |
|
| |
877
| + shared_plus_reduce["prompt_reduce_total"] = prompt_reduce_total_ms |
|
| |
878
| + results: list[MultiPromptScoreResult] = [] |
|
| |
879
| + for batch_idx, query in enumerate(query_texts): |
|
| |
880
| + results.append( |
|
| |
881
| + MultiPromptScoreResult( |
|
| |
882
| + query=query, |
|
| |
883
| + total_ms=batch_total_ms / layout.batch_size, |
|
| |
884
| + details=result_rows[batch_idx], |
|
| |
885
| + stage_ms={ |
|
| |
886
| + **shared_plus_reduce, |
|
| |
887
| + "per_query_total_estimate": batch_total_ms / layout.batch_size, |
|
| |
888
| + }, |
|
| |
889
| + ) |
|
| |
890
| + ) |
|
| |
891
| + return results |
|
| |
892
| + |
|
| |
893
| + @torch.inference_mode() |
|
| |
894
| + def score_queries(self, queries: list[str]) -> BatchScoreResult: |
|
| |
895
| + if not queries: |
|
| |
896
| + raise ValueError("queries must not be empty") |
|
| |
897
| + batch_size = len(queries) |
|
| |
898
| + if batch_size not in self.batch_layouts: |
|
| |
899
| + raise ValueError(f"batch size {batch_size} is not preloaded; configured batch_sizes={self.cfg.batch_sizes}") |
|
| |
900
| + layout = self.batch_layouts[batch_size] |
|
| |
901
| + prefix_cache = self.mixed_prefix_caches[batch_size] |
|
| |
902
| + |
|
| |
903
| + self._sync() |
|
| |
904
| + t0 = time.perf_counter() |
|
| |
905
| + query_ids_batch = [self.tokenizer.encode(query, add_special_tokens=False) for query in queries] |
|
| |
906
| + self._sync() |
|
| |
907
| + t1 = time.perf_counter() |
|
| |
908
| + |
|
| |
909
| + max_continuation_len = max( |
|
| |
910
| + len(plan.runner.build_continuation_from_query_ids(query_ids)) |
|
| |
911
| + for plan in layout.plans |
|
| |
912
| + for query_ids in query_ids_batch |
|
| |
913
| + ) |
|
| |
914
| + picked_bucket = self._pick_bucket(max_continuation_len) |
|
| |
915
| + bucket = self.mixed_buckets[(batch_size, picked_bucket)] |
|
| |
916
| + _, prep_meta = self._prepare_bucket(layout, prefix_cache, bucket, query_ids_batch) |
|
| |
917
| + self._sync() |
|
| |
918
| + t2 = time.perf_counter() |
|
| |
919
| + |
|
| |
920
| + if bucket.graph is not None: |
|
| |
921
| + bucket.graph.replay() |
|
| |
922
| + else: |
|
| |
923
| + self._run_mixed_backbone(bucket, prefix_cache) |
|
| |
924
| + self._sync() |
|
| |
925
| + t3 = time.perf_counter() |
|
| |
926
| + |
|
| |
927
| + shared_stage_ms = { |
|
| |
928
| + "encode_queries_shared": (t1 - t0) * 1000.0, |
|
| |
929
| + "prepare_batch_shared": (t2 - t1) * 1000.0, |
|
| |
930
| + "backbone_shared": (t3 - t2) * 1000.0, |
|
| |
931
| + } |
|
| |
932
| + results = self._reduce_prompt_scores(layout, bucket, queries, prep_meta, shared_stage_ms) |
|
| |
933
| + total_ms = sum(shared_stage_ms.values()) + results[0].stage_ms["prompt_reduce_total"] |
|
| |
934
| + return BatchScoreResult( |
|
| |
935
| + batch_size=batch_size, |
|
| |
936
| + total_ms=total_ms, |
|
| |
937
| + results=results, |
|
| |
938
| + stage_ms={ |
|
| |
939
| + **shared_stage_ms, |
|
| |
940
| + "prompt_reduce_total": results[0].stage_ms["prompt_reduce_total"], |
|
| |
941
| + }, |
|
| |
942
| + ) |
|
| |
943
| + |
|
| |
944
| + def score_query(self, query: str) -> MultiPromptScoreResult: |
|
| |
945
| + return self.score_queries([query]).results[0] |
|
| |
946
| + |
|
| |
947
| + def preload(self) -> PreloadReport: |
|
| |
948
| + if self._preload_report is not None: |
|
| |
949
| + return self._preload_report |
|
| |
950
| + stage_ms: dict[str, float] = dict(self._init_stage_ms) |
|
| |
951
| + start = time.perf_counter() |
|
| |
952
| + self._sync() |
|
| |
953
| + t0 = time.perf_counter() |
|
| |
954
| + warmup_batch_sizes = self.cfg.warmup_batch_sizes or self.cfg.batch_sizes |
|
| |
955
| + for batch_size in warmup_batch_sizes: |
|
| |
956
| + queries = [self.cfg.warmup_query] * batch_size |
|
| |
957
| + self._warmup_results[batch_size] = self.score_queries(queries) |
|
| |
958
| + self._sync() |
|
| |
959
| + t1 = time.perf_counter() |
|
| |
960
| + stage_ms["warmup_end_to_end"] = (t1 - t0) * 1000.0 |
|
| |
961
| + stage_ms["startup_total_before_warmup"] = self._init_total_ms |
|
| |
962
| + total_ms = self._init_total_ms + (t1 - start) * 1000.0 |
|
| |
963
| + runtime = self.preload_report() |
|
| |
964
| + self._preload_report = PreloadReport(total_ms=total_ms, stage_ms=stage_ms, runtime=runtime) |
|
| |
965
| + return self._preload_report |
|
| |
966
| + |
|
| |
967
| + def preload_report(self) -> dict[str, object]: |
|
| |
968
| + return { |
|
| |
969
| + "model_name": self.cfg.resolved_model_name, |
|
| |
970
| + "model_source": self.cfg.resolved_model_source, |
|
| |
971
| + "device": str(self.device), |
|
| |
972
| + "dtype": self.cfg.dtype, |
|
| |
973
| + "attn_backend": self.cfg.attn_backend, |
|
| |
974
| + "execution_model": "single_mixed_backbone_per_batch", |
|
| |
975
| + "num_tasks": len(self.runners), |
|
| |
976
| + "task_names": [runner.task_cfg.name for runner in self.runners], |
|
| |
977
| + "batch_sizes": list(self.cfg.batch_sizes), |
|
| |
978
| + "continuation_buckets": list(self.cfg.continuation_buckets), |
|
| |
979
| + "mixed_bucket_count": len(self.mixed_buckets), |
|
| |
980
| + "captured_mixed_buckets": sum(bucket.graph is not None for bucket in self.mixed_buckets.values()), |
|
| |
981
| + "all_configured_buckets_preloaded": True, |
|
| |
982
| + "init_stage_ms": dict(self._init_stage_ms), |
|
| |
983
| + "init_total_ms": self._init_total_ms, |
|
| |
984
| + "force_single_token_labels": self.cfg.force_single_token_labels, |
|
| |
985
| + "warmup_query": self.cfg.warmup_query, |
|
| |
986
| + "tasks": [ |
|
| |
987
| + { |
|
| |
988
| + "task_name": runner.task_cfg.name, |
|
| |
989
| + "fast_path": runner.fast_path, |
|
| |
990
| + "num_labels": runner.num_labels, |
|
| |
991
| + "label_token_lengths": {item.text: len(item.token_ids) for item in runner.encoded_labels}, |
|
| |
992
| + "prefix_tokens": runner.prefix_cache.prefix_len, |
|
| |
993
| + "prefix_hashes": runner.prefix_cache.prefix_hashes, |
|
| |
994
| + "label_prefix": runner.task_cfg.label_prefix, |
|
| |
995
| + } |
|
| |
996
| + for runner in self.runners |
|
| |
997
| + ], |
|
| |
998
| + } |
|
| |
999
| + |
|
| |
1000
| +但是关于多token的处理方式是低效的、是错的,不要参考他,请你重新实现。 |
|
| |
1001
| +本地的.venv已经创建好,是复用llm-qp的,请使用该环境 |
|
| |
1002
| +有用的代码我已经引用过来了,请你基于需求,搜寻相关资料,专注于重新开发全新的版本。 |
|
| |
1003
| +任务较重,请耐心逐步完成,慢慢来,要长时间迭代一直到服务建设完成、测试完全符合预期。 |