Commit 540fb5af2e1b2ac687c5920114e68cd8de45736e

Authored by tangwang
1 parent 52ea6529

添加了可关闭的开关:保留默认行为(避免 T4 上 FA2

报错),并允许通过配置或环境变量让 vLLM 自行选择 attention。 -- 临时版本
config/config.yaml
... ... @@ -404,7 +404,7 @@ services:
404 404 sort_by_doc_length: true
405 405 # 与 reranker/backends/qwen3_vllm.py 一致:standard=_format_instruction__standard(固定 yes/no system);compact=_format_instruction(instruction 作 system 且 user 内重复 Instruct)
406 406 # instruction_format: compact
407   - instruction_format: compact
  407 + instruction_format: standard
408 408 # instruction: "Given a query, score the product for relevance"
409 409 # "rank products by given query" 比 “Given a query, score the product for relevance” 更好点
410 410 # instruction: "rank products by given query, category match first"
... ... @@ -420,7 +420,10 @@ services:
420 420 model_name: "Qwen/Qwen3-Reranker-0.6B"
421 421 # 官方 Hub 原版需 true;若改用已转换的 seq-cls 权重(如 tomaarsen/...-seq-cls)则设为 false
422 422 use_original_qwen3_hf_overrides: true
423   - # vLLM 0.18:算力 < 8(如 T4)默认自动用 TRITON_ATTN;Ampere+ 可省略或设 auto。也可设环境变量 RERANK_VLLM_ATTENTION_BACKEND
  423 + # vLLM 0.18:算力 < 8(如 T4)默认注入 TRITON_ATTN,避免 FA2 在 sm<80 上报错;若更慢可关回退让 vLLM 自选:
  424 + # auto_triton_attn_on_sm_lt_8: false
  425 + # 关回退时 vLLM 可能走 FLASHINFER,首次 score 会 JIT,需 PATH 上有 ninja(requirements 已列 ninja;请用 ./scripts/start_reranker.sh 或 source venv/bin/activate,勿裸跑 /usr/bin 解析后的 python 且 PATH 无 venv/bin)
  426 + # 或环境变量 RERANK_VLLM_AUTO_TRITON_ATTN=0;仍可直接指定后端:RERANK_VLLM_ATTENTION_BACKEND / vllm_attention_backend
424 427 # vllm_attention_backend: "auto"
425 428 # 可选:与 vLLM 对齐;一般保持 auto
426 429 # vllm_runner: "auto"
... ...
requirements_reranker_qwen3_vllm_score.txt
... ... @@ -9,6 +9,8 @@
9 9 # https://docs.vllm.ai/en/latest/getting_started/installation.html
10 10  
11 11 -r requirements_reranker_base.txt
  12 +# FlashInfer JIT (vLLM may select it on Turing when TRITON_ATTN is not forced) needs a ninja binary on PATH.
  13 +ninja>=1.11
12 14 vllm==0.18.0
13 15 # Match vLLM 0.18 stack; cap <5 to avoid pip prefetching incompatible transformers 5.x.
14 16 transformers>=4.51.0,<5
... ...
reranker/backends/qwen3_vllm_score.py
... ... @@ -41,10 +41,48 @@ _DEFAULT_DOCUMENT_TEMPLATE = &quot;&lt;Document&gt;: {doc}{suffix}&quot;
41 41 _IM_USER_START = "<|im_end|>\n<|im_start|>user\n"
42 42  
43 43  
  44 +def _parse_env_bool(raw: str | None) -> bool | None:
  45 + if raw is None:
  46 + return None
  47 + s = str(raw).strip().lower()
  48 + if not s:
  49 + return None
  50 + if s in {"1", "true", "yes", "y", "on"}:
  51 + return True
  52 + if s in {"0", "false", "no", "n", "off"}:
  53 + return False
  54 + return None
  55 +
  56 +
  57 +def _auto_triton_on_sm_lt_8_enabled(config: Dict[str, Any]) -> bool:
  58 + """
  59 + When True (default), sm < 8 injects TRITON_ATTN to avoid FA2-only paths that error on T4/V100.
  60 +
  61 + When False, vLLM may choose FLASHINFER on Turing; first ``score()`` can JIT-compile and needs
  62 + ``ninja`` on PATH (``requirements_reranker_qwen3_vllm_score.txt``). Use
  63 + ``./scripts/start_reranker.sh`` (prepends the backend venv's ``bin`` to ``PATH``) or
  64 + ``source .../bin/activate``.
  65 + """
  66 + env = _parse_env_bool(os.getenv("RERANK_VLLM_AUTO_TRITON_ATTN"))
  67 + if env is not None:
  68 + return env
  69 + raw = config.get("auto_triton_attn_on_sm_lt_8")
  70 + if raw is None:
  71 + return True
  72 + if isinstance(raw, bool):
  73 + return raw
  74 + parsed = _parse_env_bool(str(raw))
  75 + return True if parsed is None else parsed
  76 +
  77 +
44 78 def _resolve_vllm_attention_config(config: Dict[str, Any]) -> Dict[str, Any] | None:
45 79 """
46   - vLLM 0.18 defaults to Flash-Attention paths that require compute capability >= 8 (Ampere+).
47   - Turing / Volta (e.g. T4 sm_75) must use a non-FA backend such as TRITON_ATTN.
  80 + Optional explicit backend via vllm_attention_backend / RERANK_VLLM_ATTENTION_BACKEND.
  81 +
  82 + On compute capability < 8, vLLM may default to Flash-Attention 2, which is not supported on
  83 + Turing/Volta; this module historically injected TRITON_ATTN. That can be slower than vLLM's
  84 + other fallbacks — disable with auto_triton_attn_on_sm_lt_8: false or
  85 + RERANK_VLLM_AUTO_TRITON_ATTN=0 if your stack runs without errors.
48 86 """
49 87 env = (os.getenv("RERANK_VLLM_ATTENTION_BACKEND") or "").strip()
50 88 raw = config.get("vllm_attention_backend")
... ... @@ -63,16 +101,26 @@ def _resolve_vllm_attention_config(config: Dict[str, Any]) -&gt; Dict[str, Any] | N
63 101 return {"backend": backend}
64 102  
65 103 major, minor = torch.cuda.get_device_capability()
66   - if major < 8:
  104 + if major < 8 and _auto_triton_on_sm_lt_8_enabled(config):
67 105 logger.info(
68 106 "[Qwen3_VLLM_SCORE] GPU compute capability %d.%d < 8.0; using attention backend "
69 107 "TRITON_ATTN (Flash-Attention 2 requires sm >= 80). "
70   - "Override with services.rerank.backends.qwen3_vllm_score.vllm_attention_backend "
71   - "or RERANK_VLLM_ATTENTION_BACKEND.",
  108 + "To use vLLM default instead: auto_triton_attn_on_sm_lt_8: false or "
  109 + "RERANK_VLLM_AUTO_TRITON_ATTN=0; or set vllm_attention_backend / "
  110 + "RERANK_VLLM_ATTENTION_BACKEND.",
72 111 major,
73 112 minor,
74 113 )
75 114 return {"backend": "TRITON_ATTN"}
  115 + if major < 8 and not _auto_triton_on_sm_lt_8_enabled(config):
  116 + logger.info(
  117 + "[Qwen3_VLLM_SCORE] GPU compute capability %d.%d < 8.0; auto TRITON_ATTN disabled — "
  118 + "leaving attention backend to vLLM (no attention_config). "
  119 + "If the first score() fails on 'ninja', install ninja in the score venv, ensure "
  120 + "PATH includes that venv's bin (see start_reranker.sh), or use system ninja-build.",
  121 + major,
  122 + minor,
  123 + )
76 124 return None
77 125  
78 126  
... ...
reranker/性能优化版本的qwen3_vllm_score 为什么反而更慢.md 0 → 100644
... ... @@ -0,0 +1,141 @@
  1 +
  2 +结论先说:**YAML 里能对齐的项(`model_name`、`max_model_len`、`infer_batch_size`、`prefix_caching` 等)你们已经基本对齐了**;`qwen3_vllm_score` 更慢,主要来自**两条后端走的不是同一条 vLLM 推理路径**,以及 **score 后端在 T4 上强制了 attention 后端**,和 **generate 路径更容易吃到「同 query、多 doc」的优化**。
  3 +
  4 +---
  5 +
  6 +## 1. 配置层面:哪些「对等」、哪些根本不存在于另一侧
  7 +
  8 +两边共用的逻辑在代码里是一致的:`infer_batch_size`、`sort_by_doc_length`、去重、`instruction` / `instruction_format` 的语义(在各自实现里)是对齐设计的。
  9 +
  10 +差异在于 **`qwen3_vllm_score` 必须多出来的 LLM 构造参数**:`runner` / `convert` / `hf_overrides`(把 Hub 模型改成 `Qwen3ForSequenceClassification` 那条链路)。`qwen3_vllm` 没有这些,因为它是**普通 causal LM + `generate`**。这不是 `config.yaml` 漏配,而是两种 API 的必要差别。
  11 +
  12 +```132:140:reranker/backends/qwen3_vllm.py
  13 + self._llm = LLM(
  14 + model=model_name,
  15 + tensor_parallel_size=tensor_parallel_size,
  16 + max_model_len=max_model_len,
  17 + gpu_memory_utilization=gpu_memory_utilization,
  18 + enable_prefix_caching=enable_prefix_caching,
  19 + enforce_eager=enforce_eager,
  20 + dtype=dtype,
  21 + )
  22 +```
  23 +
  24 +```167:195:reranker/backends/qwen3_vllm_score.py
  25 + llm_kwargs: Dict[str, Any] = {
  26 + "model": model_name,
  27 + "runner": runner,
  28 + "convert": convert,
  29 + "tensor_parallel_size": tensor_parallel_size,
  30 + "max_model_len": max_model_len,
  31 + "gpu_memory_utilization": gpu_memory_utilization,
  32 + "enable_prefix_caching": enable_prefix_caching,
  33 + "enforce_eager": enforce_eager,
  34 + "dtype": dtype,
  35 + }
  36 + hf_overrides: Dict[str, Any] = dict(self._config.get("hf_overrides") or {})
  37 + if use_hf_overrides:
  38 + hf_overrides = {
  39 + **hf_overrides,
  40 + "architectures": ["Qwen3ForSequenceClassification"],
  41 + "classifier_from_token": ["no", "yes"],
  42 + "is_original_qwen3_reranker": True,
  43 + }
  44 + if hf_overrides:
  45 + llm_kwargs["hf_overrides"] = hf_overrides
  46 +
  47 + attn_cfg = _resolve_vllm_attention_config(self._config)
  48 + if attn_cfg is not None:
  49 + llm_kwargs["attention_config"] = attn_cfg
  50 +
  51 + self._llm = LLM(**llm_kwargs)
  52 +```
  53 +
  54 +**小坑(仅当有人删掉 YAML 字段时):**
  55 +`instruction_format` 的**代码默认值不一致**——`qwen3_vllm` 默认 `compact`,`qwen3_vllm_score` 默认 `standard`。你贴的片段里两边都写了 `standard`,所以当前是对齐的。
  56 +
  57 +```93:98:reranker/backends/qwen3_vllm.py
  58 + _fmt = str(self._config.get("instruction_format") or "compact").strip().lower()
  59 +```
  60 +
  61 +```104:109:reranker/backends/qwen3_vllm_score.py
  62 + _fmt = str(self._config.get("instruction_format") or "standard").strip().lower()
  63 +```
  64 +
  65 +---
  66 +
  67 +## 2. 为什么「按理 score 更快」在你们机器上反过来
  68 +
  69 +你们自己的报告里写的是 **Tesla T4**(算力 **sm_75 &lt; 8.0**)。这一点和代码里的行为直接相关。
  70 +
  71 +### (1)只有 score 后端在 sm&lt;8 时**强制** `TRITON_ATTN`
  72 +
  73 +```65:75:reranker/backends/qwen3_vllm_score.py
  74 + major, minor = torch.cuda.get_device_capability()
  75 + if major < 8:
  76 + logger.info(
  77 + "[Qwen3_VLLM_SCORE] GPU compute capability %d.%d < 8.0; using attention backend "
  78 + "TRITON_ATTN (Flash-Attention 2 requires sm >= 80). "
  79 + ...
  80 + )
  81 + return {"backend": "TRITON_ATTN"}
  82 +```
  83 +
  84 +`qwen3_vllm` **没有**这段逻辑,**不写** `attention_config`,完全交给 vLLM 在 **generate** 路径上自己选实现。
  85 +因此在 T4 上很容易出现:**两条路径实际用的 attention / kernel 组合并不相同**;若默认路径比强制的 `TRITON_ATTN` 更适合你们的 batch 与序列长度,就会出现 **score 更慢**。
  86 +若要验证,可在 score 的 YAML 里试 `vllm_attention_backend`(或与 `RERANK_VLLM_ATTENTION_BACKEND` 对齐到和 generate 实际一致的后端),或在 Ampere+ 上复测矩阵。
  87 +
  88 +### (2)工作量与 vLLM 优化重心不同(这是主因之一)
  89 +
  90 +- **generate 后端**:`max_tokens=1`、`allowed_token_ids` 只有 yes/no,本质是 **prefill + 极短 decode**,且 logprobs 只关心最后一步的分布。
  91 +- **score 后端**:`LLM.score()` 走 **pooling / cross-encoder 式**的打分图,是另一条 runner,**不等于**「比 1-token generate 一定更少算」;在 vLLM 里通常 **causal generate 路径打磨得更狠**。
  92 +
  93 +所以「score API 更高级所以一定更快」在这个模型用法下**不一定成立**。
  94 +
  95 +### (3)`enable_prefix_caching: true` 对两边的「可缓存前缀」不对称
  96 +
  97 +同一 query、多个 doc 时,**generate** 路径用 chat template 拼出来的 prompt,**从 system 到 query 的长前缀在 batch 内完全相同**,很容易成为 prefix caching 的理想场景。
  98 +
  99 +**score** 路径把内容拆成 `queries` / `documents` 两列交给 `score()`,内部如何切块、是否能把「同一 query 对应多 doc」映射成与 generate 同等强度的前缀复用,依赖 vLLM 实现;很多版本下 **generate + 共享前缀** 更占便宜。你们 `max_model_len: 160` 很短,prefill 成本敏感,**谁更吃到缓存**会明显拉开差距。
  100 +
  101 +### (4)Tokenizer 侧:后者多了一步「批量模板」优化
  102 +
  103 +`qwen3_vllm` 对整批 `apply_chat_template` 一次做完再 `generate`:
  104 +
  105 +```171:180:reranker/backends/qwen3_vllm.py
  106 + messages_batch = [
  107 + self._format_messages(self._instruction, q, d) for q, d in pairs
  108 + ]
  109 + tokenized = self._tokenizer.apply_chat_template(
  110 + messages_batch,
  111 + tokenize=True,
  112 + add_generation_prompt=False,
  113 + enable_thinking=False,
  114 + )
  115 +```
  116 +
  117 +`qwen3_vllm_score` 在 Python 里逐对拼字符串,再进 `score()`(tokenization 在 vLLM 内)。这一项通常不是第一瓶颈,但在 **batch 大、序列短** 时也会有一点差别。
  118 +
  119 +### (5)两个 venv 的 vLLM 版本不同
  120 +
  121 +- `.venv-reranker`:`vllm>=0.8.5`(实际装的几版本会变)
  122 +- `.venv-reranker-score`:固定 `vllm==0.18.0`
  123 +
  124 +对比「谁更快」时,**版本 + 代码路径**是绑在一起的;不能假设「新 vLLM + score」在 T4 上一定赢过「旧 vLLM + 1-token generate」。
  125 +
  126 +---
  127 +
  128 +## 3. 和你们 `RESULTS.md` 的对应关系
  129 +
  130 +`perf_reports/.../RESULTS.md` 里:**同一 `instruction_format` 下 `qwen3_vllm` 全程低于 `qwen3_vllm_score`**,与上面 **T4 + attention 强制 + 不同 runner + prefix cache 利用率** 的解释一致;报告里也写了在别的 GPU / vLLM 版本下排序可能变,这是合理的。
  131 +
  132 +---
  133 +
  134 +## 4. 若要「对齐实验」可以怎么做(方向性)
  135 +
  136 +1. **在 Ampere(A10/A100 等 sm≥80)上跑同一脚本**,看 score 是否反超(FlashAttention 路径更完整时,score 路径有时会更合理)。
  137 +2. **在 score 侧显式设置 `vllm_attention_backend`**(或与 env 对齐),避免在 T4 上只有 score 被锁死 `TRITON_ATTN` 而 generate 走另一条。
  138 +3. **固定两边 `pip show vllm` 版本**再比,否则「版本差」会污染结论。
  139 +4. 用 vLLM 的 profiler / 日志确认 **prefix cache hit** 在两种后端上的差异(若你们要量化「缓存」这一条)。
  140 +
  141 +**总结:** 不是 `config.yaml` 里少抄了几个键;而是 **推理图不同、T4 上 attention 策略不对称、以及 generate 对「同 query 多 doc」更友好**,导致在你们当前环境下 **`qwen3_vllm` 比 `qwen3_vllm_score` 更快是合理现象**,与「score API 理论上更干净」并不矛盾。
0 142 \ No newline at end of file
... ...
scripts/smoke_qwen3_vllm_score_backend.py 0 → 100644
... ... @@ -0,0 +1,87 @@
  1 +#!/usr/bin/env python3
  2 +"""
  3 +Smoke test: load Qwen3VLLMScoreRerankerBackend (must run as a file, not stdin — vLLM spawn).
  4 +
  5 +Usage (from repo root, score venv):
  6 + PYTHONPATH=. ./.venv-reranker-score/bin/python scripts/smoke_qwen3_vllm_score_backend.py
  7 +
  8 +Same as production: vLLM child processes need the venv's ``bin`` on PATH (for pip's ``ninja`` when
  9 +using FLASHINFER). ``start_reranker.sh`` exports that; this script prepends ``sysconfig.get_path("scripts")``
  10 +(the stdlib location for this environment's console scripts, independent of ``python`` symlink targets).
  11 +"""
  12 +
  13 +from __future__ import annotations
  14 +
  15 +import argparse
  16 +import logging
  17 +import os
  18 +import sys
  19 +import sysconfig
  20 +from pathlib import Path
  21 +
  22 +# Repo root on sys.path when run as scripts/smoke_*.py
  23 +_ROOT = Path(__file__).resolve().parents[1]
  24 +if str(_ROOT) not in sys.path:
  25 + sys.path.insert(0, str(_ROOT))
  26 +
  27 +logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
  28 +
  29 +import torch
  30 +
  31 +from reranker.backends.qwen3_vllm_score import (
  32 + Qwen3VLLMScoreRerankerBackend,
  33 + _resolve_vllm_attention_config,
  34 +)
  35 +
  36 +
  37 +def main() -> int:
  38 + p = argparse.ArgumentParser()
  39 + p.add_argument(
  40 + "--no-auto-triton",
  41 + action="store_true",
  42 + help="Set auto_triton_attn_on_sm_lt_8=False (match config opt-out)",
  43 + )
  44 + p.add_argument(
  45 + "--gpu-memory-utilization",
  46 + type=float,
  47 + default=0.12,
  48 + help="vLLM gpu_memory_utilization (default 0.12 for tight GPUs)",
  49 + )
  50 + args = p.parse_args()
  51 +
  52 + scripts = sysconfig.get_path("scripts")
  53 + if scripts:
  54 + os.environ["PATH"] = scripts + os.pathsep + os.environ.get("PATH", "")
  55 +
  56 + if not torch.cuda.is_available():
  57 + print("SKIP: CUDA not available")
  58 + return 0
  59 +
  60 + cfg = {
  61 + "model_name": "Qwen/Qwen3-Reranker-0.6B",
  62 + "max_model_len": 160,
  63 + "tensor_parallel_size": 1,
  64 + "gpu_memory_utilization": args.gpu_memory_utilization,
  65 + "dtype": "float16",
  66 + "enable_prefix_caching": False,
  67 + "enforce_eager": True,
  68 + "infer_batch_size": 4,
  69 + "instruction_format": "standard",
  70 + }
  71 + if args.no_auto_triton:
  72 + cfg["auto_triton_attn_on_sm_lt_8"] = False
  73 +
  74 + attn = _resolve_vllm_attention_config(cfg)
  75 + print("attention_config:", attn)
  76 +
  77 + print("Loading backend ...")
  78 + backend = Qwen3VLLMScoreRerankerBackend(cfg)
  79 + scores, meta = backend.score_with_meta("smoke query", ["title one", "title two"], normalize=False)
  80 + print("scores:", scores)
  81 + print("meta:", {k: meta[k] for k in ("backend", "infer_batch_size", "instruction_format") if k in meta})
  82 + print("OK")
  83 + return 0
  84 +
  85 +
  86 +if __name__ == "__main__":
  87 + raise SystemExit(main())
... ...
scripts/start_reranker.sh
... ... @@ -41,6 +41,8 @@ export TRITON_CACHE_DIR=&quot;${RERANKER_RUNTIME_DIR}/triton&quot;
41 41 export TORCHINDUCTOR_CACHE_DIR="${RERANKER_RUNTIME_DIR}/torch_compile"
42 42 export TMPDIR="${RERANKER_RUNTIME_DIR}/tmp"
43 43 export VLLM_NO_USAGE_STATS="${VLLM_NO_USAGE_STATS:-1}"
  44 +# venv bin must be on PATH before Python starts: vLLM worker inherits it; FlashInfer JIT needs
  45 +# pip-installed ninja when qwen3_vllm_score does not force TRITON_ATTN (e.g. T4 + auto_triton off).
44 46 export PATH="${RERANKER_VENV}/bin:${PATH}"
45 47  
46 48 if [[ "${RERANK_BACKEND}" == qwen3_gguf* ]]; then
... ...