总体需求,是基于Tesla T4 GPU,用开源的基座大模型(3-6b级别),做query分类。prompt和分类词可配置。 先专注于推理的优化,最后再考虑服务化,支持一定程度的并发(比如4)的请求,在程序启动之初,确保所有该加载的全部加载好,不要做任何懒加载,确保真实请求发生时得到极致的响应时间。cli每得到一个query输入,使用N个prompt进行N个维度的分类,对于每个prompt,推理输出各个label的打分,以及预测的总耗时(如果能输出各个阶段的耗时更好,如果好做就支持一下)。
下面有一些参考技术资料,但是你并不需要严格,你应该有一定的灵活度,来追求极致的性能。
在 Tesla T4 上,用 3B 到 6B 级别的开源 decoder-only 基座模型做 query 分类。 启动时完成 tokenizer、权重、prefix cache 和共享执行器准备工作。 每次输入一个 query,输出每个 prompt 下每个 label 的分数分布,以及预测耗时和阶段耗时。 不走通用生成路径,不做 decode,不取 full vocab logits,不做 constrained decode。 对 multi-token label 做专门优化,避免 Python 侧串行 decode。 prompt 和 label 集合必须可配置(目前只有两个,以后我会加到8个,每次请求输入一个query,并行的调用8个prompt进行推理得到得分最高的label:
prompts:
name: category prompt_template: | system Analyze the category intent. Output exactly one label from [none, dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other]. Use 'none' if no category intent. user query: {query} assistant label: label_prefix: " " labels: [dress, jeans, shirt, trench, skirt, tee, hoodie, knit, other, none]
name: audience prompt_template: | system Analyze the target user group. Output exactly one label from [none, boy, girl, man, woman, pregnant]. Use 'none' if no audience mentioned. user query: {query} assistant label: label_prefix: " " labels: [boy, girl, man, woman, pregnant, none]
做专用执行路径,不是在通用生成引擎上做配置优化,追求绝对最低的分类延迟。
主要考虑优化方向为:
- hidden_last -> N-class scorer -> argmax
- 参考vLLM 的自动 prefix caching 复用相同前缀,并基于 block hash 管理 KV
- 去 full vocab logits
- 去 decode / constrained decode
- 专用 tail kernel(输出 N 类原始分数)
- 配置的N个 prompt推理要并行推理(2-8个)
- 使用Tesla T4,因此不用 FlashAttention-3 作为主路径。选用testla T4上最佳的attention 多搜集资料,参考开源的适合于tesla T4的推理优化项目和能用于T4的推理优化实践。包括但不限于Python/C++ runtime 与 TensorRT engine 的工具包;注意硬件支持矩阵覆盖 Turing/T4。注意多参考Triton / CUDA C++针对这类单步decode且固定词表打分获取的极致优化方法及其开源项目,找到合适的baseline并进行针对性的开发。
你有sudo权限,你可以执行为本项目安装自己的环境
使用Qwen/Qwen3-8B的Q4或Q8模型,具体用哪个版本,请你查找huggingface相关资料,选择合适的版本完成部署,并进行推理耗时的测试。
请深度分析各阶段耗时,继续查找相关资料,看是否在性能上面做到极致。
一个重要的问题:一些分类词并不是单token(虽然看起来是一个单词),所以,要考虑一些分类词并不是单token的情况。 需要通过多 token 标签做极致的性能优化,避免串行decode。 我们最终目的是得到哪个label的得分最高,不一定要精确的概率,计算log P(id1 | query, prompt) + log P(id2 | query, prompt, id1)有可能导致难以优化性能,精确的概率是可以考虑放弃的,要清楚我们的最终目的,达到分类的目的即可,只要得到分类,优先考虑性能,精确的概率可以放下。 如何通过一次模型 forward处理包括多token label的整个 batch,是你需要探索的问题。
单 token fast path 的做法比较确定: last_hidden -> small class scorer -> argmax。 只取目标 label 对应 LM head 行,不做 full vocab 输出。 multi-token 怎么做需要搜索相关资料进行考量,最好要做到跟单token开销相同(放弃精确的log-prob的前提下。但是:多token和单token的label的打分的对比,一定要是可比的才能正确的分类,兼顾性能和打分的准确性)
还需要增加一个配置:force_single_token_labels,所有 label 都按首 token 处理,因为,如果各个label收token不同,那么可以近似的认为首token打分代表整个label打分。 你需要找到多label打分性能和准确性上面的最佳实践。同时也支持force_single_token_labels以达到极致的性能。
也请你仔细搜寻相关资料,特别是技术框架所用到的Triton / Ollama / CUDA C++ 在该场景上的最佳实践,进行实践,找到在T4上面query分类需求的sota、做到极致的性能优化。 以下是一些参考示例: vLLM Automatic Prefix Caching: https://docs.vllm.ai/en/stable/design/prefix_caching/ PyTorch SDPA / memory-efficient attention: https://pytorch.org/blog/out-of-the-box-acceleration/ TensorRT-LLM Support Matrix: https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html Ollama Modelfile / Generate / FAQ: https://docs.ollama.com/modelfile , https://docs.ollama.com/api/generate , https://docs.ollama.com/faq
TensorRT support matrix: T4 / SM7.5 supports FP16 and INT8, but not BF16/FP8 in the main matrix. https://docs.nvidia.com/deeplearning/tensorrt/pdf/TensorRT-Support-Matrix-Guide.pdf TensorRT-LLM support matrix: current official hardware list omits Turing/T4, so T4 is effectively community-support territory there. https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html FlashAttention repo: FA3 is Hopper-focused; current published benchmarks are A100/H100-centric. https://github.com/Dao-AILab/flash-attention vLLM APC docs: KV reuse via hashed KV blocks was the baseline idea for the prefix-cache metadata. https://docs.vllm.ai/_/downloads/en/v0.6.2/pdf/ SGLang HiCache/RadixAttention docs: useful reference for prefix-cache reuse and page-granular KV organization. https://docs.sglang.io/advanced_features/hicache_design.html FasterTransformer repo: still a useful T4 FP16 optimization baseline and historical Turing-oriented reference. https://github.com/NVIDIA/FasterTransformer xFormers README: relevant as a Turing-friendly attention alternative; my mainline choice here is PyTorch SDPA on T4, which is an engineering inference from these sources rather than a direct vendor recommendation. https://github.com/facebookresearch/xformers
注意:已经有一个项目 llm-qp, llm-qp2,这两个项目,对于单token的处理方式是可以的: SDPA prefix cache prebuilt bucket + CUDA graph 他的核心代码是: from future import annotations
import hashlib import time from dataclasses import asdict, dataclass from typing import Iterable
import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
try: from transformers import BitsAndBytesConfig except ImportError: # pragma: no cover BitsAndBytesConfig = None
from llm_qp.config import PromptTaskConfig, RuntimeConfig from llm_qp.scorer import SmallClassScorer
try: from torch.nn.attention import SDPBackend, sdpa_kernel except ImportError: # pragma: no cover SDPBackend = None sdpa_kernel = None
@dataclass(slots=True) class EncodedLabel: text: str token_ids: list[int]
@dataclass(slots=True) class PrefixCache: prefix_ids: list[int] prefix_hashes: list[str] raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...]
@property
def prefix_len(self) -> int:
return len(self.prefix_ids)
@dataclass(slots=True) class MultiTokenTables: label_token_ids: torch.Tensor label_token_mask: torch.Tensor label_prefix_ids: torch.Tensor label_prefix_mask: torch.Tensor label_position_offsets: torch.Tensor
@property
def max_label_len(self) -> int:
return self.label_token_ids.shape[1]
@property
def max_label_prefix_len(self) -> int:
return self.label_prefix_ids.shape[1]
@dataclass(slots=True) class QueryScoreResult: task_name: str query: str predicted_label: str scores: list[tuple[str, float, float]] total_ms: float stage_ms: dict[str, float] fast_path: bool prefix_tokens: int continuation_tokens: int label_token_lengths: dict[str, int]
@property
def predicted_prob(self) -> float:
for label, _score, prob in self.scores:
if label == self.predicted_label:
return prob
return 0.0
@dataclass(slots=True) class MultiPromptScoreResult: query: str total_ms: float details: list[QueryScoreResult] stage_ms: dict[str, float]
def http_json(self) -> dict[str, object]:
return {
"query": self.query,
"total_ms": self.total_ms,
"stage_ms": self.stage_ms,
"details": [asdict(t) for t in self.details],
"task_results": {
t.task_name: [t.predicted_label, t.continuation_tokens, t.predicted_prob] for t in self.details if t.predicted_label != 'none'
},
}
@dataclass(slots=True) class BatchScoreResult: batch_size: int total_ms: float results: list[MultiPromptScoreResult] stage_ms: dict[str, float]
@dataclass(slots=True) class SharedRuntime: device: torch.device dtype: torch.dtype tokenizer: object model: object backbone: object hidden_size: int graph_capture_pool: object | None = None graph_capture_stream: torch.cuda.Stream | None = None
@dataclass(slots=True) class PromptBatchPlan: runner: "PromptClassifierRunner" row_start: int row_count: int score_buffer: torch.Tensor
@property
def row_stop(self) -> int:
return self.row_start + self.row_count
@dataclass(slots=True) class MixedPrefixCache: batch_size: int total_rows: int prefix_lengths: torch.Tensor attention_mask: torch.Tensor raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...]
@property
def max_prefix_len(self) -> int:
return int(self.prefix_lengths.max().item()) if self.prefix_lengths.numel() else 0
@dataclass(slots=True) class BatchLayout: batch_size: int total_rows: int plans: list[PromptBatchPlan]
@dataclass(slots=True) class MixedBucketRuntime: batch_size: int total_rows: int continuation_len: int max_input_len: int input_ids: torch.Tensor attention_mask: torch.Tensor position_ids: torch.Tensor last_hidden_state: torch.Tensor graph: torch.cuda.CUDAGraph | None = None
@dataclass(slots=True) class PreloadReport: total_ms: float stage_ms: dict[str, float] runtime: dict[str, object]
def hash_blocks(token_ids: Iterable[int], block_size: int) -> list[str]: token_list = list(token_ids) hashes: list[str] = [] for start in range(0, len(token_list), block_size): block = tokenlist[start : start + block_size] payload = ",".join(str(x) for x in block).encode("utf-8") hashes.append(hashlib.sha1(payload).hexdigest()) return hashes
def expand_legacy_cache( raw_cache: tuple[tuple[torch.Tensor, torch.Tensor], ...], batch_size: int, ) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: expanded: list[tuple[torch.Tensor, torch.Tensor]] = [] for key, value in rawcache: expanded.append( ( key.expand(batch_size, *key.shape[1:]).contiguous(), value.expand(batch_size, *value.shape[1:]).contiguous(), ) ) return tuple(expanded)
class PromptClassifierRunner: def init( self, cfg: RuntimeConfig, task_cfg: PromptTaskConfig, shared_runtime: SharedRuntime, ): self.cfg = cfg self.task_cfg = task_cfg self.device = shared_runtime.device self.dtype = shared_runtime.dtype self.tokenizer = shared_runtime.tokenizer self.model = shared_runtime.model self.backbone = shared_runtime.backbone self.hidden_size = shared_runtime.hidden_size self.prefix_text, self.suffix_text = task_cfg.prompt_parts self.prefix_ids = self.tokenizer.encode(self.prefix_text, add_special_tokens=False) self.suffix_ids = self.tokenizer.encode(self.suffix_text, add_special_tokens=False) self.labels = list(task_cfg.labels) self.encoded_labels = [ EncodedLabel(text=label, token_ids=self.encode_label_token_ids(label)) for label in self.labels ] self.num_labels = len(self.labels) self.lm_head = self.model.get_output_embeddings() self.lm_head_weight = self.lm_head.weight.detach() self.lm_head_bias = self.lm_head.bias.detach() if getattr(self.lm_head, "bias", None) is not None else None if self.cfg.force_single_token_labels and not self._has_unique_single_token_labels(): raise ValueError( f"prompt task '{self.taskcfg.name}' cannot force single-token labels because first tokens collide" ) self.fast_path = self.has_unique_single_token_labels() self.fast_path_token_ids = [item.token_ids[0] for item in self.encoded_labels] if self.fast_path else [] self.scorer = self._build_scorer() if self.fastpath else None self.multi_token_tables = self.build_multi_token_tables() if not self.fast_path else None self.prefixcache = self.build_prefixcache()
def _encode_label_token_ids(self, label: str) -> list[int]:
token_ids = self.tokenizer.encode(
f"{self.task_cfg.label_prefix}{label}",
add_special_tokens=False,
)
if not token_ids:
raise ValueError(f"label '{label}' in prompt '{self.task_cfg.name}' tokenizes to an empty sequence")
if self.cfg.force_single_token_labels:
return token_ids[:1]
return token_ids
def _has_unique_single_token_labels(self) -> bool:
token_ids: list[int] = []
for item in self.encoded_labels:
if len(item.token_ids) != 1:
return False
token_ids.append(item.token_ids[0])
return len(token_ids) == len(set(token_ids))
def _build_scorer(self) -> SmallClassScorer:
token_ids = torch.tensor(self.fast_path_token_ids, dtype=torch.long, device=self.device)
weights = torch.index_select(self.lm_head_weight, 0, token_ids).to(dtype=self.dtype).contiguous()
bias = None
if self.lm_head_bias is not None:
bias = torch.index_select(self.lm_head_bias, 0, token_ids).to(dtype=self.dtype).contiguous()
return SmallClassScorer(weights=weights, bias=bias)
def _build_multi_token_tables(self) -> MultiTokenTables:
max_label_len = max(len(item.token_ids) for item in self.encoded_labels)
max_prefix_len = max(len(item.token_ids) - 1 for item in self.encoded_labels)
label_token_ids = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.long)
label_token_mask = torch.zeros((self.num_labels, max_label_len), device=self.device, dtype=torch.float32)
label_prefix_ids = torch.full(
(self.num_labels, max_prefix_len),
fill_value=self.tokenizer.pad_token_id,
device=self.device,
dtype=torch.long,
)
label_prefix_mask = torch.zeros((self.num_labels, max_prefix_len), device=self.device, dtype=torch.long)
for idx, item in enumerate(self.encoded_labels):
token_ids = torch.tensor(item.token_ids, device=self.device, dtype=torch.long)
token_len = token_ids.numel()
label_token_ids[idx, :token_len] = token_ids
label_token_mask[idx, :token_len] = 1.0
if token_len > 1:
prefix_len = token_len - 1
label_prefix_ids[idx, :prefix_len] = token_ids[:-1]
label_prefix_mask[idx, :prefix_len] = 1
return MultiTokenTables(
label_token_ids=label_token_ids.contiguous(),
label_token_mask=label_token_mask.contiguous(),
label_prefix_ids=label_prefix_ids.contiguous(),
label_prefix_mask=label_prefix_mask.contiguous(),
label_position_offsets=torch.arange(max_label_len, device=self.device, dtype=torch.long),
)
@torch.inference_mode()
def _build_prefix_cache(self) -> PrefixCache:
if not self.prefix_ids:
return PrefixCache(prefix_ids=[], prefix_hashes=[], raw_cache=tuple())
prefix_tensor = torch.tensor([self.prefix_ids], dtype=torch.long, device=self.device)
attention_mask = torch.ones_like(prefix_tensor, dtype=torch.long, device=self.device)
outputs = self.model(
input_ids=prefix_tensor,
attention_mask=attention_mask,
use_cache=True,
return_dict=True,
)
raw_cache = tuple(
(layer.keys.detach(), layer.values.detach())
for layer in outputs.past_key_values.layers
)
return PrefixCache(
prefix_ids=list(self.prefix_ids),
prefix_hashes=_hash_blocks(self.prefix_ids, self.cfg.prefix_block_size),
raw_cache=raw_cache,
)
def expand_prefix_raw_cache(self, batch_size: int) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
if not self.prefix_cache.raw_cache:
return tuple()
return _expand_legacy_cache(self.prefix_cache.raw_cache, batch_size)
def build_continuation_from_query_ids(self, query_ids: list[int]) -> list[int]:
continuation = query_ids + self.suffix_ids
if not continuation:
raise ValueError("prompt continuation is empty after substituting query")
if self.prefix_cache.prefix_len + len(continuation) > self.cfg.max_length:
raise ValueError(
f"sequence length {self.prefix_cache.prefix_len + len(continuation)} exceeds max_length={self.cfg.max_length}"
)
return continuation
@torch.inference_mode()
def reduce_fast_scores(
self,
hidden: torch.Tensor,
out_scores: torch.Tensor,
) -> None:
assert self.scorer is not None
out_scores.copy_(self.scorer(hidden))
@torch.inference_mode()
def reduce_multi_token_scores(
self,
last_hidden_state: torch.Tensor,
batch_size: int,
max_input_len: int,
score_positions: torch.Tensor,
out_scores: torch.Tensor,
) -> None:
assert self.multi_token_tables is not None
hidden = last_hidden_state.reshape(batch_size, self.num_labels, max_input_len, self.hidden_size)
gather_index = score_positions[:, None, :, None].expand(
batch_size,
self.num_labels,
self.multi_token_tables.max_label_len,
self.hidden_size,
)
gathered_hidden = torch.gather(hidden, 2, gather_index)
used_mask = self.multi_token_tables.label_token_mask.unsqueeze(0).expand(batch_size, -1, -1).bool()
token_log_probs = torch.zeros(
(batch_size, self.num_labels, self.multi_token_tables.max_label_len),
device=self.device,
dtype=torch.float32,
)
if used_mask.any():
flat_hidden = gathered_hidden[used_mask]
flat_token_ids = self.multi_token_tables.label_token_ids.unsqueeze(0).expand(batch_size, -1, -1)[used_mask]
linear_hidden = flat_hidden.to(self.dtype) if self.device.type == "cuda" else flat_hidden.float()
linear_weight = self.lm_head_weight if self.device.type == "cuda" else self.lm_head_weight.float()
linear_bias = self.lm_head_bias
if linear_bias is not None and self.device.type != "cuda":
linear_bias = linear_bias.float()
flat_logits = F.linear(linear_hidden, linear_weight, linear_bias)
flat_selected = flat_logits.gather(1, flat_token_ids.unsqueeze(1)).squeeze(1).float()
flat_log_norm = torch.logsumexp(flat_logits.float(), dim=-1)
token_log_probs[used_mask] = flat_selected - flat_log_norm
out_scores.copy_(token_log_probs.sum(dim=-1))
def build_score_result(
self,
query: str,
scores: torch.Tensor,
stage_ms: dict[str, float],
continuation_tokens: int,
) -> QueryScoreResult:
score_values = scores.detach().float().cpu().tolist()
best_idx = max(range(len(score_values)), key=score_values.__getitem__)
probs = torch.softmax(torch.tensor(score_values, dtype=torch.float32), dim=0).tolist()
return QueryScoreResult(
task_name=self.task_cfg.name,
query=query,
predicted_label=self.labels[best_idx],
scores=[
(label, score, prob)
for label, score, prob in zip(self.labels, score_values, probs, strict=True)
],
total_ms=sum(stage_ms.values()),
stage_ms=stage_ms,
fast_path=self.fast_path,
prefix_tokens=self.prefix_cache.prefix_len,
continuation_tokens=continuation_tokens,
label_token_lengths={item.text: len(item.token_ids) for item in self.encoded_labels},
)
class MultiPromptRunner: def init(self, cfg: RuntimeConfig): self.cfg = cfg t0 = time.perf_counter() self.shared_runtime = self.build_shared_runtime(cfg) t1 = time.perf_counter() self.device = self.shared_runtime.device self.dtype = self.shared_runtime.dtype self.tokenizer = self.shared_runtime.tokenizer self.model = self.shared_runtime.model self.backbone = self.shared_runtime.backbone self.hidden_size = self.shared_runtime.hidden_size self.graph_capture_pool = self.shared_runtime.graph_capture_pool self.graph_capture_stream = self.shared_runtime.graph_capture_stream self.runners = [ PromptClassifierRunner(cfg=cfg, task_cfg=task_cfg, shared_runtime=self.shared_runtime) for task_cfg in cfg.tasks ] t2 = time.perf_counter() self.batch_layouts = {batch_size: self.build_batch_layout(batch_size) for batch_size in self.cfg.batch_sizes} t3 = time.perf_counter() self.mixed_prefix_caches = { batch_size: self.build_mixed_prefix_cache(self.batch_layouts[batch_size]) for batch_size in self.cfg.batch_sizes } t4 = time.perf_counter() self.max_label_prefix_len = max( (runner.multi_token_tables.max_label_prefix_len if runner.multi_token_tables is not None else 0) for runner in self.runners ) self.mixedbuckets = { (batch_size, continuation_len): self.build_mixed_bucket( self.batch_layouts[batch_size], self.mixedprefix_caches[batch_size], continuationlen, ) for batch_size in self.cfg.batch_sizes for continuation_len in self.cfg.continuation_buckets } t5 = time.perf_counter() self.warmup_results: dict[int, BatchScoreResult] = {} self.preload_report: PreloadReport | None = None self._init_stage_ms = { "load_model_and_tokenizer": (t1 - t0) * 1000.0, "buildprompt_runtimes": (t2 - t1) * 1000.0, "buildbatch_layouts": (t3 - t2) * 1000.0, "build_mixed_prefix_caches": (t4 - t3) * 1000.0, "build_mixed_buckets_and_graphs": (t5 - t4) * 1000.0, } self.init_total_ms = sum(self.init_stage_ms.values())
@staticmethod
def build_shared_runtime(cfg: RuntimeConfig) -> SharedRuntime:
device = torch.device(cfg.device)
dtype = torch.float16
tokenizer = AutoTokenizer.from_pretrained(
cfg.resolved_model_source,
trust_remote_code=cfg.resolved_trust_remote_code,
token=cfg.hf_token,
cache_dir=cfg.hf_cache_dir,
local_files_only=cfg.resolved_local_files_only,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
attn_impl = MultiPromptRunner._resolve_attn_impl(cfg.attn_backend)
quantization_config = None
model_kwargs: dict[str, object] = {
"trust_remote_code": cfg.resolved_trust_remote_code,
"attn_implementation": attn_impl,
"token": cfg.hf_token,
"cache_dir": cfg.hf_cache_dir,
"local_files_only": cfg.resolved_local_files_only,
}
if cfg.load_in_4bit:
if BitsAndBytesConfig is None:
raise ImportError("transformers BitsAndBytesConfig is unavailable; install bitsandbytes support first")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=dtype,
bnb_4bit_quant_type=cfg.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=cfg.bnb_4bit_use_double_quant,
)
model_kwargs["quantization_config"] = quantization_config
model_kwargs["device_map"] = {"": device.index or 0}
else:
model_kwargs["dtype"] = dtype
model_kwargs["device_map"] = None
model = AutoModelForCausalLM.from_pretrained(
cfg.resolved_model_source,
**model_kwargs,
).eval()
if not cfg.load_in_4bit:
model = model.to(device)
backbone = model.get_submodule(model.base_model_prefix)
hidden_size = model.get_output_embeddings().weight.shape[1]
graph_capture_pool = None
graph_capture_stream = None
if device.type == "cuda" and torch.cuda.is_available() and cfg.cuda_graphs and not cfg.load_in_4bit:
graph_capture_pool = torch.cuda.graph_pool_handle()
graph_capture_stream = torch.cuda.Stream(device=device)
return SharedRuntime(
device=device,
dtype=dtype,
tokenizer=tokenizer,
model=model,
backbone=backbone,
hidden_size=hidden_size,
graph_capture_pool=graph_capture_pool,
graph_capture_stream=graph_capture_stream,
)
@staticmethod
def _resolve_attn_impl(requested: str) -> str:
if requested in {"sdpa", "eager"}:
return requested
if requested == "auto":
return "sdpa"
raise ValueError(f"unsupported attn_backend: {requested}")
def _attn_context(self):
if sdpa_kernel is not None and self.cfg.attn_backend in {"auto", "sdpa"}:
return sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH])
return torch.no_grad()
def _sync(self) -> None:
if self.device.type == "cuda":
torch.cuda.synchronize()
def _pick_bucket(self, continuation_len: int) -> int:
for bucket in self.cfg.continuation_buckets:
if continuation_len <= bucket:
return bucket
if self.cfg.pad_to_bucket:
raise ValueError(
f"continuation length {continuation_len} exceeds configured buckets; extend continuation_buckets"
)
return continuation_len
def _build_batch_layout(self, batch_size: int) -> BatchLayout:
plans: list[PromptBatchPlan] = []
row_start = 0
for runner in self.runners:
row_count = batch_size if runner.fast_path else batch_size * runner.num_labels
plans.append(
PromptBatchPlan(
runner=runner,
row_start=row_start,
row_count=row_count,
score_buffer=torch.empty((batch_size, runner.num_labels), device=self.device, dtype=torch.float32),
)
)
row_start += row_count
return BatchLayout(batch_size=batch_size, total_rows=row_start, plans=plans)
def _build_mixed_prefix_cache(self, layout: BatchLayout) -> MixedPrefixCache:
prefix_lengths = torch.zeros((layout.total_rows,), device=self.device, dtype=torch.long)
non_empty = [plan.runner.prefix_cache.raw_cache for plan in layout.plans if plan.runner.prefix_cache.raw_cache]
if not non_empty:
return MixedPrefixCache(
batch_size=layout.batch_size,
total_rows=layout.total_rows,
prefix_lengths=prefix_lengths,
attention_mask=torch.zeros((layout.total_rows, 0), device=self.device, dtype=torch.long),
raw_cache=tuple(),
)
max_prefix_len = max(plan.runner.prefix_cache.prefix_len for plan in layout.plans)
num_layers = len(non_empty[0])
attention_mask = torch.zeros((layout.total_rows, max_prefix_len), device=self.device, dtype=torch.long)
raw_layers: list[tuple[torch.Tensor, torch.Tensor]] = []
for layer_idx in range(num_layers):
sample_key, sample_value = non_empty[0][layer_idx]
merged_key = sample_key.new_zeros(
(layout.total_rows, sample_key.shape[1], max_prefix_len, sample_key.shape[3])
)
merged_value = sample_value.new_zeros(
(layout.total_rows, sample_value.shape[1], max_prefix_len, sample_value.shape[3])
)
raw_layers.append((merged_key, merged_value))
for plan in layout.plans:
runner = plan.runner
prefix_len = runner.prefix_cache.prefix_len
row_slice = slice(plan.row_start, plan.row_stop)
prefix_lengths[row_slice] = prefix_len
if prefix_len == 0:
continue
attention_mask[row_slice, :prefix_len] = 1
raw_cache = runner.expand_prefix_raw_cache(plan.row_count)
for layer_idx, (key, value) in enumerate(raw_cache):
merged_key, merged_value = raw_layers[layer_idx]
merged_key[row_slice, :, :prefix_len, :] = key
merged_value[row_slice, :, :prefix_len, :] = value
return MixedPrefixCache(
batch_size=layout.batch_size,
total_rows=layout.total_rows,
prefix_lengths=prefix_lengths,
attention_mask=attention_mask.contiguous(),
raw_cache=tuple(raw_layers),
)
def _build_mixed_bucket(
self,
layout: BatchLayout,
prefix_cache: MixedPrefixCache,
continuation_len: int,
) -> MixedBucketRuntime:
max_input_len = continuation_len + self.max_label_prefix_len
total_len = prefix_cache.max_prefix_len + max_input_len
input_ids = torch.full(
(layout.total_rows, max_input_len),
fill_value=self.tokenizer.pad_token_id,
device=self.device,
dtype=torch.long,
)
attention_mask = torch.zeros((layout.total_rows, total_len), device=self.device, dtype=torch.long)
if prefix_cache.max_prefix_len:
attention_mask[:, : prefix_cache.max_prefix_len] = prefix_cache.attention_mask
position_ids = (
prefix_cache.prefix_lengths[:, None]
+ torch.arange(max_input_len, device=self.device, dtype=torch.long).unsqueeze(0)
).contiguous()
last_hidden_state = torch.empty(
(layout.total_rows, max_input_len, self.hidden_size),
device=self.device,
dtype=self.dtype,
)
bucket = MixedBucketRuntime(
batch_size=layout.batch_size,
total_rows=layout.total_rows,
continuation_len=continuation_len,
max_input_len=max_input_len,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
last_hidden_state=last_hidden_state,
)
if self.cfg.cuda_graphs:
self._capture_mixed_bucket(bucket, prefix_cache)
return bucket
@torch.inference_mode()
def _run_mixed_backbone(
self,
bucket: MixedBucketRuntime,
prefix_cache: MixedPrefixCache,
) -> None:
cache = DynamicCache(ddp_cache_data=prefix_cache.raw_cache, config=self.model.config)
with self._attn_context():
outputs = self.backbone(
input_ids=bucket.input_ids,
attention_mask=bucket.attention_mask,
position_ids=bucket.position_ids,
past_key_values=cache,
use_cache=False,
return_dict=True,
)
bucket.last_hidden_state.copy_(outputs.last_hidden_state)
def _capture_mixed_bucket(self, bucket: MixedBucketRuntime, prefix_cache: MixedPrefixCache) -> None:
if not torch.cuda.is_available():
return
try:
torch.cuda.synchronize()
stream = self.graph_capture_stream or torch.cuda.Stream(device=self.device)
with torch.cuda.stream(stream):
for _ in range(self.cfg.graph_warmups):
self._run_mixed_backbone(bucket, prefix_cache)
stream.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=self.graph_capture_pool, stream=stream):
self._run_mixed_backbone(bucket, prefix_cache)
bucket.graph = graph
except RuntimeError:
bucket.graph = None
def _prepare_bucket(
self,
layout: BatchLayout,
prefix_cache: MixedPrefixCache,
bucket: MixedBucketRuntime,
query_ids_batch: list[list[int]],
) -> tuple[list[list[int]], dict[str, list[object]]]:
del prefix_cache
bucket.input_ids.fill_(self.tokenizer.pad_token_id)
bucket.attention_mask.zero_()
if self.mixed_prefix_caches[layout.batch_size].max_prefix_len:
bucket.attention_mask[:, : self.mixed_prefix_caches[layout.batch_size].max_prefix_len] = (
self.mixed_prefix_caches[layout.batch_size].attention_mask
)
continuation_lengths_per_task: dict[str, list[int]] = {}
continuation_tokens_per_task: dict[str, list[list[int]]] = {}
prefix_base = self.mixed_prefix_caches[layout.batch_size].max_prefix_len
for plan in layout.plans:
runner = plan.runner
per_query_continuations = [runner.build_continuation_from_query_ids(query_ids) for query_ids in query_ids_batch]
continuation_tokens_per_task[runner.task_cfg.name] = per_query_continuations
continuation_lengths_per_task[runner.task_cfg.name] = [len(ids) for ids in per_query_continuations]
if runner.fast_path:
for batch_idx, continuation in enumerate(per_query_continuations):
cont_len = len(continuation)
row_idx = plan.row_start + batch_idx
bucket.input_ids[row_idx, :cont_len] = torch.tensor(continuation, device=self.device, dtype=torch.long)
bucket.attention_mask[row_idx, prefix_base : prefix_base + cont_len] = 1
continue
assert runner.multi_token_tables is not None
for batch_idx, continuation in enumerate(per_query_continuations):
cont_len = len(continuation)
row_start = plan.row_start + batch_idx * runner.num_labels
row_stop = row_start + runner.num_labels
row_slice = slice(row_start, row_stop)
cont_tensor = torch.tensor(continuation, device=self.device, dtype=torch.long)
bucket.input_ids[row_slice, :cont_len] = cont_tensor.unsqueeze(0).expand(runner.num_labels, -1)
bucket.attention_mask[row_slice, prefix_base : prefix_base + cont_len] = 1
if runner.multi_token_tables.max_label_prefix_len:
bucket.input_ids[
row_slice,
cont_len : cont_len + runner.multi_token_tables.max_label_prefix_len,
] = runner.multi_token_tables.label_prefix_ids
bucket.attention_mask[
row_slice,
prefix_base + cont_len : prefix_base + cont_len + runner.multi_token_tables.max_label_prefix_len,
] = runner.multi_token_tables.label_prefix_mask
return query_ids_batch, {
"continuation_lengths_per_task": continuation_lengths_per_task,
"continuation_tokens_per_task": continuation_tokens_per_task,
}
def _reduce_prompt_scores(
self,
layout: BatchLayout,
bucket: MixedBucketRuntime,
query_texts: list[str],
prep_meta: dict[str, list[object]],
shared_stage_ms: dict[str, float],
) -> list[MultiPromptScoreResult]:
result_rows = [[] for _ in range(layout.batch_size)]
prompt_reduce_total_ms = 0.0
for plan in layout.plans:
runner = plan.runner
continuation_lengths = prep_meta["continuation_lengths_per_task"][runner.task_cfg.name]
reduce_start = time.perf_counter()
if runner.fast_path:
hidden_rows = []
row_slice = bucket.last_hidden_state[plan.row_start : plan.row_start + layout.batch_size]
for batch_idx, cont_len in enumerate(continuation_lengths):
hidden_rows.append(row_slice[batch_idx, cont_len - 1])
hidden = torch.stack(hidden_rows, dim=0)
runner.reduce_fast_scores(hidden=hidden, out_scores=plan.score_buffer)
stage_name = "tail_scorer"
else:
assert runner.multi_token_tables is not None
score_positions = torch.stack(
[
cont_len - 1 + runner.multi_token_tables.label_position_offsets
for cont_len in continuation_lengths
],
dim=0,
)
runner.reduce_multi_token_scores(
last_hidden_state=bucket.last_hidden_state[plan.row_start : plan.row_stop],
batch_size=layout.batch_size,
max_input_len=bucket.max_input_len,
score_positions=score_positions,
out_scores=plan.score_buffer,
)
stage_name = "candidate_reduce"
self._sync()
reduce_end = time.perf_counter()
reduce_ms = (reduce_end - reduce_start) * 1000.0
prompt_reduce_total_ms += reduce_ms
for batch_idx, query in enumerate(query_texts):
stage_ms = dict(shared_stage_ms)
stage_ms[stage_name] = reduce_ms / layout.batch_size
result_rows[batch_idx].append(
runner.build_score_result(
query=query,
scores=plan.score_buffer[batch_idx],
stage_ms=stage_ms,
continuation_tokens=continuation_lengths[batch_idx],
)
)
batch_total_ms = sum(shared_stage_ms.values()) + prompt_reduce_total_ms
shared_plus_reduce = dict(shared_stage_ms)
shared_plus_reduce["prompt_reduce_total"] = prompt_reduce_total_ms
results: list[MultiPromptScoreResult] = []
for batch_idx, query in enumerate(query_texts):
results.append(
MultiPromptScoreResult(
query=query,
total_ms=batch_total_ms / layout.batch_size,
details=result_rows[batch_idx],
stage_ms={
**shared_plus_reduce,
"per_query_total_estimate": batch_total_ms / layout.batch_size,
},
)
)
return results
@torch.inference_mode()
def score_queries(self, queries: list[str]) -> BatchScoreResult:
if not queries:
raise ValueError("queries must not be empty")
batch_size = len(queries)
if batch_size not in self.batch_layouts:
raise ValueError(f"batch size {batch_size} is not preloaded; configured batch_sizes={self.cfg.batch_sizes}")
layout = self.batch_layouts[batch_size]
prefix_cache = self.mixed_prefix_caches[batch_size]
self._sync()
t0 = time.perf_counter()
query_ids_batch = [self.tokenizer.encode(query, add_special_tokens=False) for query in queries]
self._sync()
t1 = time.perf_counter()
max_continuation_len = max(
len(plan.runner.build_continuation_from_query_ids(query_ids))
for plan in layout.plans
for query_ids in query_ids_batch
)
picked_bucket = self._pick_bucket(max_continuation_len)
bucket = self.mixed_buckets[(batch_size, picked_bucket)]
_, prep_meta = self._prepare_bucket(layout, prefix_cache, bucket, query_ids_batch)
self._sync()
t2 = time.perf_counter()
if bucket.graph is not None:
bucket.graph.replay()
else:
self._run_mixed_backbone(bucket, prefix_cache)
self._sync()
t3 = time.perf_counter()
shared_stage_ms = {
"encode_queries_shared": (t1 - t0) * 1000.0,
"prepare_batch_shared": (t2 - t1) * 1000.0,
"backbone_shared": (t3 - t2) * 1000.0,
}
results = self._reduce_prompt_scores(layout, bucket, queries, prep_meta, shared_stage_ms)
total_ms = sum(shared_stage_ms.values()) + results[0].stage_ms["prompt_reduce_total"]
return BatchScoreResult(
batch_size=batch_size,
total_ms=total_ms,
results=results,
stage_ms={
**shared_stage_ms,
"prompt_reduce_total": results[0].stage_ms["prompt_reduce_total"],
},
)
def score_query(self, query: str) -> MultiPromptScoreResult:
return self.score_queries([query]).results[0]
def preload(self) -> PreloadReport:
if self._preload_report is not None:
return self._preload_report
stage_ms: dict[str, float] = dict(self._init_stage_ms)
start = time.perf_counter()
self._sync()
t0 = time.perf_counter()
warmup_batch_sizes = self.cfg.warmup_batch_sizes or self.cfg.batch_sizes
for batch_size in warmup_batch_sizes:
queries = [self.cfg.warmup_query] * batch_size
self._warmup_results[batch_size] = self.score_queries(queries)
self._sync()
t1 = time.perf_counter()
stage_ms["warmup_end_to_end"] = (t1 - t0) * 1000.0
stage_ms["startup_total_before_warmup"] = self._init_total_ms
total_ms = self._init_total_ms + (t1 - start) * 1000.0
runtime = self.preload_report()
self._preload_report = PreloadReport(total_ms=total_ms, stage_ms=stage_ms, runtime=runtime)
return self._preload_report
def preload_report(self) -> dict[str, object]:
return {
"model_name": self.cfg.resolved_model_name,
"model_source": self.cfg.resolved_model_source,
"device": str(self.device),
"dtype": self.cfg.dtype,
"attn_backend": self.cfg.attn_backend,
"execution_model": "single_mixed_backbone_per_batch",
"num_tasks": len(self.runners),
"task_names": [runner.task_cfg.name for runner in self.runners],
"batch_sizes": list(self.cfg.batch_sizes),
"continuation_buckets": list(self.cfg.continuation_buckets),
"mixed_bucket_count": len(self.mixed_buckets),
"captured_mixed_buckets": sum(bucket.graph is not None for bucket in self.mixed_buckets.values()),
"all_configured_buckets_preloaded": True,
"init_stage_ms": dict(self._init_stage_ms),
"init_total_ms": self._init_total_ms,
"force_single_token_labels": self.cfg.force_single_token_labels,
"warmup_query": self.cfg.warmup_query,
"tasks": [
{
"task_name": runner.task_cfg.name,
"fast_path": runner.fast_path,
"num_labels": runner.num_labels,
"label_token_lengths": {item.text: len(item.token_ids) for item in runner.encoded_labels},
"prefix_tokens": runner.prefix_cache.prefix_len,
"prefix_hashes": runner.prefix_cache.prefix_hashes,
"label_prefix": runner.task_cfg.label_prefix,
}
for runner in self.runners
],
}
但是关于多token的处理方式是低效的、是错的,不要参考他,请你重新实现。 本地的.venv已经创建好,是复用llm-qp的,请使用该环境 有用的代码我已经引用过来了,请你基于需求,搜寻相关资料,专注于重新开发全新的版本。 任务较重,请耐心逐步完成,慢慢来,要长时间迭代一直到服务建设完成、测试完全符合预期。