issue-2026-04-06-推理优化-重建.md 44.8 KB

总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 先专注于推理的优化,最后再考虑服务化,支持一定程度的并发(比如4)的请求,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,使用N个prompt进行N个维度的分类,对于每个prompt,推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。

下面有一些参考技术资料,但是你并不需要严格,你应该有一定的灵活度,来追求极致的性能。

在 Tesla T4 上,用 3B 到 6B 级别的开源 decoder-only 基座模型做 query 分类。 启动时完成 tokenizer、权重、prefix cache 和共享执行器准备工作。 每次输入一个 query,输出每个 prompt 下每个 label 的分数分布,以及预测耗时和阶段耗时。 不走通用生成路径,不做 decode,不取 full vocab logits,不做 constrained decode。 对 multi-token label 做专门优化,避免 Python 侧串行 decode。 prompt 和 label 集合必须可配置(目前只有两个,以后我会加到8个,每次请求输入一个query,并行的调用8个prompt进行推理得到得分最高的label:

prompts:

  • name: category prompt_template: | system Analyze the category intent. Output exactly one label from [none, dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other]. Use 'none' if no category intent. user query: {query} assistant label: label_prefix: " " labels: [dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other, none]

  • name: audience prompt_template: | system Analyze the target user group. Output exactly one label from [none, boy, girl, man, woman, pregnant]. Use 'none' if no audience mentioned. user query: {query} assistant label: label_prefix: " " labels: [boy, girl, man, woman, pregnant, none]

做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟。

主要考虑优化方向为:

  1. hidden_last -> N-class scorer -> argmax
  2. 参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV
  3. 去 full vocab logits
  4. 去 decode / constrained decode
  5. 专用 tail kernel(输出 N 类原始分数)
  6. 配置的N个 prompt推理要并行推理(2-8个)
  7. 使用Tesla T4,因此不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。

你有sudo权限,你可以执行为本项目安装自己的环境

使用Qwen/Qwen3-8B的Q4或Q8模型,具体用哪个版本,请你查找huggingface相关资料,选择合适的版本完成部署,并进行推理耗时的测试。

请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。

一个重要的问题:一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况。 需要通过多 token 标签做极致的性能优化,避免串行decode。 我们最终目的是得到哪个label的得分最高,不一定要精确的概率,计算log P(id1 | query, prompt) + log P(id2 | query, prompt, id1)有可能导致难以优化性能,精确的概率是可以考虑放弃的,要清楚我们的最终目的,达到分类的目的即可,只要得到分类,优先考虑性能,精确的概率可以放下。 如何通过一次模型 forward处理包括多token label的整个 batch,是你需要探索的问题。

单 token fast path 的做法比较确定: last_hidden -> small class scorer -> argmax。 只取目标 label 对应 LM head 行,不做 full vocab 输出。 multi-token 怎么做需要搜索相关资料进行考量,最好要做到跟单token开销相同(放弃精确的log-prob的前提下。但是:多token和单token的label的打分的对比,一定要是可比的才能正确的分类,兼顾性能和打分的准确性)

还需要增加一个配置:force_single_token_labels,所有 label 都按首 token 处理,因为,如果各个label收token不同,那么可以近似的认为首token打分代表整个label打分。 你需要找到多label打分性能和准确性上面的最佳实践。同时也支持force_single_token_labels以达到极致的性能。

也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 以下是一些参考示例: vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq

TensorRT support matrix: T4 / SM7.5 supports FP16 and INT8, but not BF16/FP8 in the main matrix. https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Support-Matrix-Guide.pdf TensorRT-LLM support matrix: current official hardware list omits Turing/T4, so T4 is effectively community-support territory there. https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html FlashAttention repo: FA3 is Hopper-focused; current published benchmarks are A100/H100-centric. https://github.com/Dao-AILab/flash-attention vLLM APC docs: KV reuse via hashed KV blocks was the baseline idea for the prefix-cache metadata. https://docs.vllm.ai/_/downloads/en/v0.6.2/pdf/ SGLang HiCache/RadixAttention docs: useful reference for prefix-cache reuse and page-granular KV organization. https://docs.sglang.io/advanced_features/hicache_design.html FasterTransformer repo: still a useful T4 FP16 optimization baseline and historical Turing-oriented reference. https://github.com/NVIDIA/FasterTransformer xFormers README: relevant as a Turing-friendly attention alternative; my mainline choice here is PyTorch SDPA on T4, which is an engineering inference from these sources rather than a direct vendor recommendation. https://github.com/facebookresearch/xformers

注意:已经有一个项目 llm-qp, llm-qp2,这两个项目,对于单token的处理方式是可以的: SDPA prefix cache prebuilt bucket + CUDA graph 他的核心代码是: from future import annotations

import hashlib import time from dataclasses import asdict, dataclass from typing import Iterable

import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache

try: from transformers import BitsAndBytesConfig except ImportError: # pragma: no cover BitsAndBytesConfig = None

from llm_qp.config import PromptTaskConfig, RuntimeConfig from llm_qp.scorer import SmallClassScorer

try: from torch.nn.attention import SDPBackend, sdpa_kernel except ImportError: # pragma: no cover SDPBackend = None sdpa_kernel = None

@dataclass(slots=True) class EncodedLabel: text: str token_ids: list[int]

@dataclass(slots=True) class PrefixCache: prefix_ids: list[int] prefix_hashes: list[str] raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...]

@property
def prefix_len(self) -> int:
    return len(self.prefix_ids)

@dataclass(slots=True) class MultiTokenTables: label_token_ids: torch.Tensor label_token_mask: torch.Tensor label_prefix_ids: torch.Tensor label_prefix_mask: torch.Tensor label_position_offsets: torch.Tensor

@property
def max_label_len(self) -> int:
    return self.label_token_ids.shape[1]

@property
def max_label_prefix_len(self) -> int:
    return self.label_prefix_ids.shape[1]

@dataclass(slots=True) class QueryScoreResult: task_name: str query: str predicted_label: str scores: list[tuple[str, float, float]] total_ms: float stage_ms: dict[str, float] fast_path: bool prefix_tokens: int continuation_tokens: int label_token_lengths: dict[str, int]

@property
def predicted_prob(self) -> float:
    for label, _score, prob in self.scores:
        if label == self.predicted_label:
            return prob
    return 0.0

@dataclass(slots=True) class MultiPromptScoreResult: query: str total_ms: float details: list[QueryScoreResult] stage_ms: dict[str, float]

def http_json(self) -> dict[str, object]:
    return {
        "query": self.query,
        "total_ms": self.total_ms,
        "stage_ms": self.stage_ms,
        "details": [asdict(t) for t in self.details],
        "task_results": {
            t.task_name: [t.predicted_label, t.continuation_tokens, t.predicted_prob] for t in self.details if t.predicted_label != 'none'
        },
    }

@dataclass(slots=True) class BatchScoreResult: batch_size: int total_ms: float results: list[MultiPromptScoreResult] stage_ms: dict[str, float]

@dataclass(slots=True) class SharedRuntime: device: torch.device dtype: torch.dtype tokenizer: object model: object backbone: object hidden_size: int graph_capture_pool: object | None = None graph_capture_stream: torch.cuda.Stream | None = None

@dataclass(slots=True) class PromptBatchPlan: runner: "PromptClassifierRunner" row_start: int row_count: int score_buffer: torch.Tensor

@property
def row_stop(self) -> int:
    return self.row_start + self.row_count

@dataclass(slots=True) class MixedPrefixCache: batch_size: int total_rows: int prefix_lengths: torch.Tensor attention_mask: torch.Tensor raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...]

@property
def max_prefix_len(self) -> int:
    return int(self.prefix_lengths.max().item()) if self.prefix_lengths.numel() else 0

@dataclass(slots=True) class BatchLayout: batch_size: int total_rows: int plans: list[PromptBatchPlan]

@dataclass(slots=True) class MixedBucketRuntime: batch_size: int total_rows: int continuation_len: int max_input_len: int input_ids: torch.Tensor attention_mask: torch.Tensor position_ids: torch.Tensor last_hidden_state: torch.Tensor graph: torch.cuda.CUDAGraph | None = None

@dataclass(slots=True) class PreloadReport: total_ms: float stage_ms: dict[str, float] runtime: dict[str, object]

def hash_blocks(token_ids: Iterable[int], block_size: int) -> list[str]: token_list = list(token_ids) hashes: list[str] = [] for start in range(0, len(token_list), block_size): block = tokenlist[start : start + block_size] payload = ",".join(str(x) for x in block).encode("utf-8") hashes.append(hashlib.sha1(payload).hexdigest()) return hashes

def expand_legacy_cache( raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...], batch_size: int, ) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: expanded: list[tuple[torch.Tensor, torch.Tensor]] = [] for key, value in rawcache: expanded.append( ( key.expand(batch_size, *key.shape[1:]).contiguous(), value.expand(batch_size, *value.shape[1:]).contiguous(), ) ) return tuple(expanded)

class PromptClassifierRunner: def init( self, cfg: RuntimeConfig, task_cfg: PromptTaskConfig, shared_runtime: SharedRuntime, ): self.cfg = cfg self.task_cfg = task_cfg self.device = shared_runtime.device self.dtype = shared_runtime.dtype self.tokenizer = shared_runtime.tokenizer self.model = shared_runtime.model self.backbone = shared_runtime.backbone self.hidden_size = shared_runtime.hidden_size self.prefix_text, self.suffix_text = task_cfg.prompt_parts self.prefix_ids = self.tokenizer.encode(self.prefix_text, add_special_tokens=False) self.suffix_ids = self.tokenizer.encode(self.suffix_text, add_special_tokens=False) self.labels = list(task_cfg.labels) self.encoded_labels = [ EncodedLabel(text=label, token_ids=self.encode_label_token_ids(label)) for label in self.labels ] self.num_labels = len(self.labels) self.lm_head = self.model.get_output_embeddings() self.lm_head_weight = self.lm_head.weight.detach() self.lm_head_bias = self.lm_head.bias.detach() if getattr(self.lm_head, "bias", None) is not None else None if self.cfg.force_single_token_labels and not self._has_unique_single_token_labels(): raise ValueError( f"prompt task '{self.taskcfg.name}' cannot force single-token labels because first tokens collide" ) self.fast_path = self.has_unique_single_token_labels() self.fast_path_token_ids = [item.token_ids[0] for item in self.encoded_labels] if self.fast_path else [] self.scorer = self._build_scorer() if self.fastpath else None self.multi_token_tables = self.build_multi_token_tables() if not self.fast_path else None self.prefixcache = self.build_prefixcache()

def _encode_label_token_ids(self, label: str) -> list[int]:
    token_ids = self.tokenizer.encode(
        f"{self.task_cfg.label_prefix}{label}",
        add_special_tokens=False,
    )
    if not token_ids:
        raise ValueError(f"label '{label}' in prompt '{self.task_cfg.name}' tokenizes to an empty sequence")
    if self.cfg.force_single_token_labels:
        return token_ids[:1]
    return token_ids

def _has_unique_single_token_labels(self) -> bool:
    token_ids: list[int] = []
    for item in self.encoded_labels:
        if len(item.token_ids) != 1:
            return False
        token_ids.append(item.token_ids[0])
    return len(token_ids) == len(set(token_ids))

def _build_scorer(self) -> SmallClassScorer:
    token_ids = torch.tensor(self.fast_path_token_ids, dtype=torch.long, device=self.device)
    weights = torch.index_select(self.lm_head_weight, 0, token_ids).to(dtype=self.dtype).contiguous()
    bias = None
    if self.lm_head_bias is not None:
        bias = torch.index_select(self.lm_head_bias, 0, token_ids).to(dtype=self.dtype).contiguous()
    return SmallClassScorer(weights=weights, bias=bias)

def _build_multi_token_tables(self) -> MultiTokenTables:
    max_label_len = max(len(item.token_ids) for item in self.encoded_labels)
    max_prefix_len = max(len(item.token_ids) - 1 for item in self.encoded_labels)
    label_token_ids = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.long)
    label_token_mask = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.float32)
    label_prefix_ids = torch.full(
        (self.num_labels, max_prefix_len),
        fill_value=self.tokenizer.pad_token_id,
        device=self.device,
        dtype=torch.long,
    )
    label_prefix_mask = torch.zeros((self.num_labels, max_prefix_len), device=self.device, dtype=torch.long)
    for idx, item in enumerate(self.encoded_labels):
        token_ids = torch.tensor(item.token_ids, device=self.device, dtype=torch.long)
        token_len = token_ids.numel()
        label_token_ids[idx, :token_len] = token_ids
        label_token_mask[idx, :token_len] = 1.0
        if token_len > 1:
            prefix_len = token_len - 1
            label_prefix_ids[idx, :prefix_len] = token_ids[:-1]
            label_prefix_mask[idx, :prefix_len] = 1
    return MultiTokenTables(
        label_token_ids=label_token_ids.contiguous(),
        label_token_mask=label_token_mask.contiguous(),
        label_prefix_ids=label_prefix_ids.contiguous(),
        label_prefix_mask=label_prefix_mask.contiguous(),
        label_position_offsets=torch.arange(max_label_len, device=self.device, dtype=torch.long),
    )

@torch.inference_mode()
def _build_prefix_cache(self) -> PrefixCache:
    if not self.prefix_ids:
        return PrefixCache(prefix_ids=[], prefix_hashes=[], raw_cache=tuple())
    prefix_tensor = torch.tensor([self.prefix_ids], dtype=torch.long, device=self.device)
    attention_mask = torch.ones_like(prefix_tensor, dtype=torch.long, device=self.device)
    outputs = self.model(
        input_ids=prefix_tensor,
        attention_mask=attention_mask,
        use_cache=True,
        return_dict=True,
    )
    raw_cache = tuple(
        (layer.keys.detach(), layer.values.detach())
        for layer in outputs.past_key_values.layers
    )
    return PrefixCache(
        prefix_ids=list(self.prefix_ids),
        prefix_hashes=_hash_blocks(self.prefix_ids, self.cfg.prefix_block_size),
        raw_cache=raw_cache,
    )

def expand_prefix_raw_cache(self, batch_size: int) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
    if not self.prefix_cache.raw_cache:
        return tuple()
    return _expand_legacy_cache(self.prefix_cache.raw_cache, batch_size)

def build_continuation_from_query_ids(self, query_ids: list[int]) -> list[int]:
    continuation = query_ids + self.suffix_ids
    if not continuation:
        raise ValueError("prompt continuation is empty after substituting query")
    if self.prefix_cache.prefix_len + len(continuation) > self.cfg.max_length:
        raise ValueError(
            f"sequence length {self.prefix_cache.prefix_len + len(continuation)} exceeds max_length={self.cfg.max_length}"
        )
    return continuation

@torch.inference_mode()
def reduce_fast_scores(
    self,
    hidden: torch.Tensor,
    out_scores: torch.Tensor,
) -> None:
    assert self.scorer is not None
    out_scores.copy_(self.scorer(hidden))

@torch.inference_mode()
def reduce_multi_token_scores(
    self,
    last_hidden_state: torch.Tensor,
    batch_size: int,
    max_input_len: int,
    score_positions: torch.Tensor,
    out_scores: torch.Tensor,
) -> None:
    assert self.multi_token_tables is not None
    hidden = last_hidden_state.reshape(batch_size, self.num_labels, max_input_len, self.hidden_size)
    gather_index = score_positions[:, None, :, None].expand(
        batch_size,
        self.num_labels,
        self.multi_token_tables.max_label_len,
        self.hidden_size,
    )
    gathered_hidden = torch.gather(hidden, 2, gather_index)
    used_mask = self.multi_token_tables.label_token_mask.unsqueeze(0).expand(batch_size, -1, -1).bool()

    token_log_probs = torch.zeros(
        (batch_size, self.num_labels, self.multi_token_tables.max_label_len),
        device=self.device,
        dtype=torch.float32,
    )
    if used_mask.any():
        flat_hidden = gathered_hidden[used_mask]
        flat_token_ids = self.multi_token_tables.label_token_ids.unsqueeze(0).expand(batch_size, -1, -1)[used_mask]
        linear_hidden = flat_hidden.to(self.dtype) if self.device.type == "cuda" else flat_hidden.float()
        linear_weight = self.lm_head_weight if self.device.type == "cuda" else self.lm_head_weight.float()
        linear_bias = self.lm_head_bias
        if linear_bias is not None and self.device.type != "cuda":
            linear_bias = linear_bias.float()
        flat_logits = F.linear(linear_hidden, linear_weight, linear_bias)
        flat_selected = flat_logits.gather(1, flat_token_ids.unsqueeze(1)).squeeze(1).float()
        flat_log_norm = torch.logsumexp(flat_logits.float(), dim=-1)
        token_log_probs[used_mask] = flat_selected - flat_log_norm
    out_scores.copy_(token_log_probs.sum(dim=-1))

def build_score_result(
    self,
    query: str,
    scores: torch.Tensor,
    stage_ms: dict[str, float],
    continuation_tokens: int,
) -> QueryScoreResult:
    score_values = scores.detach().float().cpu().tolist()
    best_idx = max(range(len(score_values)), key=score_values.__getitem__)
    probs = torch.softmax(torch.tensor(score_values, dtype=torch.float32), dim=0).tolist()
    return QueryScoreResult(
        task_name=self.task_cfg.name,
        query=query,
        predicted_label=self.labels[best_idx],
        scores=[
            (label, score, prob)
            for label, score, prob in zip(self.labels, score_values, probs, strict=True)
        ],
        total_ms=sum(stage_ms.values()),
        stage_ms=stage_ms,
        fast_path=self.fast_path,
        prefix_tokens=self.prefix_cache.prefix_len,
        continuation_tokens=continuation_tokens,
        label_token_lengths={item.text: len(item.token_ids) for item in self.encoded_labels},
    )

class MultiPromptRunner: def init(self, cfg: RuntimeConfig): self.cfg = cfg t0 = time.perf_counter() self.shared_runtime = self.build_shared_runtime(cfg) t1 = time.perf_counter() self.device = self.shared_runtime.device self.dtype = self.shared_runtime.dtype self.tokenizer = self.shared_runtime.tokenizer self.model = self.shared_runtime.model self.backbone = self.shared_runtime.backbone self.hidden_size = self.shared_runtime.hidden_size self.graph_capture_pool = self.shared_runtime.graph_capture_pool self.graph_capture_stream = self.shared_runtime.graph_capture_stream self.runners = [ PromptClassifierRunner(cfg=cfg, task_cfg=task_cfg, shared_runtime=self.shared_runtime) for task_cfg in cfg.tasks ] t2 = time.perf_counter() self.batch_layouts = {batch_size: self.build_batch_layout(batch_size) for batch_size in self.cfg.batch_sizes} t3 = time.perf_counter() self.mixed_prefix_caches = { batch_size: self.build_mixed_prefix_cache(self.batch_layouts[batch_size]) for batch_size in self.cfg.batch_sizes } t4 = time.perf_counter() self.max_label_prefix_len = max( (runner.multi_token_tables.max_label_prefix_len if runner.multi_token_tables is not None else 0) for runner in self.runners ) self.mixedbuckets = { (batch_size, continuation_len): self.build_mixed_bucket( self.batch_layouts[batch_size], self.mixedprefix_caches[batch_size], continuationlen, ) for batch_size in self.cfg.batch_sizes for continuation_len in self.cfg.continuation_buckets } t5 = time.perf_counter() self.warmup_results: dict[int, BatchScoreResult] = {} self.preload_report: PreloadReport | None = None self._init_stage_ms = { "load_model_and_tokenizer": (t1 - t0) * 1000.0, "buildprompt_runtimes": (t2 - t1) * 1000.0, "buildbatch_layouts": (t3 - t2) * 1000.0, "build_mixed_prefix_caches": (t4 - t3) * 1000.0, "build_mixed_buckets_and_graphs": (t5 - t4) * 1000.0, } self.init_total_ms = sum(self.init_stage_ms.values())

@staticmethod
def build_shared_runtime(cfg: RuntimeConfig) -> SharedRuntime:
    device = torch.device(cfg.device)
    dtype = torch.float16
    tokenizer = AutoTokenizer.from_pretrained(
        cfg.resolved_model_source,
        trust_remote_code=cfg.resolved_trust_remote_code,
        token=cfg.hf_token,
        cache_dir=cfg.hf_cache_dir,
        local_files_only=cfg.resolved_local_files_only,
    )
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    attn_impl = MultiPromptRunner._resolve_attn_impl(cfg.attn_backend)
    quantization_config = None
    model_kwargs: dict[str, object] = {
        "trust_remote_code": cfg.resolved_trust_remote_code,
        "attn_implementation": attn_impl,
        "token": cfg.hf_token,
        "cache_dir": cfg.hf_cache_dir,
        "local_files_only": cfg.resolved_local_files_only,
    }
    if cfg.load_in_4bit:
        if BitsAndBytesConfig is None:
            raise ImportError("transformers BitsAndBytesConfig is unavailable; install bitsandbytes support first")
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=dtype,
            bnb_4bit_quant_type=cfg.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=cfg.bnb_4bit_use_double_quant,
        )
        model_kwargs["quantization_config"] = quantization_config
        model_kwargs["device_map"] = {"": device.index or 0}
    else:
        model_kwargs["dtype"] = dtype
        model_kwargs["device_map"] = None
    model = AutoModelForCausalLM.from_pretrained(
        cfg.resolved_model_source,
        **model_kwargs,
    ).eval()
    if not cfg.load_in_4bit:
        model = model.to(device)
    backbone = model.get_submodule(model.base_model_prefix)
    hidden_size = model.get_output_embeddings().weight.shape[1]
    graph_capture_pool = None
    graph_capture_stream = None
    if device.type == "cuda" and torch.cuda.is_available() and cfg.cuda_graphs and not cfg.load_in_4bit:
        graph_capture_pool = torch.cuda.graph_pool_handle()
        graph_capture_stream = torch.cuda.Stream(device=device)
    return SharedRuntime(
        device=device,
        dtype=dtype,
        tokenizer=tokenizer,
        model=model,
        backbone=backbone,
        hidden_size=hidden_size,
        graph_capture_pool=graph_capture_pool,
        graph_capture_stream=graph_capture_stream,
    )

@staticmethod
def _resolve_attn_impl(requested: str) -> str:
    if requested in {"sdpa", "eager"}:
        return requested
    if requested == "auto":
        return "sdpa"
    raise ValueError(f"unsupported attn_backend: {requested}")

def _attn_context(self):
    if sdpa_kernel is not None and self.cfg.attn_backend in {"auto", "sdpa"}:
        return sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH])
    return torch.no_grad()

def _sync(self) -> None:
    if self.device.type == "cuda":
        torch.cuda.synchronize()

def _pick_bucket(self, continuation_len: int) -> int:
    for bucket in self.cfg.continuation_buckets:
        if continuation_len <= bucket:
            return bucket
    if self.cfg.pad_to_bucket:
        raise ValueError(
            f"continuation length {continuation_len} exceeds configured buckets; extend continuation_buckets"
        )
    return continuation_len

def _build_batch_layout(self, batch_size: int) -> BatchLayout:
    plans: list[PromptBatchPlan] = []
    row_start = 0
    for runner in self.runners:
        row_count = batch_size if runner.fast_path else batch_size * runner.num_labels
        plans.append(
            PromptBatchPlan(
                runner=runner,
                row_start=row_start,
                row_count=row_count,
                score_buffer=torch.empty((batch_size, runner.num_labels), device=self.device, dtype=torch.float32),
            )
        )
        row_start += row_count
    return BatchLayout(batch_size=batch_size, total_rows=row_start, plans=plans)

def _build_mixed_prefix_cache(self, layout: BatchLayout) -> MixedPrefixCache:
    prefix_lengths = torch.zeros((layout.total_rows,), device=self.device, dtype=torch.long)
    non_empty = [plan.runner.prefix_cache.raw_cache for plan in layout.plans if plan.runner.prefix_cache.raw_cache]
    if not non_empty:
        return MixedPrefixCache(
            batch_size=layout.batch_size,
            total_rows=layout.total_rows,
            prefix_lengths=prefix_lengths,
            attention_mask=torch.zeros((layout.total_rows, 0), device=self.device, dtype=torch.long),
            raw_cache=tuple(),
        )

    max_prefix_len = max(plan.runner.prefix_cache.prefix_len for plan in layout.plans)
    num_layers = len(non_empty[0])
    attention_mask = torch.zeros((layout.total_rows, max_prefix_len), device=self.device, dtype=torch.long)
    raw_layers: list[tuple[torch.Tensor, torch.Tensor]] = []
    for layer_idx in range(num_layers):
        sample_key, sample_value = non_empty[0][layer_idx]
        merged_key = sample_key.new_zeros(
            (layout.total_rows, sample_key.shape[1], max_prefix_len, sample_key.shape[3])
        )
        merged_value = sample_value.new_zeros(
            (layout.total_rows, sample_value.shape[1], max_prefix_len, sample_value.shape[3])
        )
        raw_layers.append((merged_key, merged_value))

    for plan in layout.plans:
        runner = plan.runner
        prefix_len = runner.prefix_cache.prefix_len
        row_slice = slice(plan.row_start, plan.row_stop)
        prefix_lengths[row_slice] = prefix_len
        if prefix_len == 0:
            continue
        attention_mask[row_slice, :prefix_len] = 1
        raw_cache = runner.expand_prefix_raw_cache(plan.row_count)
        for layer_idx, (key, value) in enumerate(raw_cache):
            merged_key, merged_value = raw_layers[layer_idx]
            merged_key[row_slice, :, :prefix_len, :] = key
            merged_value[row_slice, :, :prefix_len, :] = value

    return MixedPrefixCache(
        batch_size=layout.batch_size,
        total_rows=layout.total_rows,
        prefix_lengths=prefix_lengths,
        attention_mask=attention_mask.contiguous(),
        raw_cache=tuple(raw_layers),
    )

def _build_mixed_bucket(
    self,
    layout: BatchLayout,
    prefix_cache: MixedPrefixCache,
    continuation_len: int,
) -> MixedBucketRuntime:
    max_input_len = continuation_len + self.max_label_prefix_len
    total_len = prefix_cache.max_prefix_len + max_input_len
    input_ids = torch.full(
        (layout.total_rows, max_input_len),
        fill_value=self.tokenizer.pad_token_id,
        device=self.device,
        dtype=torch.long,
    )
    attention_mask = torch.zeros((layout.total_rows, total_len), device=self.device, dtype=torch.long)
    if prefix_cache.max_prefix_len:
        attention_mask[:, : prefix_cache.max_prefix_len] = prefix_cache.attention_mask
    position_ids = (
        prefix_cache.prefix_lengths[:, None]
        + torch.arange(max_input_len, device=self.device, dtype=torch.long).unsqueeze(0)
    ).contiguous()
    last_hidden_state = torch.empty(
        (layout.total_rows, max_input_len, self.hidden_size),
        device=self.device,
        dtype=self.dtype,
    )
    bucket = MixedBucketRuntime(
        batch_size=layout.batch_size,
        total_rows=layout.total_rows,
        continuation_len=continuation_len,
        max_input_len=max_input_len,
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        last_hidden_state=last_hidden_state,
    )
    if self.cfg.cuda_graphs:
        self._capture_mixed_bucket(bucket, prefix_cache)
    return bucket

@torch.inference_mode()
def _run_mixed_backbone(
    self,
    bucket: MixedBucketRuntime,
    prefix_cache: MixedPrefixCache,
) -> None:
    cache = DynamicCache(ddp_cache_data=prefix_cache.raw_cache, config=self.model.config)
    with self._attn_context():
        outputs = self.backbone(
            input_ids=bucket.input_ids,
            attention_mask=bucket.attention_mask,
            position_ids=bucket.position_ids,
            past_key_values=cache,
            use_cache=False,
            return_dict=True,
        )
    bucket.last_hidden_state.copy_(outputs.last_hidden_state)

def _capture_mixed_bucket(self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache) -> None:
    if not torch.cuda.is_available():
        return
    try:
        torch.cuda.synchronize()
        stream = self.graph_capture_stream or torch.cuda.Stream(device=self.device)
        with torch.cuda.stream(stream):
            for _ in range(self.cfg.graph_warmups):
                self._run_mixed_backbone(bucket, prefix_cache)
        stream.synchronize()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, pool=self.graph_capture_pool, stream=stream):
            self._run_mixed_backbone(bucket, prefix_cache)
        bucket.graph = graph
    except RuntimeError:
        bucket.graph = None

def _prepare_bucket(
    self,
    layout: BatchLayout,
    prefix_cache: MixedPrefixCache,
    bucket: MixedBucketRuntime,
    query_ids_batch: list[list[int]],
) -> tuple[list[list[int]], dict[str, list[object]]]:
    del prefix_cache
    bucket.input_ids.fill_(self.tokenizer.pad_token_id)
    bucket.attention_mask.zero_()
    if self.mixed_prefix_caches[layout.batch_size].max_prefix_len:
        bucket.attention_mask[:, : self.mixed_prefix_caches[layout.batch_size].max_prefix_len] = (
            self.mixed_prefix_caches[layout.batch_size].attention_mask
        )
    continuation_lengths_per_task: dict[str, list[int]] = {}
    continuation_tokens_per_task: dict[str, list[list[int]]] = {}
    prefix_base = self.mixed_prefix_caches[layout.batch_size].max_prefix_len
    for plan in layout.plans:
        runner = plan.runner
        per_query_continuations = [runner.build_continuation_from_query_ids(query_ids) for query_ids in query_ids_batch]
        continuation_tokens_per_task[runner.task_cfg.name] = per_query_continuations
        continuation_lengths_per_task[runner.task_cfg.name] = [len(ids) for ids in per_query_continuations]
        if runner.fast_path:
            for batch_idx, continuation in enumerate(per_query_continuations):
                cont_len = len(continuation)
                row_idx = plan.row_start + batch_idx
                bucket.input_ids[row_idx, :cont_len] = torch.tensor(continuation, device=self.device, dtype=torch.long)
                bucket.attention_mask[row_idx, prefix_base : prefix_base + cont_len] = 1
            continue

        assert runner.multi_token_tables is not None
        for batch_idx, continuation in enumerate(per_query_continuations):
            cont_len = len(continuation)
            row_start = plan.row_start + batch_idx * runner.num_labels
            row_stop = row_start + runner.num_labels
            row_slice = slice(row_start, row_stop)
            cont_tensor = torch.tensor(continuation, device=self.device, dtype=torch.long)
            bucket.input_ids[row_slice, :cont_len] = cont_tensor.unsqueeze(0).expand(runner.num_labels, -1)
            bucket.attention_mask[row_slice, prefix_base : prefix_base + cont_len] = 1
            if runner.multi_token_tables.max_label_prefix_len:
                bucket.input_ids[
                    row_slice,
                    cont_len : cont_len + runner.multi_token_tables.max_label_prefix_len,
                ] = runner.multi_token_tables.label_prefix_ids
                bucket.attention_mask[
                    row_slice,
                    prefix_base + cont_len : prefix_base + cont_len + runner.multi_token_tables.max_label_prefix_len,
                ] = runner.multi_token_tables.label_prefix_mask
    return query_ids_batch, {
        "continuation_lengths_per_task": continuation_lengths_per_task,
        "continuation_tokens_per_task": continuation_tokens_per_task,
    }

def _reduce_prompt_scores(
    self,
    layout: BatchLayout,
    bucket: MixedBucketRuntime,
    query_texts: list[str],
    prep_meta: dict[str, list[object]],
    shared_stage_ms: dict[str, float],
) -> list[MultiPromptScoreResult]:
    result_rows = [[] for _ in range(layout.batch_size)]
    prompt_reduce_total_ms = 0.0
    for plan in layout.plans:
        runner = plan.runner
        continuation_lengths = prep_meta["continuation_lengths_per_task"][runner.task_cfg.name]
        reduce_start = time.perf_counter()
        if runner.fast_path:
            hidden_rows = []
            row_slice = bucket.last_hidden_state[plan.row_start : plan.row_start + layout.batch_size]
            for batch_idx, cont_len in enumerate(continuation_lengths):
                hidden_rows.append(row_slice[batch_idx, cont_len - 1])
            hidden = torch.stack(hidden_rows, dim=0)
            runner.reduce_fast_scores(hidden=hidden, out_scores=plan.score_buffer)
            stage_name = "tail_scorer"
        else:
            assert runner.multi_token_tables is not None
            score_positions = torch.stack(
                [
                    cont_len - 1 + runner.multi_token_tables.label_position_offsets
                    for cont_len in continuation_lengths
                ],
                dim=0,
            )
            runner.reduce_multi_token_scores(
                last_hidden_state=bucket.last_hidden_state[plan.row_start : plan.row_stop],
                batch_size=layout.batch_size,
                max_input_len=bucket.max_input_len,
                score_positions=score_positions,
                out_scores=plan.score_buffer,
            )
            stage_name = "candidate_reduce"
        self._sync()
        reduce_end = time.perf_counter()
        reduce_ms = (reduce_end - reduce_start) * 1000.0
        prompt_reduce_total_ms += reduce_ms
        for batch_idx, query in enumerate(query_texts):
            stage_ms = dict(shared_stage_ms)
            stage_ms[stage_name] = reduce_ms / layout.batch_size
            result_rows[batch_idx].append(
                runner.build_score_result(
                    query=query,
                    scores=plan.score_buffer[batch_idx],
                    stage_ms=stage_ms,
                    continuation_tokens=continuation_lengths[batch_idx],
                )
            )

    batch_total_ms = sum(shared_stage_ms.values()) + prompt_reduce_total_ms
    shared_plus_reduce = dict(shared_stage_ms)
    shared_plus_reduce["prompt_reduce_total"] = prompt_reduce_total_ms
    results: list[MultiPromptScoreResult] = []
    for batch_idx, query in enumerate(query_texts):
        results.append(
            MultiPromptScoreResult(
                query=query,
                total_ms=batch_total_ms / layout.batch_size,
                details=result_rows[batch_idx],
                stage_ms={
                    **shared_plus_reduce,
                    "per_query_total_estimate": batch_total_ms / layout.batch_size,
                },
            )
        )
    return results

@torch.inference_mode()
def score_queries(self, queries: list[str]) -> BatchScoreResult:
    if not queries:
        raise ValueError("queries must not be empty")
    batch_size = len(queries)
    if batch_size not in self.batch_layouts:
        raise ValueError(f"batch size {batch_size} is not preloaded; configured batch_sizes={self.cfg.batch_sizes}")
    layout = self.batch_layouts[batch_size]
    prefix_cache = self.mixed_prefix_caches[batch_size]

    self._sync()
    t0 = time.perf_counter()
    query_ids_batch = [self.tokenizer.encode(query, add_special_tokens=False) for query in queries]
    self._sync()
    t1 = time.perf_counter()

    max_continuation_len = max(
        len(plan.runner.build_continuation_from_query_ids(query_ids))
        for plan in layout.plans
        for query_ids in query_ids_batch
    )
    picked_bucket = self._pick_bucket(max_continuation_len)
    bucket = self.mixed_buckets[(batch_size, picked_bucket)]
    _, prep_meta = self._prepare_bucket(layout, prefix_cache, bucket, query_ids_batch)
    self._sync()
    t2 = time.perf_counter()

    if bucket.graph is not None:
        bucket.graph.replay()
    else:
        self._run_mixed_backbone(bucket, prefix_cache)
    self._sync()
    t3 = time.perf_counter()

    shared_stage_ms = {
        "encode_queries_shared": (t1 - t0) * 1000.0,
        "prepare_batch_shared": (t2 - t1) * 1000.0,
        "backbone_shared": (t3 - t2) * 1000.0,
    }
    results = self._reduce_prompt_scores(layout, bucket, queries, prep_meta, shared_stage_ms)
    total_ms = sum(shared_stage_ms.values()) + results[0].stage_ms["prompt_reduce_total"]
    return BatchScoreResult(
        batch_size=batch_size,
        total_ms=total_ms,
        results=results,
        stage_ms={
            **shared_stage_ms,
            "prompt_reduce_total": results[0].stage_ms["prompt_reduce_total"],
        },
    )

def score_query(self, query: str) -> MultiPromptScoreResult:
    return self.score_queries([query]).results[0]

def preload(self) -> PreloadReport:
    if self._preload_report is not None:
        return self._preload_report
    stage_ms: dict[str, float] = dict(self._init_stage_ms)
    start = time.perf_counter()
    self._sync()
    t0 = time.perf_counter()
    warmup_batch_sizes = self.cfg.warmup_batch_sizes or self.cfg.batch_sizes
    for batch_size in warmup_batch_sizes:
        queries = [self.cfg.warmup_query] * batch_size
        self._warmup_results[batch_size] = self.score_queries(queries)
    self._sync()
    t1 = time.perf_counter()
    stage_ms["warmup_end_to_end"] = (t1 - t0) * 1000.0
    stage_ms["startup_total_before_warmup"] = self._init_total_ms
    total_ms = self._init_total_ms + (t1 - start) * 1000.0
    runtime = self.preload_report()
    self._preload_report = PreloadReport(total_ms=total_ms, stage_ms=stage_ms, runtime=runtime)
    return self._preload_report

def preload_report(self) -> dict[str, object]:
    return {
        "model_name": self.cfg.resolved_model_name,
        "model_source": self.cfg.resolved_model_source,
        "device": str(self.device),
        "dtype": self.cfg.dtype,
        "attn_backend": self.cfg.attn_backend,
        "execution_model": "single_mixed_backbone_per_batch",
        "num_tasks": len(self.runners),
        "task_names": [runner.task_cfg.name for runner in self.runners],
        "batch_sizes": list(self.cfg.batch_sizes),
        "continuation_buckets": list(self.cfg.continuation_buckets),
        "mixed_bucket_count": len(self.mixed_buckets),
        "captured_mixed_buckets": sum(bucket.graph is not None for bucket in self.mixed_buckets.values()),
        "all_configured_buckets_preloaded": True,
        "init_stage_ms": dict(self._init_stage_ms),
        "init_total_ms": self._init_total_ms,
        "force_single_token_labels": self.cfg.force_single_token_labels,
        "warmup_query": self.cfg.warmup_query,
        "tasks": [
            {
                "task_name": runner.task_cfg.name,
                "fast_path": runner.fast_path,
                "num_labels": runner.num_labels,
                "label_token_lengths": {item.text: len(item.token_ids) for item in runner.encoded_labels},
                "prefix_tokens": runner.prefix_cache.prefix_len,
                "prefix_hashes": runner.prefix_cache.prefix_hashes,
                "label_prefix": runner.task_cfg.label_prefix,
            }
            for runner in self.runners
        ],
    }

但是关于多token的处理方式是低效的、是错的,不要参考他,请你重新实现。 本地的.venv已经创建好,是复用llm-qp的,请使用该环境 有用的代码我已经引用过来了,请你基于需求,搜寻相关资料,专注于重新开发全新的版本。 任务较重,请耐心逐步完成,慢慢来,要长时间迭代一直到服务建设完成、测试完全符合预期。