...
...
@@ -0,0 +1,1003 @@
1
+总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。
2
+先专注于推理的优化,最后再考虑服务化,支持一定程度的并发(比如4)的请求,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,使用N个prompt进行N个维度的分类,对于每个prompt,推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。
3
+
4
+下面有一些参考技术资料,但是你并不需要严格,你应该有一定的灵活度,来追求极致的性能。
5
+
6
+在 Tesla T4 上,用 3B 到 6B 级别的开源 decoder-only 基座模型做 query 分类。
7
+启动时完成 tokenizer、权重、prefix cache 和共享执行器准备工作。
8
+每次输入一个 query,输出每个 prompt 下每个 label 的分数分布,以及预测耗时和阶段耗时。
9
+不走通用生成路径,不做 decode,不取 full vocab logits,不做 constrained decode。
10
+对 multi-token label 做专门优化,避免 Python 侧串行 decode。
11
+prompt 和 label 集合必须可配置(目前只有两个,以后我会加到8个,每次请求输入一个query,并行的调用8个prompt进行推理得到得分最高的label:
12
+
13
+
14
+prompts:
15
+ - name: category
16
+ prompt_template: |
17
+ <|im_start|>system
18
+ Analyze the category intent. Output exactly one label from [none, dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other]. Use 'none' if no category intent.<|im_end|>
19
+ <|im_start|>user
20
+ query: {query}<|im_end|>
21
+ <|im_start|>assistant
22
+ label:
23
+ label_prefix: " "
24
+ labels: [dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other, none]
25
+
26
+ - name: audience
27
+ prompt_template: |
28
+ <|im_start|>system
29
+ Analyze the target user group. Output exactly one label from [none, boy, girl, man, woman, pregnant]. Use 'none' if no audience mentioned.<|im_end|>
30
+ <|im_start|>user
31
+ query: {query}<|im_end|>
32
+ <|im_start|>assistant
33
+ label:
34
+ label_prefix: " "
35
+ labels: [boy, girl, man, woman, pregnant, none]
36
+
37
+
38
+做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟。
39
+
40
+主要考虑优化方向为:
41
+1. hidden_last -> N-class scorer -> argmax
42
+2. 参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV
43
+3. 去 full vocab logits
44
+4. 去 decode / constrained decode
45
+5. 专用 tail kernel(输出 N 类原始分数)
46
+6. 配置的N个 prompt推理要并行推理(2-8个)
47
+7. 使用Tesla T4,因此不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention
48
+多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。
49
+
50
+
51
+你有sudo权限,你可以执行为本项目安装自己的环境
52
+
53
+使用Qwen/Qwen3-8B的Q4或Q8模型,具体用哪个版本,请你查找huggingface相关资料,选择合适的版本完成部署,并进行推理耗时的测试。
54
+
55
+请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。
56
+
57
+一个重要的问题:一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况。
58
+需要通过多 token 标签做极致的性能优化,避免串行decode。
59
+我们最终目的是得到哪个label的得分最高,不一定要精确的概率,计算log P(id1 | query, prompt) + log P(id2 | query, prompt, id1)有可能导致难以优化性能,精确的概率是可以考虑放弃的,要清楚我们的最终目的,达到分类的目的即可,只要得到分类,优先考虑性能,精确的概率可以放下。
60
+如何通过一次模型 forward处理包括多token label的整个 batch,是你需要探索的问题。
61
+
62
+单 token fast path 的做法比较确定: last_hidden -> small class scorer -> argmax。 只取目标 label 对应 LM head 行,不做 full vocab 输出。
63
+multi-token 怎么做需要搜索相关资料进行考量,最好要做到跟单token开销相同(放弃精确的log-prob的前提下。但是:多token和单token的label的打分的对比,一定要是可比的才能正确的分类,兼顾性能和打分的准确性)
64
+
65
+还需要增加一个配置:force_single_token_labels,所有 label 都按首 token 处理,因为,如果各个label收token不同,那么可以近似的认为首token打分代表整个label打分。
66
+你需要找到多label打分性能和准确性上面的最佳实践。同时也支持force_single_token_labels以达到极致的性能。
67
+
68
+也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。
69
+以下是一些参考示例:
70
+vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/
71
+PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/
72
+TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html
73
+Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq
74
+
75
+TensorRT support matrix: T4 / SM7.5 supports FP16 and INT8, but not BF16/FP8 in the main matrix. https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Support-Matrix-Guide.pdf
76
+TensorRT-LLM support matrix: current official hardware list omits Turing/T4, so T4 is effectively community-support territory there. https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html
77
+FlashAttention repo: FA3 is Hopper-focused; current published benchmarks are A100/H100-centric. https://github.com/Dao-AILab/flash-attention
78
+vLLM APC docs: KV reuse via hashed KV blocks was the baseline idea for the prefix-cache metadata. https://docs.vllm.ai/_/downloads/en/v0.6.2/pdf/
79
+SGLang HiCache/RadixAttention docs: useful reference for prefix-cache reuse and page-granular KV organization. https://docs.sglang.io/advanced_features/hicache_design.html
80
+FasterTransformer repo: still a useful T4 FP16 optimization baseline and historical Turing-oriented reference. https://github.com/NVIDIA/FasterTransformer
81
+xFormers README: relevant as a Turing-friendly attention alternative; my mainline choice here is PyTorch SDPA on T4, which is an engineering inference from these sources rather than a direct vendor recommendation. https://github.com/facebookresearch/xformers
82
+
83
+注意:已经有一个项目 llm-qp, llm-qp2,这两个项目,对于单token的处理方式是可以的:
84
+SDPA
85
+prefix cache
86
+prebuilt bucket + CUDA graph
87
+他的核心代码是:
88
+from __future__ import annotations
89
+
90
+import hashlib
91
+import time
92
+from dataclasses import asdict, dataclass
93
+from typing import Iterable
94
+
95
+import torch
96
+import torch.nn.functional as F
97
+from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
98
+
99
+try:
100
+ from transformers import BitsAndBytesConfig
101
+except ImportError: # pragma: no cover
102
+ BitsAndBytesConfig = None
103
+
104
+from llm_qp.config import PromptTaskConfig, RuntimeConfig
105
+from llm_qp.scorer import SmallClassScorer
106
+
107
+try:
108
+ from torch.nn.attention import SDPBackend, sdpa_kernel
109
+except ImportError: # pragma: no cover
110
+ SDPBackend = None
111
+ sdpa_kernel = None
112
+
113
+
114
+@dataclass(slots=True)
115
+class EncodedLabel:
116
+ text: str
117
+ token_ids: list[int]
118
+
119
+
120
+@dataclass(slots=True)
121
+class PrefixCache:
122
+ prefix_ids: list[int]
123
+ prefix_hashes: list[str]
124
+ raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...]
125
+
126
+ @property
127
+ def prefix_len(self) -> int:
128
+ return len(self.prefix_ids)
129
+
130
+
131
+@dataclass(slots=True)
132
+class MultiTokenTables:
133
+ label_token_ids: torch.Tensor
134
+ label_token_mask: torch.Tensor
135
+ label_prefix_ids: torch.Tensor
136
+ label_prefix_mask: torch.Tensor
137
+ label_position_offsets: torch.Tensor
138
+
139
+ @property
140
+ def max_label_len(self) -> int:
141
+ return self.label_token_ids.shape[1]
142
+
143
+ @property
144
+ def max_label_prefix_len(self) -> int:
145
+ return self.label_prefix_ids.shape[1]
146
+
147
+
148
+@dataclass(slots=True)
149
+class QueryScoreResult:
150
+ task_name: str
151
+ query: str
152
+ predicted_label: str
153
+ scores: list[tuple[str, float, float]]
154
+ total_ms: float
155
+ stage_ms: dict[str, float]
156
+ fast_path: bool
157
+ prefix_tokens: int
158
+ continuation_tokens: int
159
+ label_token_lengths: dict[str, int]
160
+
161
+ @property
162
+ def predicted_prob(self) -> float:
163
+ for label, _score, prob in self.scores:
164
+ if label == self.predicted_label:
165
+ return prob
166
+ return 0.0
167
+
168
+
169
+@dataclass(slots=True)
170
+class MultiPromptScoreResult:
171
+ query: str
172
+ total_ms: float
173
+ details: list[QueryScoreResult]
174
+ stage_ms: dict[str, float]
175
+
176
+ def http_json(self) -> dict[str, object]:
177
+ return {
178
+ "query": self.query,
179
+ "total_ms": self.total_ms,
180
+ "stage_ms": self.stage_ms,
181
+ "details": [asdict(t) for t in self.details],
182
+ "task_results": {
183
+ t.task_name: [t.predicted_label, t.continuation_tokens, t.predicted_prob] for t in self.details if t.predicted_label != 'none'
184
+ },
185
+ }
186
+
187
+
188
+@dataclass(slots=True)
189
+class BatchScoreResult:
190
+ batch_size: int
191
+ total_ms: float
192
+ results: list[MultiPromptScoreResult]
193
+ stage_ms: dict[str, float]
194
+
195
+
196
+@dataclass(slots=True)
197
+class SharedRuntime:
198
+ device: torch.device
199
+ dtype: torch.dtype
200
+ tokenizer: object
201
+ model: object
202
+ backbone: object
203
+ hidden_size: int
204
+ graph_capture_pool: object | None = None
205
+ graph_capture_stream: torch.cuda.Stream | None = None
206
+
207
+
208
+@dataclass(slots=True)
209
+class PromptBatchPlan:
210
+ runner: "PromptClassifierRunner"
211
+ row_start: int
212
+ row_count: int
213
+ score_buffer: torch.Tensor
214
+
215
+ @property
216
+ def row_stop(self) -> int:
217
+ return self.row_start + self.row_count
218
+
219
+
220
+@dataclass(slots=True)
221
+class MixedPrefixCache:
222
+ batch_size: int
223
+ total_rows: int
224
+ prefix_lengths: torch.Tensor
225
+ attention_mask: torch.Tensor
226
+ raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...]
227
+
228
+ @property
229
+ def max_prefix_len(self) -> int:
230
+ return int(self.prefix_lengths.max().item()) if self.prefix_lengths.numel() else 0
231
+
232
+
233
+@dataclass(slots=True)
234
+class BatchLayout:
235
+ batch_size: int
236
+ total_rows: int
237
+ plans: list[PromptBatchPlan]
238
+
239
+
240
+@dataclass(slots=True)
241
+class MixedBucketRuntime:
242
+ batch_size: int
243
+ total_rows: int
244
+ continuation_len: int
245
+ max_input_len: int
246
+ input_ids: torch.Tensor
247
+ attention_mask: torch.Tensor
248
+ position_ids: torch.Tensor
249
+ last_hidden_state: torch.Tensor
250
+ graph: torch.cuda.CUDAGraph | None = None
251
+
252
+
253
+@dataclass(slots=True)
254
+class PreloadReport:
255
+ total_ms: float
256
+ stage_ms: dict[str, float]
257
+ runtime: dict[str, object]
258
+
259
+
260
+def _hash_blocks(token_ids: Iterable[int], block_size: int) -> list[str]:
261
+ token_list = list(token_ids)
262
+ hashes: list[str] = []
263
+ for start in range(0, len(token_list), block_size):
264
+ block = token_list[start : start + block_size]
265
+ payload = ",".join(str(x) for x in block).encode("utf-8")
266
+ hashes.append(hashlib.sha1(payload).hexdigest())
267
+ return hashes
268
+
269
+
270
+def _expand_legacy_cache(
271
+ raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...],
272
+ batch_size: int,
273
+) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
274
+ expanded: list[tuple[torch.Tensor, torch.Tensor]] = []
275
+ for key, value in raw_cache:
276
+ expanded.append(
277
+ (
278
+ key.expand(batch_size, *key.shape[1:]).contiguous(),
279
+ value.expand(batch_size, *value.shape[1:]).contiguous(),
280
+ )
281
+ )
282
+ return tuple(expanded)
283
+
284
+
285
+class PromptClassifierRunner:
286
+ def __init__(
287
+ self,
288
+ cfg: RuntimeConfig,
289
+ task_cfg: PromptTaskConfig,
290
+ shared_runtime: SharedRuntime,
291
+ ):
292
+ self.cfg = cfg
293
+ self.task_cfg = task_cfg
294
+ self.device = shared_runtime.device
295
+ self.dtype = shared_runtime.dtype
296
+ self.tokenizer = shared_runtime.tokenizer
297
+ self.model = shared_runtime.model
298
+ self.backbone = shared_runtime.backbone
299
+ self.hidden_size = shared_runtime.hidden_size
300
+ self.prefix_text, self.suffix_text = task_cfg.prompt_parts
301
+ self.prefix_ids = self.tokenizer.encode(self.prefix_text, add_special_tokens=False)
302
+ self.suffix_ids = self.tokenizer.encode(self.suffix_text, add_special_tokens=False)
303
+ self.labels = list(task_cfg.labels)
304
+ self.encoded_labels = [
305
+ EncodedLabel(text=label, token_ids=self._encode_label_token_ids(label))
306
+ for label in self.labels
307
+ ]
308
+ self.num_labels = len(self.labels)
309
+ self.lm_head = self.model.get_output_embeddings()
310
+ self.lm_head_weight = self.lm_head.weight.detach()
311
+ self.lm_head_bias = self.lm_head.bias.detach() if getattr(self.lm_head, "bias", None) is not None else None
312
+ if self.cfg.force_single_token_labels and not self._has_unique_single_token_labels():
313
+ raise ValueError(
314
+ f"prompt task '{self.task_cfg.name}' cannot force single-token labels because first tokens collide"
315
+ )
316
+ self.fast_path = self._has_unique_single_token_labels()
317
+ self.fast_path_token_ids = [item.token_ids[0] for item in self.encoded_labels] if self.fast_path else []
318
+ self.scorer = self._build_scorer() if self.fast_path else None
319
+ self.multi_token_tables = self._build_multi_token_tables() if not self.fast_path else None
320
+ self.prefix_cache = self._build_prefix_cache()
321
+
322
+ def _encode_label_token_ids(self, label: str) -> list[int]:
323
+ token_ids = self.tokenizer.encode(
324
+ f"{self.task_cfg.label_prefix}{label}",
325
+ add_special_tokens=False,
326
+ )
327
+ if not token_ids:
328
+ raise ValueError(f"label '{label}' in prompt '{self.task_cfg.name}' tokenizes to an empty sequence")
329
+ if self.cfg.force_single_token_labels:
330
+ return token_ids[:1]
331
+ return token_ids
332
+
333
+ def _has_unique_single_token_labels(self) -> bool:
334
+ token_ids: list[int] = []
335
+ for item in self.encoded_labels:
336
+ if len(item.token_ids) != 1:
337
+ return False
338
+ token_ids.append(item.token_ids[0])
339
+ return len(token_ids) == len(set(token_ids))
340
+
341
+ def _build_scorer(self) -> SmallClassScorer:
342
+ token_ids = torch.tensor(self.fast_path_token_ids, dtype=torch.long, device=self.device)
343
+ weights = torch.index_select(self.lm_head_weight, 0, token_ids).to(dtype=self.dtype).contiguous()
344
+ bias = None
345
+ if self.lm_head_bias is not None:
346
+ bias = torch.index_select(self.lm_head_bias, 0, token_ids).to(dtype=self.dtype).contiguous()
347
+ return SmallClassScorer(weights=weights, bias=bias)
348
+
349
+ def _build_multi_token_tables(self) -> MultiTokenTables:
350
+ max_label_len = max(len(item.token_ids) for item in self.encoded_labels)
351
+ max_prefix_len = max(len(item.token_ids) - 1 for item in self.encoded_labels)
352
+ label_token_ids = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.long)
353
+ label_token_mask = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.float32)
354
+ label_prefix_ids = torch.full(
355
+ (self.num_labels, max_prefix_len),
356
+ fill_value=self.tokenizer.pad_token_id,
357
+ device=self.device,
358
+ dtype=torch.long,
359
+ )
360
+ label_prefix_mask = torch.zeros((self.num_labels, max_prefix_len), device=self.device, dtype=torch.long)
361
+ for idx, item in enumerate(self.encoded_labels):
362
+ token_ids = torch.tensor(item.token_ids, device=self.device, dtype=torch.long)
363
+ token_len = token_ids.numel()
364
+ label_token_ids[idx, :token_len] = token_ids
365
+ label_token_mask[idx, :token_len] = 1.0
366
+ if token_len > 1:
367
+ prefix_len = token_len - 1
368
+ label_prefix_ids[idx, :prefix_len] = token_ids[:-1]
369
+ label_prefix_mask[idx, :prefix_len] = 1
370
+ return MultiTokenTables(
371
+ label_token_ids=label_token_ids.contiguous(),
372
+ label_token_mask=label_token_mask.contiguous(),
373
+ label_prefix_ids=label_prefix_ids.contiguous(),
374
+ label_prefix_mask=label_prefix_mask.contiguous(),
375
+ label_position_offsets=torch.arange(max_label_len, device=self.device, dtype=torch.long),
376
+ )
377
+
378
+ @torch.inference_mode()
379
+ def _build_prefix_cache(self) -> PrefixCache:
380
+ if not self.prefix_ids:
381
+ return PrefixCache(prefix_ids=[], prefix_hashes=[], raw_cache=tuple())
382
+ prefix_tensor = torch.tensor([self.prefix_ids], dtype=torch.long, device=self.device)
383
+ attention_mask = torch.ones_like(prefix_tensor, dtype=torch.long, device=self.device)
384
+ outputs = self.model(
385
+ input_ids=prefix_tensor,
386
+ attention_mask=attention_mask,
387
+ use_cache=True,
388
+ return_dict=True,
389
+ )
390
+ raw_cache = tuple(
391
+ (layer.keys.detach(), layer.values.detach())
392
+ for layer in outputs.past_key_values.layers
393
+ )
394
+ return PrefixCache(
395
+ prefix_ids=list(self.prefix_ids),
396
+ prefix_hashes=_hash_blocks(self.prefix_ids, self.cfg.prefix_block_size),
397
+ raw_cache=raw_cache,
398
+ )
399
+
400
+ def expand_prefix_raw_cache(self, batch_size: int) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
401
+ if not self.prefix_cache.raw_cache:
402
+ return tuple()
403
+ return _expand_legacy_cache(self.prefix_cache.raw_cache, batch_size)
404
+
405
+ def build_continuation_from_query_ids(self, query_ids: list[int]) -> list[int]:
406
+ continuation = query_ids + self.suffix_ids
407
+ if not continuation:
408
+ raise ValueError("prompt continuation is empty after substituting query")
409
+ if self.prefix_cache.prefix_len + len(continuation) > self.cfg.max_length:
410
+ raise ValueError(
411
+ f"sequence length {self.prefix_cache.prefix_len + len(continuation)} exceeds max_length={self.cfg.max_length}"
412
+ )
413
+ return continuation
414
+
415
+ @torch.inference_mode()
416
+ def reduce_fast_scores(
417
+ self,
418
+ hidden: torch.Tensor,
419
+ out_scores: torch.Tensor,
420
+ ) -> None:
421
+ assert self.scorer is not None
422
+ out_scores.copy_(self.scorer(hidden))
423
+
424
+ @torch.inference_mode()
425
+ def reduce_multi_token_scores(
426
+ self,
427
+ last_hidden_state: torch.Tensor,
428
+ batch_size: int,
429
+ max_input_len: int,
430
+ score_positions: torch.Tensor,
431
+ out_scores: torch.Tensor,
432
+ ) -> None:
433
+ assert self.multi_token_tables is not None
434
+ hidden = last_hidden_state.reshape(batch_size, self.num_labels, max_input_len, self.hidden_size)
435
+ gather_index = score_positions[:, None, :, None].expand(
436
+ batch_size,
437
+ self.num_labels,
438
+ self.multi_token_tables.max_label_len,
439
+ self.hidden_size,
440
+ )
441
+ gathered_hidden = torch.gather(hidden, 2, gather_index)
442
+ used_mask = self.multi_token_tables.label_token_mask.unsqueeze(0).expand(batch_size, -1, -1).bool()
443
+
444
+ token_log_probs = torch.zeros(
445
+ (batch_size, self.num_labels, self.multi_token_tables.max_label_len),
446
+ device=self.device,
447
+ dtype=torch.float32,
448
+ )
449
+ if used_mask.any():
450
+ flat_hidden = gathered_hidden[used_mask]
451
+ flat_token_ids = self.multi_token_tables.label_token_ids.unsqueeze(0).expand(batch_size, -1, -1)[used_mask]
452
+ linear_hidden = flat_hidden.to(self.dtype) if self.device.type == "cuda" else flat_hidden.float()
453
+ linear_weight = self.lm_head_weight if self.device.type == "cuda" else self.lm_head_weight.float()
454
+ linear_bias = self.lm_head_bias
455
+ if linear_bias is not None and self.device.type != "cuda":
456
+ linear_bias = linear_bias.float()
457
+ flat_logits = F.linear(linear_hidden, linear_weight, linear_bias)
458
+ flat_selected = flat_logits.gather(1, flat_token_ids.unsqueeze(1)).squeeze(1).float()
459
+ flat_log_norm = torch.logsumexp(flat_logits.float(), dim=-1)
460
+ token_log_probs[used_mask] = flat_selected - flat_log_norm
461
+ out_scores.copy_(token_log_probs.sum(dim=-1))
462
+
463
+ def build_score_result(
464
+ self,
465
+ query: str,
466
+ scores: torch.Tensor,
467
+ stage_ms: dict[str, float],
468
+ continuation_tokens: int,
469
+ ) -> QueryScoreResult:
470
+ score_values = scores.detach().float().cpu().tolist()
471
+ best_idx = max(range(len(score_values)), key=score_values.__getitem__)
472
+ probs = torch.softmax(torch.tensor(score_values, dtype=torch.float32), dim=0).tolist()
473
+ return QueryScoreResult(
474
+ task_name=self.task_cfg.name,
475
+ query=query,
476
+ predicted_label=self.labels[best_idx],
477
+ scores=[
478
+ (label, score, prob)
479
+ for label, score, prob in zip(self.labels, score_values, probs, strict=True)
480
+ ],
481
+ total_ms=sum(stage_ms.values()),
482
+ stage_ms=stage_ms,
483
+ fast_path=self.fast_path,
484
+ prefix_tokens=self.prefix_cache.prefix_len,
485
+ continuation_tokens=continuation_tokens,
486
+ label_token_lengths={item.text: len(item.token_ids) for item in self.encoded_labels},
487
+ )
488
+
489
+
490
+class MultiPromptRunner:
491
+ def __init__(self, cfg: RuntimeConfig):
492
+ self.cfg = cfg
493
+ t0 = time.perf_counter()
494
+ self.shared_runtime = self.build_shared_runtime(cfg)
495
+ t1 = time.perf_counter()
496
+ self.device = self.shared_runtime.device
497
+ self.dtype = self.shared_runtime.dtype
498
+ self.tokenizer = self.shared_runtime.tokenizer
499
+ self.model = self.shared_runtime.model
500
+ self.backbone = self.shared_runtime.backbone
501
+ self.hidden_size = self.shared_runtime.hidden_size
502
+ self.graph_capture_pool = self.shared_runtime.graph_capture_pool
503
+ self.graph_capture_stream = self.shared_runtime.graph_capture_stream
504
+ self.runners = [
505
+ PromptClassifierRunner(cfg=cfg, task_cfg=task_cfg, shared_runtime=self.shared_runtime)
506
+ for task_cfg in cfg.tasks
507
+ ]
508
+ t2 = time.perf_counter()
509
+ self.batch_layouts = {batch_size: self._build_batch_layout(batch_size) for batch_size in self.cfg.batch_sizes}
510
+ t3 = time.perf_counter()
511
+ self.mixed_prefix_caches = {
512
+ batch_size: self._build_mixed_prefix_cache(self.batch_layouts[batch_size])
513
+ for batch_size in self.cfg.batch_sizes
514
+ }
515
+ t4 = time.perf_counter()
516
+ self.max_label_prefix_len = max(
517
+ (runner.multi_token_tables.max_label_prefix_len if runner.multi_token_tables is not None else 0)
518
+ for runner in self.runners
519
+ )
520
+ self.mixed_buckets = {
521
+ (batch_size, continuation_len): self._build_mixed_bucket(
522
+ self.batch_layouts[batch_size],
523
+ self.mixed_prefix_caches[batch_size],
524
+ continuation_len,
525
+ )
526
+ for batch_size in self.cfg.batch_sizes
527
+ for continuation_len in self.cfg.continuation_buckets
528
+ }
529
+ t5 = time.perf_counter()
530
+ self._warmup_results: dict[int, BatchScoreResult] = {}
531
+ self._preload_report: PreloadReport | None = None
532
+ self._init_stage_ms = {
533
+ "load_model_and_tokenizer": (t1 - t0) * 1000.0,
534
+ "build_prompt_runtimes": (t2 - t1) * 1000.0,
535
+ "build_batch_layouts": (t3 - t2) * 1000.0,
536
+ "build_mixed_prefix_caches": (t4 - t3) * 1000.0,
537
+ "build_mixed_buckets_and_graphs": (t5 - t4) * 1000.0,
538
+ }
539
+ self._init_total_ms = sum(self._init_stage_ms.values())
540
+
541
+ @staticmethod
542
+ def build_shared_runtime(cfg: RuntimeConfig) -> SharedRuntime:
543
+ device = torch.device(cfg.device)
544
+ dtype = torch.float16
545
+ tokenizer = AutoTokenizer.from_pretrained(
546
+ cfg.resolved_model_source,
547
+ trust_remote_code=cfg.resolved_trust_remote_code,
548
+ token=cfg.hf_token,
549
+ cache_dir=cfg.hf_cache_dir,
550
+ local_files_only=cfg.resolved_local_files_only,
551
+ )
552
+ if tokenizer.pad_token_id is None:
553
+ tokenizer.pad_token = tokenizer.eos_token
554
+ attn_impl = MultiPromptRunner._resolve_attn_impl(cfg.attn_backend)
555
+ quantization_config = None
556
+ model_kwargs: dict[str, object] = {
557
+ "trust_remote_code": cfg.resolved_trust_remote_code,
558
+ "attn_implementation": attn_impl,
559
+ "token": cfg.hf_token,
560
+ "cache_dir": cfg.hf_cache_dir,
561
+ "local_files_only": cfg.resolved_local_files_only,
562
+ }
563
+ if cfg.load_in_4bit:
564
+ if BitsAndBytesConfig is None:
565
+ raise ImportError("transformers BitsAndBytesConfig is unavailable; install bitsandbytes support first")
566
+ quantization_config = BitsAndBytesConfig(
567
+ load_in_4bit=True,
568
+ bnb_4bit_compute_dtype=dtype,
569
+ bnb_4bit_quant_type=cfg.bnb_4bit_quant_type,
570
+ bnb_4bit_use_double_quant=cfg.bnb_4bit_use_double_quant,
571
+ )
572
+ model_kwargs["quantization_config"] = quantization_config
573
+ model_kwargs["device_map"] = {"": device.index or 0}
574
+ else:
575
+ model_kwargs["dtype"] = dtype
576
+ model_kwargs["device_map"] = None
577
+ model = AutoModelForCausalLM.from_pretrained(
578
+ cfg.resolved_model_source,
579
+ **model_kwargs,
580
+ ).eval()
581
+ if not cfg.load_in_4bit:
582
+ model = model.to(device)
583
+ backbone = model.get_submodule(model.base_model_prefix)
584
+ hidden_size = model.get_output_embeddings().weight.shape[1]
585
+ graph_capture_pool = None
586
+ graph_capture_stream = None
587
+ if device.type == "cuda" and torch.cuda.is_available() and cfg.cuda_graphs and not cfg.load_in_4bit:
588
+ graph_capture_pool = torch.cuda.graph_pool_handle()
589
+ graph_capture_stream = torch.cuda.Stream(device=device)
590
+ return SharedRuntime(
591
+ device=device,
592
+ dtype=dtype,
593
+ tokenizer=tokenizer,
594
+ model=model,
595
+ backbone=backbone,
596
+ hidden_size=hidden_size,
597
+ graph_capture_pool=graph_capture_pool,
598
+ graph_capture_stream=graph_capture_stream,
599
+ )
600
+
601
+ @staticmethod
602
+ def _resolve_attn_impl(requested: str) -> str:
603
+ if requested in {"sdpa", "eager"}:
604
+ return requested
605
+ if requested == "auto":
606
+ return "sdpa"
607
+ raise ValueError(f"unsupported attn_backend: {requested}")
608
+
609
+ def _attn_context(self):
610
+ if sdpa_kernel is not None and self.cfg.attn_backend in {"auto", "sdpa"}:
611
+ return sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH])
612
+ return torch.no_grad()
613
+
614
+ def _sync(self) -> None:
615
+ if self.device.type == "cuda":
616
+ torch.cuda.synchronize()
617
+
618
+ def _pick_bucket(self, continuation_len: int) -> int:
619
+ for bucket in self.cfg.continuation_buckets:
620
+ if continuation_len <= bucket:
621
+ return bucket
622
+ if self.cfg.pad_to_bucket:
623
+ raise ValueError(
624
+ f"continuation length {continuation_len} exceeds configured buckets; extend continuation_buckets"
625
+ )
626
+ return continuation_len
627
+
628
+ def _build_batch_layout(self, batch_size: int) -> BatchLayout:
629
+ plans: list[PromptBatchPlan] = []
630
+ row_start = 0
631
+ for runner in self.runners:
632
+ row_count = batch_size if runner.fast_path else batch_size * runner.num_labels
633
+ plans.append(
634
+ PromptBatchPlan(
635
+ runner=runner,
636
+ row_start=row_start,
637
+ row_count=row_count,
638
+ score_buffer=torch.empty((batch_size, runner.num_labels), device=self.device, dtype=torch.float32),
639
+ )
640
+ )
641
+ row_start += row_count
642
+ return BatchLayout(batch_size=batch_size, total_rows=row_start, plans=plans)
643
+
644
+ def _build_mixed_prefix_cache(self, layout: BatchLayout) -> MixedPrefixCache:
645
+ prefix_lengths = torch.zeros((layout.total_rows,), device=self.device, dtype=torch.long)
646
+ non_empty = [plan.runner.prefix_cache.raw_cache for plan in layout.plans if plan.runner.prefix_cache.raw_cache]
647
+ if not non_empty:
648
+ return MixedPrefixCache(
649
+ batch_size=layout.batch_size,
650
+ total_rows=layout.total_rows,
651
+ prefix_lengths=prefix_lengths,
652
+ attention_mask=torch.zeros((layout.total_rows, 0), device=self.device, dtype=torch.long),
653
+ raw_cache=tuple(),
654
+ )
655
+
656
+ max_prefix_len = max(plan.runner.prefix_cache.prefix_len for plan in layout.plans)
657
+ num_layers = len(non_empty[0])
658
+ attention_mask = torch.zeros((layout.total_rows, max_prefix_len), device=self.device, dtype=torch.long)
659
+ raw_layers: list[tuple[torch.Tensor, torch.Tensor]] = []
660
+ for layer_idx in range(num_layers):
661
+ sample_key, sample_value = non_empty[0][layer_idx]
662
+ merged_key = sample_key.new_zeros(
663
+ (layout.total_rows, sample_key.shape[1], max_prefix_len, sample_key.shape[3])
664
+ )
665
+ merged_value = sample_value.new_zeros(
666
+ (layout.total_rows, sample_value.shape[1], max_prefix_len, sample_value.shape[3])
667
+ )
668
+ raw_layers.append((merged_key, merged_value))
669
+
670
+ for plan in layout.plans:
671
+ runner = plan.runner
672
+ prefix_len = runner.prefix_cache.prefix_len
673
+ row_slice = slice(plan.row_start, plan.row_stop)
674
+ prefix_lengths[row_slice] = prefix_len
675
+ if prefix_len == 0:
676
+ continue
677
+ attention_mask[row_slice, :prefix_len] = 1
678
+ raw_cache = runner.expand_prefix_raw_cache(plan.row_count)
679
+ for layer_idx, (key, value) in enumerate(raw_cache):
680
+ merged_key, merged_value = raw_layers[layer_idx]
681
+ merged_key[row_slice, :, :prefix_len, :] = key
682
+ merged_value[row_slice, :, :prefix_len, :] = value
683
+
684
+ return MixedPrefixCache(
685
+ batch_size=layout.batch_size,
686
+ total_rows=layout.total_rows,
687
+ prefix_lengths=prefix_lengths,
688
+ attention_mask=attention_mask.contiguous(),
689
+ raw_cache=tuple(raw_layers),
690
+ )
691
+
692
+ def _build_mixed_bucket(
693
+ self,
694
+ layout: BatchLayout,
695
+ prefix_cache: MixedPrefixCache,
696
+ continuation_len: int,
697
+ ) -> MixedBucketRuntime:
698
+ max_input_len = continuation_len + self.max_label_prefix_len
699
+ total_len = prefix_cache.max_prefix_len + max_input_len
700
+ input_ids = torch.full(
701
+ (layout.total_rows, max_input_len),
702
+ fill_value=self.tokenizer.pad_token_id,
703
+ device=self.device,
704
+ dtype=torch.long,
705
+ )
706
+ attention_mask = torch.zeros((layout.total_rows, total_len), device=self.device, dtype=torch.long)
707
+ if prefix_cache.max_prefix_len:
708
+ attention_mask[:, : prefix_cache.max_prefix_len] = prefix_cache.attention_mask
709
+ position_ids = (
710
+ prefix_cache.prefix_lengths[:, None]
711
+ + torch.arange(max_input_len, device=self.device, dtype=torch.long).unsqueeze(0)
712
+ ).contiguous()
713
+ last_hidden_state = torch.empty(
714
+ (layout.total_rows, max_input_len, self.hidden_size),
715
+ device=self.device,
716
+ dtype=self.dtype,
717
+ )
718
+ bucket = MixedBucketRuntime(
719
+ batch_size=layout.batch_size,
720
+ total_rows=layout.total_rows,
721
+ continuation_len=continuation_len,
722
+ max_input_len=max_input_len,
723
+ input_ids=input_ids,
724
+ attention_mask=attention_mask,
725
+ position_ids=position_ids,
726
+ last_hidden_state=last_hidden_state,
727
+ )
728
+ if self.cfg.cuda_graphs:
729
+ self._capture_mixed_bucket(bucket, prefix_cache)
730
+ return bucket
731
+
732
+ @torch.inference_mode()
733
+ def _run_mixed_backbone(
734
+ self,
735
+ bucket: MixedBucketRuntime,
736
+ prefix_cache: MixedPrefixCache,
737
+ ) -> None:
738
+ cache = DynamicCache(ddp_cache_data=prefix_cache.raw_cache, config=self.model.config)
739
+ with self._attn_context():
740
+ outputs = self.backbone(
741
+ input_ids=bucket.input_ids,
742
+ attention_mask=bucket.attention_mask,
743
+ position_ids=bucket.position_ids,
744
+ past_key_values=cache,
745
+ use_cache=False,
746
+ return_dict=True,
747
+ )
748
+ bucket.last_hidden_state.copy_(outputs.last_hidden_state)
749
+
750
+ def _capture_mixed_bucket(self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache) -> None:
751
+ if not torch.cuda.is_available():
752
+ return
753
+ try:
754
+ torch.cuda.synchronize()
755
+ stream = self.graph_capture_stream or torch.cuda.Stream(device=self.device)
756
+ with torch.cuda.stream(stream):
757
+ for _ in range(self.cfg.graph_warmups):
758
+ self._run_mixed_backbone(bucket, prefix_cache)
759
+ stream.synchronize()
760
+ graph = torch.cuda.CUDAGraph()
761
+ with torch.cuda.graph(graph, pool=self.graph_capture_pool, stream=stream):
762
+ self._run_mixed_backbone(bucket, prefix_cache)
763
+ bucket.graph = graph
764
+ except RuntimeError:
765
+ bucket.graph = None
766
+
767
+ def _prepare_bucket(
768
+ self,
769
+ layout: BatchLayout,
770
+ prefix_cache: MixedPrefixCache,
771
+ bucket: MixedBucketRuntime,
772
+ query_ids_batch: list[list[int]],
773
+ ) -> tuple[list[list[int]], dict[str, list[object]]]:
774
+ del prefix_cache
775
+ bucket.input_ids.fill_(self.tokenizer.pad_token_id)
776
+ bucket.attention_mask.zero_()
777
+ if self.mixed_prefix_caches[layout.batch_size].max_prefix_len:
778
+ bucket.attention_mask[:, : self.mixed_prefix_caches[layout.batch_size].max_prefix_len] = (
779
+ self.mixed_prefix_caches[layout.batch_size].attention_mask
780
+ )
781
+ continuation_lengths_per_task: dict[str, list[int]] = {}
782
+ continuation_tokens_per_task: dict[str, list[list[int]]] = {}
783
+ prefix_base = self.mixed_prefix_caches[layout.batch_size].max_prefix_len
784
+ for plan in layout.plans:
785
+ runner = plan.runner
786
+ per_query_continuations = [runner.build_continuation_from_query_ids(query_ids) for query_ids in query_ids_batch]
787
+ continuation_tokens_per_task[runner.task_cfg.name] = per_query_continuations
788
+ continuation_lengths_per_task[runner.task_cfg.name] = [len(ids) for ids in per_query_continuations]
789
+ if runner.fast_path:
790
+ for batch_idx, continuation in enumerate(per_query_continuations):
791
+ cont_len = len(continuation)
792
+ row_idx = plan.row_start + batch_idx
793
+ bucket.input_ids[row_idx, :cont_len] = torch.tensor(continuation, device=self.device, dtype=torch.long)
794
+ bucket.attention_mask[row_idx, prefix_base : prefix_base + cont_len] = 1
795
+ continue
796
+
797
+ assert runner.multi_token_tables is not None
798
+ for batch_idx, continuation in enumerate(per_query_continuations):
799
+ cont_len = len(continuation)
800
+ row_start = plan.row_start + batch_idx * runner.num_labels
801
+ row_stop = row_start + runner.num_labels
802
+ row_slice = slice(row_start, row_stop)
803
+ cont_tensor = torch.tensor(continuation, device=self.device, dtype=torch.long)
804
+ bucket.input_ids[row_slice, :cont_len] = cont_tensor.unsqueeze(0).expand(runner.num_labels, -1)
805
+ bucket.attention_mask[row_slice, prefix_base : prefix_base + cont_len] = 1
806
+ if runner.multi_token_tables.max_label_prefix_len:
807
+ bucket.input_ids[
808
+ row_slice,
809
+ cont_len : cont_len + runner.multi_token_tables.max_label_prefix_len,
810
+ ] = runner.multi_token_tables.label_prefix_ids
811
+ bucket.attention_mask[
812
+ row_slice,
813
+ prefix_base + cont_len : prefix_base + cont_len + runner.multi_token_tables.max_label_prefix_len,
814
+ ] = runner.multi_token_tables.label_prefix_mask
815
+ return query_ids_batch, {
816
+ "continuation_lengths_per_task": continuation_lengths_per_task,
817
+ "continuation_tokens_per_task": continuation_tokens_per_task,
818
+ }
819
+
820
+ def _reduce_prompt_scores(
821
+ self,
822
+ layout: BatchLayout,
823
+ bucket: MixedBucketRuntime,
824
+ query_texts: list[str],
825
+ prep_meta: dict[str, list[object]],
826
+ shared_stage_ms: dict[str, float],
827
+ ) -> list[MultiPromptScoreResult]:
828
+ result_rows = [[] for _ in range(layout.batch_size)]
829
+ prompt_reduce_total_ms = 0.0
830
+ for plan in layout.plans:
831
+ runner = plan.runner
832
+ continuation_lengths = prep_meta["continuation_lengths_per_task"][runner.task_cfg.name]
833
+ reduce_start = time.perf_counter()
834
+ if runner.fast_path:
835
+ hidden_rows = []
836
+ row_slice = bucket.last_hidden_state[plan.row_start : plan.row_start + layout.batch_size]
837
+ for batch_idx, cont_len in enumerate(continuation_lengths):
838
+ hidden_rows.append(row_slice[batch_idx, cont_len - 1])
839
+ hidden = torch.stack(hidden_rows, dim=0)
840
+ runner.reduce_fast_scores(hidden=hidden, out_scores=plan.score_buffer)
841
+ stage_name = "tail_scorer"
842
+ else:
843
+ assert runner.multi_token_tables is not None
844
+ score_positions = torch.stack(
845
+ [
846
+ cont_len - 1 + runner.multi_token_tables.label_position_offsets
847
+ for cont_len in continuation_lengths
848
+ ],
849
+ dim=0,
850
+ )
851
+ runner.reduce_multi_token_scores(
852
+ last_hidden_state=bucket.last_hidden_state[plan.row_start : plan.row_stop],
853
+ batch_size=layout.batch_size,
854
+ max_input_len=bucket.max_input_len,
855
+ score_positions=score_positions,
856
+ out_scores=plan.score_buffer,
857
+ )
858
+ stage_name = "candidate_reduce"
859
+ self._sync()
860
+ reduce_end = time.perf_counter()
861
+ reduce_ms = (reduce_end - reduce_start) * 1000.0
862
+ prompt_reduce_total_ms += reduce_ms
863
+ for batch_idx, query in enumerate(query_texts):
864
+ stage_ms = dict(shared_stage_ms)
865
+ stage_ms[stage_name] = reduce_ms / layout.batch_size
866
+ result_rows[batch_idx].append(
867
+ runner.build_score_result(
868
+ query=query,
869
+ scores=plan.score_buffer[batch_idx],
870
+ stage_ms=stage_ms,
871
+ continuation_tokens=continuation_lengths[batch_idx],
872
+ )
873
+ )
874
+
875
+ batch_total_ms = sum(shared_stage_ms.values()) + prompt_reduce_total_ms
876
+ shared_plus_reduce = dict(shared_stage_ms)
877
+ shared_plus_reduce["prompt_reduce_total"] = prompt_reduce_total_ms
878
+ results: list[MultiPromptScoreResult] = []
879
+ for batch_idx, query in enumerate(query_texts):
880
+ results.append(
881
+ MultiPromptScoreResult(
882
+ query=query,
883
+ total_ms=batch_total_ms / layout.batch_size,
884
+ details=result_rows[batch_idx],
885
+ stage_ms={
886
+ **shared_plus_reduce,
887
+ "per_query_total_estimate": batch_total_ms / layout.batch_size,
888
+ },
889
+ )
890
+ )
891
+ return results
892
+
893
+ @torch.inference_mode()
894
+ def score_queries(self, queries: list[str]) -> BatchScoreResult:
895
+ if not queries:
896
+ raise ValueError("queries must not be empty")
897
+ batch_size = len(queries)
898
+ if batch_size not in self.batch_layouts:
899
+ raise ValueError(f"batch size {batch_size} is not preloaded; configured batch_sizes={self.cfg.batch_sizes}")
900
+ layout = self.batch_layouts[batch_size]
901
+ prefix_cache = self.mixed_prefix_caches[batch_size]
902
+
903
+ self._sync()
904
+ t0 = time.perf_counter()
905
+ query_ids_batch = [self.tokenizer.encode(query, add_special_tokens=False) for query in queries]
906
+ self._sync()
907
+ t1 = time.perf_counter()
908
+
909
+ max_continuation_len = max(
910
+ len(plan.runner.build_continuation_from_query_ids(query_ids))
911
+ for plan in layout.plans
912
+ for query_ids in query_ids_batch
913
+ )
914
+ picked_bucket = self._pick_bucket(max_continuation_len)
915
+ bucket = self.mixed_buckets[(batch_size, picked_bucket)]
916
+ _, prep_meta = self._prepare_bucket(layout, prefix_cache, bucket, query_ids_batch)
917
+ self._sync()
918
+ t2 = time.perf_counter()
919
+
920
+ if bucket.graph is not None:
921
+ bucket.graph.replay()
922
+ else:
923
+ self._run_mixed_backbone(bucket, prefix_cache)
924
+ self._sync()
925
+ t3 = time.perf_counter()
926
+
927
+ shared_stage_ms = {
928
+ "encode_queries_shared": (t1 - t0) * 1000.0,
929
+ "prepare_batch_shared": (t2 - t1) * 1000.0,
930
+ "backbone_shared": (t3 - t2) * 1000.0,
931
+ }
932
+ results = self._reduce_prompt_scores(layout, bucket, queries, prep_meta, shared_stage_ms)
933
+ total_ms = sum(shared_stage_ms.values()) + results[0].stage_ms["prompt_reduce_total"]
934
+ return BatchScoreResult(
935
+ batch_size=batch_size,
936
+ total_ms=total_ms,
937
+ results=results,
938
+ stage_ms={
939
+ **shared_stage_ms,
940
+ "prompt_reduce_total": results[0].stage_ms["prompt_reduce_total"],
941
+ },
942
+ )
943
+
944
+ def score_query(self, query: str) -> MultiPromptScoreResult:
945
+ return self.score_queries([query]).results[0]
946
+
947
+ def preload(self) -> PreloadReport:
948
+ if self._preload_report is not None:
949
+ return self._preload_report
950
+ stage_ms: dict[str, float] = dict(self._init_stage_ms)
951
+ start = time.perf_counter()
952
+ self._sync()
953
+ t0 = time.perf_counter()
954
+ warmup_batch_sizes = self.cfg.warmup_batch_sizes or self.cfg.batch_sizes
955
+ for batch_size in warmup_batch_sizes:
956
+ queries = [self.cfg.warmup_query] * batch_size
957
+ self._warmup_results[batch_size] = self.score_queries(queries)
958
+ self._sync()
959
+ t1 = time.perf_counter()
960
+ stage_ms["warmup_end_to_end"] = (t1 - t0) * 1000.0
961
+ stage_ms["startup_total_before_warmup"] = self._init_total_ms
962
+ total_ms = self._init_total_ms + (t1 - start) * 1000.0
963
+ runtime = self.preload_report()
964
+ self._preload_report = PreloadReport(total_ms=total_ms, stage_ms=stage_ms, runtime=runtime)
965
+ return self._preload_report
966
+
967
+ def preload_report(self) -> dict[str, object]:
968
+ return {
969
+ "model_name": self.cfg.resolved_model_name,
970
+ "model_source": self.cfg.resolved_model_source,
971
+ "device": str(self.device),
972
+ "dtype": self.cfg.dtype,
973
+ "attn_backend": self.cfg.attn_backend,
974
+ "execution_model": "single_mixed_backbone_per_batch",
975
+ "num_tasks": len(self.runners),
976
+ "task_names": [runner.task_cfg.name for runner in self.runners],
977
+ "batch_sizes": list(self.cfg.batch_sizes),
978
+ "continuation_buckets": list(self.cfg.continuation_buckets),
979
+ "mixed_bucket_count": len(self.mixed_buckets),
980
+ "captured_mixed_buckets": sum(bucket.graph is not None for bucket in self.mixed_buckets.values()),
981
+ "all_configured_buckets_preloaded": True,
982
+ "init_stage_ms": dict(self._init_stage_ms),
983
+ "init_total_ms": self._init_total_ms,
984
+ "force_single_token_labels": self.cfg.force_single_token_labels,
985
+ "warmup_query": self.cfg.warmup_query,
986
+ "tasks": [
987
+ {
988
+ "task_name": runner.task_cfg.name,
989
+ "fast_path": runner.fast_path,
990
+ "num_labels": runner.num_labels,
991
+ "label_token_lengths": {item.text: len(item.token_ids) for item in runner.encoded_labels},
992
+ "prefix_tokens": runner.prefix_cache.prefix_len,
993
+ "prefix_hashes": runner.prefix_cache.prefix_hashes,
994
+ "label_prefix": runner.task_cfg.label_prefix,
995
+ }
996
+ for runner in self.runners
997
+ ],
998
+ }
999
+
1000
+但是关于多token的处理方式是低效的、是错的,不要参考他,请你重新实现。
1001
+本地的.venv已经创建好,是复用llm-qp的,请使用该环境
1002
+有用的代码我已经引用过来了,请你基于需求,搜寻相关资料,专注于重新开发全新的版本。
1003
+任务较重,请耐心逐步完成,慢慢来,要长时间迭代一直到服务建设完成、测试完全符合预期。
...
...