总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 先专注于推理的优化,最后再考虑服务化,支持一定程度的并发(比如4)的请求,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,使用N个prompt进行N个维度的分类,对于每个prompt,推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。 下面有一些参考技术资料,但是你并不需要严格,你应该有一定的灵活度,来追求极致的性能。 在 Tesla T4 上,用 3B 到 6B 级别的开源 decoder-only 基座模型做 query 分类。 启动时完成 tokenizer、权重、prefix cache 和共享执行器准备工作。 每次输入一个 query,输出每个 prompt 下每个 label 的分数分布,以及预测耗时和阶段耗时。 不走通用生成路径,不做 decode,不取 full vocab logits,不做 constrained decode。 对 multi-token label 做专门优化,避免 Python 侧串行 decode。 prompt 和 label 集合必须可配置(目前只有两个,以后我会加到8个,每次请求输入一个query,并行的调用8个prompt进行推理得到得分最高的label: prompts: - name: category prompt_template: | <|im_start|>system Analyze the category intent. Output exactly one label from [none, dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other]. Use 'none' if no category intent.<|im_end|> <|im_start|>user query: {query}<|im_end|> <|im_start|>assistant label: label_prefix: " " labels: [dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other, none] - name: audience prompt_template: | <|im_start|>system Analyze the target user group. Output exactly one label from [none, boy, girl, man, woman, pregnant]. Use 'none' if no audience mentioned.<|im_end|> <|im_start|>user query: {query}<|im_end|> <|im_start|>assistant label: label_prefix: " " labels: [boy, girl, man, woman, pregnant, none] 做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟。 主要考虑优化方向为: 1. hidden_last -> N-class scorer -> argmax 2. 参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV 3. 去 full vocab logits 4. 去 decode / constrained decode 5. 专用 tail kernel(输出 N 类原始分数) 6. 配置的N个 prompt推理要并行推理(2-8个) 7. 使用Tesla T4,因此不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。 你有sudo权限,你可以执行为本项目安装自己的环境 使用Qwen/Qwen3-8B的Q4或Q8模型,具体用哪个版本,请你查找huggingface相关资料,选择合适的版本完成部署,并进行推理耗时的测试。 请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。 一个重要的问题:一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况。 需要通过多 token 标签做极致的性能优化,避免串行decode。 我们最终目的是得到哪个label的得分最高,不一定要精确的概率,计算log P(id1 | query, prompt) + log P(id2 | query, prompt, id1)有可能导致难以优化性能,精确的概率是可以考虑放弃的,要清楚我们的最终目的,达到分类的目的即可,只要得到分类,优先考虑性能,精确的概率可以放下。 如何通过一次模型 forward处理包括多token label的整个 batch,是你需要探索的问题。 单 token fast path 的做法比较确定: last_hidden -> small class scorer -> argmax。 只取目标 label 对应 LM head 行,不做 full vocab 输出。 multi-token 怎么做需要搜索相关资料进行考量,最好要做到跟单token开销相同(放弃精确的log-prob的前提下。但是:多token和单token的label的打分的对比,一定要是可比的才能正确的分类,兼顾性能和打分的准确性) 还需要增加一个配置:force_single_token_labels,所有 label 都按首 token 处理,因为,如果各个label收token不同,那么可以近似的认为首token打分代表整个label打分。 你需要找到多label打分性能和准确性上面的最佳实践。同时也支持force_single_token_labels以达到极致的性能。 也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 以下是一些参考示例: vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq TensorRT support matrix: T4 / SM7.5 supports FP16 and INT8, but not BF16/FP8 in the main matrix. https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Support-Matrix-Guide.pdf TensorRT-LLM support matrix: current official hardware list omits Turing/T4, so T4 is effectively community-support territory there. https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html FlashAttention repo: FA3 is Hopper-focused; current published benchmarks are A100/H100-centric. https://github.com/Dao-AILab/flash-attention vLLM APC docs: KV reuse via hashed KV blocks was the baseline idea for the prefix-cache metadata. https://docs.vllm.ai/_/downloads/en/v0.6.2/pdf/ SGLang HiCache/RadixAttention docs: useful reference for prefix-cache reuse and page-granular KV organization. https://docs.sglang.io/advanced_features/hicache_design.html FasterTransformer repo: still a useful T4 FP16 optimization baseline and historical Turing-oriented reference. https://github.com/NVIDIA/FasterTransformer xFormers README: relevant as a Turing-friendly attention alternative; my mainline choice here is PyTorch SDPA on T4, which is an engineering inference from these sources rather than a direct vendor recommendation. https://github.com/facebookresearch/xformers 注意:已经有一个项目 llm-qp, llm-qp2,这两个项目,对于单token的处理方式是可以的: SDPA prefix cache prebuilt bucket + CUDA graph 他的核心代码是: from __future__ import annotations import hashlib import time from dataclasses import asdict, dataclass from typing import Iterable import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache try: from transformers import BitsAndBytesConfig except ImportError: # pragma: no cover BitsAndBytesConfig = None from llm_qp.config import PromptTaskConfig, RuntimeConfig from llm_qp.scorer import SmallClassScorer try: from torch.nn.attention import SDPBackend, sdpa_kernel except ImportError: # pragma: no cover SDPBackend = None sdpa_kernel = None @dataclass(slots=True) class EncodedLabel: text: str token_ids: list[int] @dataclass(slots=True) class PrefixCache: prefix_ids: list[int] prefix_hashes: list[str] raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] @property def prefix_len(self) -> int: return len(self.prefix_ids) @dataclass(slots=True) class MultiTokenTables: label_token_ids: torch.Tensor label_token_mask: torch.Tensor label_prefix_ids: torch.Tensor label_prefix_mask: torch.Tensor label_position_offsets: torch.Tensor @property def max_label_len(self) -> int: return self.label_token_ids.shape[1] @property def max_label_prefix_len(self) -> int: return self.label_prefix_ids.shape[1] @dataclass(slots=True) class QueryScoreResult: task_name: str query: str predicted_label: str scores: list[tuple[str, float, float]] total_ms: float stage_ms: dict[str, float] fast_path: bool prefix_tokens: int continuation_tokens: int label_token_lengths: dict[str, int] @property def predicted_prob(self) -> float: for label, _score, prob in self.scores: if label == self.predicted_label: return prob return 0.0 @dataclass(slots=True) class MultiPromptScoreResult: query: str total_ms: float details: list[QueryScoreResult] stage_ms: dict[str, float] def http_json(self) -> dict[str, object]: return { "query": self.query, "total_ms": self.total_ms, "stage_ms": self.stage_ms, "details": [asdict(t) for t in self.details], "task_results": { t.task_name: [t.predicted_label, t.continuation_tokens, t.predicted_prob] for t in self.details if t.predicted_label != 'none' }, } @dataclass(slots=True) class BatchScoreResult: batch_size: int total_ms: float results: list[MultiPromptScoreResult] stage_ms: dict[str, float] @dataclass(slots=True) class SharedRuntime: device: torch.device dtype: torch.dtype tokenizer: object model: object backbone: object hidden_size: int graph_capture_pool: object | None = None graph_capture_stream: torch.cuda.Stream | None = None @dataclass(slots=True) class PromptBatchPlan: runner: "PromptClassifierRunner" row_start: int row_count: int score_buffer: torch.Tensor @property def row_stop(self) -> int: return self.row_start + self.row_count @dataclass(slots=True) class MixedPrefixCache: batch_size: int total_rows: int prefix_lengths: torch.Tensor attention_mask: torch.Tensor raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...] @property def max_prefix_len(self) -> int: return int(self.prefix_lengths.max().item()) if self.prefix_lengths.numel() else 0 @dataclass(slots=True) class BatchLayout: batch_size: int total_rows: int plans: list[PromptBatchPlan] @dataclass(slots=True) class MixedBucketRuntime: batch_size: int total_rows: int continuation_len: int max_input_len: int input_ids: torch.Tensor attention_mask: torch.Tensor position_ids: torch.Tensor last_hidden_state: torch.Tensor graph: torch.cuda.CUDAGraph | None = None @dataclass(slots=True) class PreloadReport: total_ms: float stage_ms: dict[str, float] runtime: dict[str, object] def _hash_blocks(token_ids: Iterable[int], block_size: int) -> list[str]: token_list = list(token_ids) hashes: list[str] = [] for start in range(0, len(token_list), block_size): block = token_list[start : start + block_size] payload = ",".join(str(x) for x in block).encode("utf-8") hashes.append(hashlib.sha1(payload).hexdigest()) return hashes def _expand_legacy_cache( raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...], batch_size: int, ) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: expanded: list[tuple[torch.Tensor, torch.Tensor]] = [] for key, value in raw_cache: expanded.append( ( key.expand(batch_size, *key.shape[1:]).contiguous(), value.expand(batch_size, *value.shape[1:]).contiguous(), ) ) return tuple(expanded) class PromptClassifierRunner: def __init__( self, cfg: RuntimeConfig, task_cfg: PromptTaskConfig, shared_runtime: SharedRuntime, ): self.cfg = cfg self.task_cfg = task_cfg self.device = shared_runtime.device self.dtype = shared_runtime.dtype self.tokenizer = shared_runtime.tokenizer self.model = shared_runtime.model self.backbone = shared_runtime.backbone self.hidden_size = shared_runtime.hidden_size self.prefix_text, self.suffix_text = task_cfg.prompt_parts self.prefix_ids = self.tokenizer.encode(self.prefix_text, add_special_tokens=False) self.suffix_ids = self.tokenizer.encode(self.suffix_text, add_special_tokens=False) self.labels = list(task_cfg.labels) self.encoded_labels = [ EncodedLabel(text=label, token_ids=self._encode_label_token_ids(label)) for label in self.labels ] self.num_labels = len(self.labels) self.lm_head = self.model.get_output_embeddings() self.lm_head_weight = self.lm_head.weight.detach() self.lm_head_bias = self.lm_head.bias.detach() if getattr(self.lm_head, "bias", None) is not None else None if self.cfg.force_single_token_labels and not self._has_unique_single_token_labels(): raise ValueError( f"prompt task '{self.task_cfg.name}' cannot force single-token labels because first tokens collide" ) self.fast_path = self._has_unique_single_token_labels() self.fast_path_token_ids = [item.token_ids[0] for item in self.encoded_labels] if self.fast_path else [] self.scorer = self._build_scorer() if self.fast_path else None self.multi_token_tables = self._build_multi_token_tables() if not self.fast_path else None self.prefix_cache = self._build_prefix_cache() def _encode_label_token_ids(self, label: str) -> list[int]: token_ids = self.tokenizer.encode( f"{self.task_cfg.label_prefix}{label}", add_special_tokens=False, ) if not token_ids: raise ValueError(f"label '{label}' in prompt '{self.task_cfg.name}' tokenizes to an empty sequence") if self.cfg.force_single_token_labels: return token_ids[:1] return token_ids def _has_unique_single_token_labels(self) -> bool: token_ids: list[int] = [] for item in self.encoded_labels: if len(item.token_ids) != 1: return False token_ids.append(item.token_ids[0]) return len(token_ids) == len(set(token_ids)) def _build_scorer(self) -> SmallClassScorer: token_ids = torch.tensor(self.fast_path_token_ids, dtype=torch.long, device=self.device) weights = torch.index_select(self.lm_head_weight, 0, token_ids).to(dtype=self.dtype).contiguous() bias = None if self.lm_head_bias is not None: bias = torch.index_select(self.lm_head_bias, 0, token_ids).to(dtype=self.dtype).contiguous() return SmallClassScorer(weights=weights, bias=bias) def _build_multi_token_tables(self) -> MultiTokenTables: max_label_len = max(len(item.token_ids) for item in self.encoded_labels) max_prefix_len = max(len(item.token_ids) - 1 for item in self.encoded_labels) label_token_ids = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.long) label_token_mask = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.float32) label_prefix_ids = torch.full( (self.num_labels, max_prefix_len), fill_value=self.tokenizer.pad_token_id, device=self.device, dtype=torch.long, ) label_prefix_mask = torch.zeros((self.num_labels, max_prefix_len), device=self.device, dtype=torch.long) for idx, item in enumerate(self.encoded_labels): token_ids = torch.tensor(item.token_ids, device=self.device, dtype=torch.long) token_len = token_ids.numel() label_token_ids[idx, :token_len] = token_ids label_token_mask[idx, :token_len] = 1.0 if token_len > 1: prefix_len = token_len - 1 label_prefix_ids[idx, :prefix_len] = token_ids[:-1] label_prefix_mask[idx, :prefix_len] = 1 return MultiTokenTables( label_token_ids=label_token_ids.contiguous(), label_token_mask=label_token_mask.contiguous(), label_prefix_ids=label_prefix_ids.contiguous(), label_prefix_mask=label_prefix_mask.contiguous(), label_position_offsets=torch.arange(max_label_len, device=self.device, dtype=torch.long), ) @torch.inference_mode() def _build_prefix_cache(self) -> PrefixCache: if not self.prefix_ids: return PrefixCache(prefix_ids=[], prefix_hashes=[], raw_cache=tuple()) prefix_tensor = torch.tensor([self.prefix_ids], dtype=torch.long, device=self.device) attention_mask = torch.ones_like(prefix_tensor, dtype=torch.long, device=self.device) outputs = self.model( input_ids=prefix_tensor, attention_mask=attention_mask, use_cache=True, return_dict=True, ) raw_cache = tuple( (layer.keys.detach(), layer.values.detach()) for layer in outputs.past_key_values.layers ) return PrefixCache( prefix_ids=list(self.prefix_ids), prefix_hashes=_hash_blocks(self.prefix_ids, self.cfg.prefix_block_size), raw_cache=raw_cache, ) def expand_prefix_raw_cache(self, batch_size: int) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: if not self.prefix_cache.raw_cache: return tuple() return _expand_legacy_cache(self.prefix_cache.raw_cache, batch_size) def build_continuation_from_query_ids(self, query_ids: list[int]) -> list[int]: continuation = query_ids + self.suffix_ids if not continuation: raise ValueError("prompt continuation is empty after substituting query") if self.prefix_cache.prefix_len + len(continuation) > self.cfg.max_length: raise ValueError( f"sequence length {self.prefix_cache.prefix_len + len(continuation)} exceeds max_length={self.cfg.max_length}" ) return continuation @torch.inference_mode() def reduce_fast_scores( self, hidden: torch.Tensor, out_scores: torch.Tensor, ) -> None: assert self.scorer is not None out_scores.copy_(self.scorer(hidden)) @torch.inference_mode() def reduce_multi_token_scores( self, last_hidden_state: torch.Tensor, batch_size: int, max_input_len: int, score_positions: torch.Tensor, out_scores: torch.Tensor, ) -> None: assert self.multi_token_tables is not None hidden = last_hidden_state.reshape(batch_size, self.num_labels, max_input_len, self.hidden_size) gather_index = score_positions[:, None, :, None].expand( batch_size, self.num_labels, self.multi_token_tables.max_label_len, self.hidden_size, ) gathered_hidden = torch.gather(hidden, 2, gather_index) used_mask = self.multi_token_tables.label_token_mask.unsqueeze(0).expand(batch_size, -1, -1).bool() token_log_probs = torch.zeros( (batch_size, self.num_labels, self.multi_token_tables.max_label_len), device=self.device, dtype=torch.float32, ) if used_mask.any(): flat_hidden = gathered_hidden[used_mask] flat_token_ids = self.multi_token_tables.label_token_ids.unsqueeze(0).expand(batch_size, -1, -1)[used_mask] linear_hidden = flat_hidden.to(self.dtype) if self.device.type == "cuda" else flat_hidden.float() linear_weight = self.lm_head_weight if self.device.type == "cuda" else self.lm_head_weight.float() linear_bias = self.lm_head_bias if linear_bias is not None and self.device.type != "cuda": linear_bias = linear_bias.float() flat_logits = F.linear(linear_hidden, linear_weight, linear_bias) flat_selected = flat_logits.gather(1, flat_token_ids.unsqueeze(1)).squeeze(1).float() flat_log_norm = torch.logsumexp(flat_logits.float(), dim=-1) token_log_probs[used_mask] = flat_selected - flat_log_norm out_scores.copy_(token_log_probs.sum(dim=-1)) def build_score_result( self, query: str, scores: torch.Tensor, stage_ms: dict[str, float], continuation_tokens: int, ) -> QueryScoreResult: score_values = scores.detach().float().cpu().tolist() best_idx = max(range(len(score_values)), key=score_values.__getitem__) probs = torch.softmax(torch.tensor(score_values, dtype=torch.float32), dim=0).tolist() return QueryScoreResult( task_name=self.task_cfg.name, query=query, predicted_label=self.labels[best_idx], scores=[ (label, score, prob) for label, score, prob in zip(self.labels, score_values, probs, strict=True) ], total_ms=sum(stage_ms.values()), stage_ms=stage_ms, fast_path=self.fast_path, prefix_tokens=self.prefix_cache.prefix_len, continuation_tokens=continuation_tokens, label_token_lengths={item.text: len(item.token_ids) for item in self.encoded_labels}, ) class MultiPromptRunner: def __init__(self, cfg: RuntimeConfig): self.cfg = cfg t0 = time.perf_counter() self.shared_runtime = self.build_shared_runtime(cfg) t1 = time.perf_counter() self.device = self.shared_runtime.device self.dtype = self.shared_runtime.dtype self.tokenizer = self.shared_runtime.tokenizer self.model = self.shared_runtime.model self.backbone = self.shared_runtime.backbone self.hidden_size = self.shared_runtime.hidden_size self.graph_capture_pool = self.shared_runtime.graph_capture_pool self.graph_capture_stream = self.shared_runtime.graph_capture_stream self.runners = [ PromptClassifierRunner(cfg=cfg, task_cfg=task_cfg, shared_runtime=self.shared_runtime) for task_cfg in cfg.tasks ] t2 = time.perf_counter() self.batch_layouts = {batch_size: self._build_batch_layout(batch_size) for batch_size in self.cfg.batch_sizes} t3 = time.perf_counter() self.mixed_prefix_caches = { batch_size: self._build_mixed_prefix_cache(self.batch_layouts[batch_size]) for batch_size in self.cfg.batch_sizes } t4 = time.perf_counter() self.max_label_prefix_len = max( (runner.multi_token_tables.max_label_prefix_len if runner.multi_token_tables is not None else 0) for runner in self.runners ) self.mixed_buckets = { (batch_size, continuation_len): self._build_mixed_bucket( self.batch_layouts[batch_size], self.mixed_prefix_caches[batch_size], continuation_len, ) for batch_size in self.cfg.batch_sizes for continuation_len in self.cfg.continuation_buckets } t5 = time.perf_counter() self._warmup_results: dict[int, BatchScoreResult] = {} self._preload_report: PreloadReport | None = None self._init_stage_ms = { "load_model_and_tokenizer": (t1 - t0) * 1000.0, "build_prompt_runtimes": (t2 - t1) * 1000.0, "build_batch_layouts": (t3 - t2) * 1000.0, "build_mixed_prefix_caches": (t4 - t3) * 1000.0, "build_mixed_buckets_and_graphs": (t5 - t4) * 1000.0, } self._init_total_ms = sum(self._init_stage_ms.values()) @staticmethod def build_shared_runtime(cfg: RuntimeConfig) -> SharedRuntime: device = torch.device(cfg.device) dtype = torch.float16 tokenizer = AutoTokenizer.from_pretrained( cfg.resolved_model_source, trust_remote_code=cfg.resolved_trust_remote_code, token=cfg.hf_token, cache_dir=cfg.hf_cache_dir, local_files_only=cfg.resolved_local_files_only, ) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token attn_impl = MultiPromptRunner._resolve_attn_impl(cfg.attn_backend) quantization_config = None model_kwargs: dict[str, object] = { "trust_remote_code": cfg.resolved_trust_remote_code, "attn_implementation": attn_impl, "token": cfg.hf_token, "cache_dir": cfg.hf_cache_dir, "local_files_only": cfg.resolved_local_files_only, } if cfg.load_in_4bit: if BitsAndBytesConfig is None: raise ImportError("transformers BitsAndBytesConfig is unavailable; install bitsandbytes support first") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=dtype, bnb_4bit_quant_type=cfg.bnb_4bit_quant_type, bnb_4bit_use_double_quant=cfg.bnb_4bit_use_double_quant, ) model_kwargs["quantization_config"] = quantization_config model_kwargs["device_map"] = {"": device.index or 0} else: model_kwargs["dtype"] = dtype model_kwargs["device_map"] = None model = AutoModelForCausalLM.from_pretrained( cfg.resolved_model_source, **model_kwargs, ).eval() if not cfg.load_in_4bit: model = model.to(device) backbone = model.get_submodule(model.base_model_prefix) hidden_size = model.get_output_embeddings().weight.shape[1] graph_capture_pool = None graph_capture_stream = None if device.type == "cuda" and torch.cuda.is_available() and cfg.cuda_graphs and not cfg.load_in_4bit: graph_capture_pool = torch.cuda.graph_pool_handle() graph_capture_stream = torch.cuda.Stream(device=device) return SharedRuntime( device=device, dtype=dtype, tokenizer=tokenizer, model=model, backbone=backbone, hidden_size=hidden_size, graph_capture_pool=graph_capture_pool, graph_capture_stream=graph_capture_stream, ) @staticmethod def _resolve_attn_impl(requested: str) -> str: if requested in {"sdpa", "eager"}: return requested if requested == "auto": return "sdpa" raise ValueError(f"unsupported attn_backend: {requested}") def _attn_context(self): if sdpa_kernel is not None and self.cfg.attn_backend in {"auto", "sdpa"}: return sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]) return torch.no_grad() def _sync(self) -> None: if self.device.type == "cuda": torch.cuda.synchronize() def _pick_bucket(self, continuation_len: int) -> int: for bucket in self.cfg.continuation_buckets: if continuation_len <= bucket: return bucket if self.cfg.pad_to_bucket: raise ValueError( f"continuation length {continuation_len} exceeds configured buckets; extend continuation_buckets" ) return continuation_len def _build_batch_layout(self, batch_size: int) -> BatchLayout: plans: list[PromptBatchPlan] = [] row_start = 0 for runner in self.runners: row_count = batch_size if runner.fast_path else batch_size * runner.num_labels plans.append( PromptBatchPlan( runner=runner, row_start=row_start, row_count=row_count, score_buffer=torch.empty((batch_size, runner.num_labels), device=self.device, dtype=torch.float32), ) ) row_start += row_count return BatchLayout(batch_size=batch_size, total_rows=row_start, plans=plans) def _build_mixed_prefix_cache(self, layout: BatchLayout) -> MixedPrefixCache: prefix_lengths = torch.zeros((layout.total_rows,), device=self.device, dtype=torch.long) non_empty = [plan.runner.prefix_cache.raw_cache for plan in layout.plans if plan.runner.prefix_cache.raw_cache] if not non_empty: return MixedPrefixCache( batch_size=layout.batch_size, total_rows=layout.total_rows, prefix_lengths=prefix_lengths, attention_mask=torch.zeros((layout.total_rows, 0), device=self.device, dtype=torch.long), raw_cache=tuple(), ) max_prefix_len = max(plan.runner.prefix_cache.prefix_len for plan in layout.plans) num_layers = len(non_empty[0]) attention_mask = torch.zeros((layout.total_rows, max_prefix_len), device=self.device, dtype=torch.long) raw_layers: list[tuple[torch.Tensor, torch.Tensor]] = [] for layer_idx in range(num_layers): sample_key, sample_value = non_empty[0][layer_idx] merged_key = sample_key.new_zeros( (layout.total_rows, sample_key.shape[1], max_prefix_len, sample_key.shape[3]) ) merged_value = sample_value.new_zeros( (layout.total_rows, sample_value.shape[1], max_prefix_len, sample_value.shape[3]) ) raw_layers.append((merged_key, merged_value)) for plan in layout.plans: runner = plan.runner prefix_len = runner.prefix_cache.prefix_len row_slice = slice(plan.row_start, plan.row_stop) prefix_lengths[row_slice] = prefix_len if prefix_len == 0: continue attention_mask[row_slice, :prefix_len] = 1 raw_cache = runner.expand_prefix_raw_cache(plan.row_count) for layer_idx, (key, value) in enumerate(raw_cache): merged_key, merged_value = raw_layers[layer_idx] merged_key[row_slice, :, :prefix_len, :] = key merged_value[row_slice, :, :prefix_len, :] = value return MixedPrefixCache( batch_size=layout.batch_size, total_rows=layout.total_rows, prefix_lengths=prefix_lengths, attention_mask=attention_mask.contiguous(), raw_cache=tuple(raw_layers), ) def _build_mixed_bucket( self, layout: BatchLayout, prefix_cache: MixedPrefixCache, continuation_len: int, ) -> MixedBucketRuntime: max_input_len = continuation_len + self.max_label_prefix_len total_len = prefix_cache.max_prefix_len + max_input_len input_ids = torch.full( (layout.total_rows, max_input_len), fill_value=self.tokenizer.pad_token_id, device=self.device, dtype=torch.long, ) attention_mask = torch.zeros((layout.total_rows, total_len), device=self.device, dtype=torch.long) if prefix_cache.max_prefix_len: attention_mask[:, : prefix_cache.max_prefix_len] = prefix_cache.attention_mask position_ids = ( prefix_cache.prefix_lengths[:, None] + torch.arange(max_input_len, device=self.device, dtype=torch.long).unsqueeze(0) ).contiguous() last_hidden_state = torch.empty( (layout.total_rows, max_input_len, self.hidden_size), device=self.device, dtype=self.dtype, ) bucket = MixedBucketRuntime( batch_size=layout.batch_size, total_rows=layout.total_rows, continuation_len=continuation_len, max_input_len=max_input_len, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, last_hidden_state=last_hidden_state, ) if self.cfg.cuda_graphs: self._capture_mixed_bucket(bucket, prefix_cache) return bucket @torch.inference_mode() def _run_mixed_backbone( self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache, ) -> None: cache = DynamicCache(ddp_cache_data=prefix_cache.raw_cache, config=self.model.config) with self._attn_context(): outputs = self.backbone( input_ids=bucket.input_ids, attention_mask=bucket.attention_mask, position_ids=bucket.position_ids, past_key_values=cache, use_cache=False, return_dict=True, ) bucket.last_hidden_state.copy_(outputs.last_hidden_state) def _capture_mixed_bucket(self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache) -> None: if not torch.cuda.is_available(): return try: torch.cuda.synchronize() stream = self.graph_capture_stream or torch.cuda.Stream(device=self.device) with torch.cuda.stream(stream): for _ in range(self.cfg.graph_warmups): self._run_mixed_backbone(bucket, prefix_cache) stream.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, pool=self.graph_capture_pool, stream=stream): self._run_mixed_backbone(bucket, prefix_cache) bucket.graph = graph except RuntimeError: bucket.graph = None def _prepare_bucket( self, layout: BatchLayout, prefix_cache: MixedPrefixCache, bucket: MixedBucketRuntime, query_ids_batch: list[list[int]], ) -> tuple[list[list[int]], dict[str, list[object]]]: del prefix_cache bucket.input_ids.fill_(self.tokenizer.pad_token_id) bucket.attention_mask.zero_() if self.mixed_prefix_caches[layout.batch_size].max_prefix_len: bucket.attention_mask[:, : self.mixed_prefix_caches[layout.batch_size].max_prefix_len] = ( self.mixed_prefix_caches[layout.batch_size].attention_mask ) continuation_lengths_per_task: dict[str, list[int]] = {} continuation_tokens_per_task: dict[str, list[list[int]]] = {} prefix_base = self.mixed_prefix_caches[layout.batch_size].max_prefix_len for plan in layout.plans: runner = plan.runner per_query_continuations = [runner.build_continuation_from_query_ids(query_ids) for query_ids in query_ids_batch] continuation_tokens_per_task[runner.task_cfg.name] = per_query_continuations continuation_lengths_per_task[runner.task_cfg.name] = [len(ids) for ids in per_query_continuations] if runner.fast_path: for batch_idx, continuation in enumerate(per_query_continuations): cont_len = len(continuation) row_idx = plan.row_start + batch_idx bucket.input_ids[row_idx, :cont_len] = torch.tensor(continuation, device=self.device, dtype=torch.long) bucket.attention_mask[row_idx, prefix_base : prefix_base + cont_len] = 1 continue assert runner.multi_token_tables is not None for batch_idx, continuation in enumerate(per_query_continuations): cont_len = len(continuation) row_start = plan.row_start + batch_idx * runner.num_labels row_stop = row_start + runner.num_labels row_slice = slice(row_start, row_stop) cont_tensor = torch.tensor(continuation, device=self.device, dtype=torch.long) bucket.input_ids[row_slice, :cont_len] = cont_tensor.unsqueeze(0).expand(runner.num_labels, -1) bucket.attention_mask[row_slice, prefix_base : prefix_base + cont_len] = 1 if runner.multi_token_tables.max_label_prefix_len: bucket.input_ids[ row_slice, cont_len : cont_len + runner.multi_token_tables.max_label_prefix_len, ] = runner.multi_token_tables.label_prefix_ids bucket.attention_mask[ row_slice, prefix_base + cont_len : prefix_base + cont_len + runner.multi_token_tables.max_label_prefix_len, ] = runner.multi_token_tables.label_prefix_mask return query_ids_batch, { "continuation_lengths_per_task": continuation_lengths_per_task, "continuation_tokens_per_task": continuation_tokens_per_task, } def _reduce_prompt_scores( self, layout: BatchLayout, bucket: MixedBucketRuntime, query_texts: list[str], prep_meta: dict[str, list[object]], shared_stage_ms: dict[str, float], ) -> list[MultiPromptScoreResult]: result_rows = [[] for _ in range(layout.batch_size)] prompt_reduce_total_ms = 0.0 for plan in layout.plans: runner = plan.runner continuation_lengths = prep_meta["continuation_lengths_per_task"][runner.task_cfg.name] reduce_start = time.perf_counter() if runner.fast_path: hidden_rows = [] row_slice = bucket.last_hidden_state[plan.row_start : plan.row_start + layout.batch_size] for batch_idx, cont_len in enumerate(continuation_lengths): hidden_rows.append(row_slice[batch_idx, cont_len - 1]) hidden = torch.stack(hidden_rows, dim=0) runner.reduce_fast_scores(hidden=hidden, out_scores=plan.score_buffer) stage_name = "tail_scorer" else: assert runner.multi_token_tables is not None score_positions = torch.stack( [ cont_len - 1 + runner.multi_token_tables.label_position_offsets for cont_len in continuation_lengths ], dim=0, ) runner.reduce_multi_token_scores( last_hidden_state=bucket.last_hidden_state[plan.row_start : plan.row_stop], batch_size=layout.batch_size, max_input_len=bucket.max_input_len, score_positions=score_positions, out_scores=plan.score_buffer, ) stage_name = "candidate_reduce" self._sync() reduce_end = time.perf_counter() reduce_ms = (reduce_end - reduce_start) * 1000.0 prompt_reduce_total_ms += reduce_ms for batch_idx, query in enumerate(query_texts): stage_ms = dict(shared_stage_ms) stage_ms[stage_name] = reduce_ms / layout.batch_size result_rows[batch_idx].append( runner.build_score_result( query=query, scores=plan.score_buffer[batch_idx], stage_ms=stage_ms, continuation_tokens=continuation_lengths[batch_idx], ) ) batch_total_ms = sum(shared_stage_ms.values()) + prompt_reduce_total_ms shared_plus_reduce = dict(shared_stage_ms) shared_plus_reduce["prompt_reduce_total"] = prompt_reduce_total_ms results: list[MultiPromptScoreResult] = [] for batch_idx, query in enumerate(query_texts): results.append( MultiPromptScoreResult( query=query, total_ms=batch_total_ms / layout.batch_size, details=result_rows[batch_idx], stage_ms={ **shared_plus_reduce, "per_query_total_estimate": batch_total_ms / layout.batch_size, }, ) ) return results @torch.inference_mode() def score_queries(self, queries: list[str]) -> BatchScoreResult: if not queries: raise ValueError("queries must not be empty") batch_size = len(queries) if batch_size not in self.batch_layouts: raise ValueError(f"batch size {batch_size} is not preloaded; configured batch_sizes={self.cfg.batch_sizes}") layout = self.batch_layouts[batch_size] prefix_cache = self.mixed_prefix_caches[batch_size] self._sync() t0 = time.perf_counter() query_ids_batch = [self.tokenizer.encode(query, add_special_tokens=False) for query in queries] self._sync() t1 = time.perf_counter() max_continuation_len = max( len(plan.runner.build_continuation_from_query_ids(query_ids)) for plan in layout.plans for query_ids in query_ids_batch ) picked_bucket = self._pick_bucket(max_continuation_len) bucket = self.mixed_buckets[(batch_size, picked_bucket)] _, prep_meta = self._prepare_bucket(layout, prefix_cache, bucket, query_ids_batch) self._sync() t2 = time.perf_counter() if bucket.graph is not None: bucket.graph.replay() else: self._run_mixed_backbone(bucket, prefix_cache) self._sync() t3 = time.perf_counter() shared_stage_ms = { "encode_queries_shared": (t1 - t0) * 1000.0, "prepare_batch_shared": (t2 - t1) * 1000.0, "backbone_shared": (t3 - t2) * 1000.0, } results = self._reduce_prompt_scores(layout, bucket, queries, prep_meta, shared_stage_ms) total_ms = sum(shared_stage_ms.values()) + results[0].stage_ms["prompt_reduce_total"] return BatchScoreResult( batch_size=batch_size, total_ms=total_ms, results=results, stage_ms={ **shared_stage_ms, "prompt_reduce_total": results[0].stage_ms["prompt_reduce_total"], }, ) def score_query(self, query: str) -> MultiPromptScoreResult: return self.score_queries([query]).results[0] def preload(self) -> PreloadReport: if self._preload_report is not None: return self._preload_report stage_ms: dict[str, float] = dict(self._init_stage_ms) start = time.perf_counter() self._sync() t0 = time.perf_counter() warmup_batch_sizes = self.cfg.warmup_batch_sizes or self.cfg.batch_sizes for batch_size in warmup_batch_sizes: queries = [self.cfg.warmup_query] * batch_size self._warmup_results[batch_size] = self.score_queries(queries) self._sync() t1 = time.perf_counter() stage_ms["warmup_end_to_end"] = (t1 - t0) * 1000.0 stage_ms["startup_total_before_warmup"] = self._init_total_ms total_ms = self._init_total_ms + (t1 - start) * 1000.0 runtime = self.preload_report() self._preload_report = PreloadReport(total_ms=total_ms, stage_ms=stage_ms, runtime=runtime) return self._preload_report def preload_report(self) -> dict[str, object]: return { "model_name": self.cfg.resolved_model_name, "model_source": self.cfg.resolved_model_source, "device": str(self.device), "dtype": self.cfg.dtype, "attn_backend": self.cfg.attn_backend, "execution_model": "single_mixed_backbone_per_batch", "num_tasks": len(self.runners), "task_names": [runner.task_cfg.name for runner in self.runners], "batch_sizes": list(self.cfg.batch_sizes), "continuation_buckets": list(self.cfg.continuation_buckets), "mixed_bucket_count": len(self.mixed_buckets), "captured_mixed_buckets": sum(bucket.graph is not None for bucket in self.mixed_buckets.values()), "all_configured_buckets_preloaded": True, "init_stage_ms": dict(self._init_stage_ms), "init_total_ms": self._init_total_ms, "force_single_token_labels": self.cfg.force_single_token_labels, "warmup_query": self.cfg.warmup_query, "tasks": [ { "task_name": runner.task_cfg.name, "fast_path": runner.fast_path, "num_labels": runner.num_labels, "label_token_lengths": {item.text: len(item.token_ids) for item in runner.encoded_labels}, "prefix_tokens": runner.prefix_cache.prefix_len, "prefix_hashes": runner.prefix_cache.prefix_hashes, "label_prefix": runner.task_cfg.label_prefix, } for runner in self.runners ], } 但是关于多token的处理方式是低效的、是错的,不要参考他,请你重新实现。 本地的.venv已经创建好,是复用llm-qp的,请使用该环境 有用的代码我已经引用过来了,请你基于需求,搜寻相关资料,专注于重新开发全新的版本。 任务较重,请耐心逐步完成,慢慢来,要长时间迭代一直到服务建设完成、测试完全符合预期。