"""
Qwen3-Reranker via vLLM ``LLM.score()`` (pooling / cross-encoder score API).
Matches vLLM ``examples/offline_inference/qwen3_reranker.py``: paired
``llm.score(query_texts, doc_texts)`` with the recommended prefix/suffix templates.
Requires vLLM >= 0.17 (uses ``runner``/``convert`` auto, not legacy ``task="score"``).
Dedicated venv: ``.venv-reranker-score`` + ``requirements_reranker_qwen3_vllm_score.txt``
(see ``./scripts/setup_reranker_venv.sh qwen3_vllm_score``). Default ``model_name`` can match
``qwen3_vllm``; only the Python env differs for pinned high-performance vLLM.
Reference: https://docs.vllm.ai/ — Qwen3 reranker example
"""
from __future__ import annotations
import logging
import os
import threading
import time
from typing import Any, Dict, List, Tuple
logger = logging.getLogger("reranker.backends.qwen3_vllm_score")
import torch
from vllm import LLM
from reranker.backends.qwen3_vllm import deduplicate_with_positions
# Official vLLM Qwen3 reranker prompt layout (im_start blocks + assistant suffix).
_DEFAULT_PREFIX = (
"<|im_start|>system\n"
"Judge whether the Document meets the requirements based on the Query and the Instruct "
'provided. Note that the answer can only be "yes" or "no".'
"<|im_end|>\n<|im_start|>user\n"
)
_DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
_DEFAULT_QUERY_TEMPLATE = "{prefix}: {instruction}\n: {query}\n"
_DEFAULT_DOCUMENT_TEMPLATE = ": {doc}{suffix}"
# compact:与 qwen3_vllm._format_instruction 一致(instruction 作 system,user 内重复 Instruct)
_IM_USER_START = "<|im_end|>\n<|im_start|>user\n"
def _parse_env_bool(raw: str | None) -> bool | None:
if raw is None:
return None
s = str(raw).strip().lower()
if not s:
return None
if s in {"1", "true", "yes", "y", "on"}:
return True
if s in {"0", "false", "no", "n", "off"}:
return False
return None
def _auto_triton_on_sm_lt_8_enabled(config: Dict[str, Any]) -> bool:
"""
When True (default), sm < 8 injects TRITON_ATTN to avoid FA2-only paths that error on T4/V100.
When False, vLLM may choose FLASHINFER on Turing; first ``score()`` can JIT-compile and needs
``ninja`` on PATH (``requirements_reranker_qwen3_vllm_score.txt``). Use
``./scripts/start_reranker.sh`` (prepends the backend venv's ``bin`` to ``PATH``) or
``source .../bin/activate``.
"""
env = _parse_env_bool(os.getenv("RERANK_VLLM_AUTO_TRITON_ATTN"))
if env is not None:
return env
raw = config.get("auto_triton_attn_on_sm_lt_8")
if raw is None:
return True
if isinstance(raw, bool):
return raw
parsed = _parse_env_bool(str(raw))
return True if parsed is None else parsed
def _resolve_vllm_attention_config(config: Dict[str, Any]) -> Dict[str, Any] | None:
"""
Optional explicit backend via vllm_attention_backend / RERANK_VLLM_ATTENTION_BACKEND.
On compute capability < 8, vLLM may default to Flash-Attention 2, which is not supported on
Turing/Volta; this module historically injected TRITON_ATTN. That can be slower than vLLM's
other fallbacks — disable with auto_triton_attn_on_sm_lt_8: false or
RERANK_VLLM_AUTO_TRITON_ATTN=0 if your stack runs without errors.
"""
env = (os.getenv("RERANK_VLLM_ATTENTION_BACKEND") or "").strip()
raw = config.get("vllm_attention_backend")
if env:
choice = env
elif raw is not None and str(raw).strip() and str(raw).strip().lower() != "auto":
choice = str(raw).strip()
else:
choice = ""
if choice:
backend = choice.strip().upper()
if backend == "AUTO":
choice = ""
else:
logger.info("[Qwen3_VLLM_SCORE] attention_config.backend=%s (from config/env)", backend)
return {"backend": backend}
major, minor = torch.cuda.get_device_capability()
if major < 8 and _auto_triton_on_sm_lt_8_enabled(config):
logger.info(
"[Qwen3_VLLM_SCORE] GPU compute capability %d.%d < 8.0; using attention backend "
"TRITON_ATTN (Flash-Attention 2 requires sm >= 80). "
"To use vLLM default instead: auto_triton_attn_on_sm_lt_8: false or "
"RERANK_VLLM_AUTO_TRITON_ATTN=0; or set vllm_attention_backend / "
"RERANK_VLLM_ATTENTION_BACKEND.",
major,
minor,
)
return {"backend": "TRITON_ATTN"}
if major < 8 and not _auto_triton_on_sm_lt_8_enabled(config):
logger.info(
"[Qwen3_VLLM_SCORE] GPU compute capability %d.%d < 8.0; auto TRITON_ATTN disabled — "
"leaving attention backend to vLLM (no attention_config). "
"If the first score() fails on 'ninja', install ninja in the score venv, ensure "
"PATH includes that venv's bin (see start_reranker.sh), or use system ninja-build.",
major,
minor,
)
return None
class Qwen3VLLMScoreRerankerBackend:
"""
Qwen3 reranker using vLLM ``LLM.score()`` (pooling runner) for cross-encoder scores.
Config from ``services.rerank.backends.qwen3_vllm_score``.
"""
def __init__(self, config: Dict[str, Any]) -> None:
self._config = config or {}
model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
max_model_len = int(self._config.get("max_model_len", 2048))
tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1))
gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.4))
enable_prefix_caching = bool(self._config.get("enable_prefix_caching", False))
enforce_eager = bool(self._config.get("enforce_eager", True))
dtype = str(self._config.get("dtype", "float16")).strip().lower()
use_hf_overrides = self._config.get("use_original_qwen3_hf_overrides")
if use_hf_overrides is None:
use_hf_overrides = True
use_hf_overrides = bool(use_hf_overrides)
self._instruction = str(
self._config.get("instruction")
or "Given a query, score the product for relevance"
)
_fmt = str(self._config.get("instruction_format") or "standard").strip().lower()
if _fmt not in {"standard", "compact"}:
raise ValueError(
f"instruction_format must be 'standard' or 'compact', got {_fmt!r}"
)
self._instruction_format = _fmt
self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX)
self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX)
self._query_template = str(self._config.get("query_template") or _DEFAULT_QUERY_TEMPLATE)
self._document_template = str(
self._config.get("document_template") or _DEFAULT_DOCUMENT_TEMPLATE
)
infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get(
"infer_batch_size", 64
)
sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH")
if sort_by_doc_length is None:
sort_by_doc_length = self._config.get("sort_by_doc_length", True)
self._infer_batch_size = int(infer_batch_size)
self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {
"1",
"true",
"yes",
"y",
"on",
}
if not torch.cuda.is_available():
raise RuntimeError(
"qwen3_vllm_score backend requires CUDA GPU, but torch.cuda.is_available() is False"
)
if dtype not in {"float16", "half", "auto"}:
raise ValueError(
f"Unsupported dtype for qwen3_vllm_score: {dtype!r}. Use float16/half/auto."
)
if self._infer_batch_size <= 0:
raise ValueError(f"infer_batch_size must be > 0, got {self._infer_batch_size}")
runner = str(self._config.get("vllm_runner") or "auto").strip().lower()
convert = str(self._config.get("vllm_convert") or "auto").strip().lower()
if runner not in {"auto", "generate", "pooling", "draft"}:
raise ValueError(f"Invalid vllm_runner: {runner!r}")
if convert not in {"auto", "none", "embed", "classify"}:
raise ValueError(f"Invalid vllm_convert: {convert!r}")
logger.info(
"[Qwen3_VLLM_SCORE] Loading model %s (LLM.score API, runner=%s, convert=%s, "
"hf_overrides=%s, max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s, "
"instruction_format=%s)",
model_name,
runner,
convert,
use_hf_overrides,
max_model_len,
tensor_parallel_size,
gpu_memory_utilization,
dtype,
enable_prefix_caching,
self._instruction_format,
)
# vLLM 0.17+ uses runner/convert instead of LLM(..., task="score"). With the official
# Qwen3 reranker hf_overrides, architecture becomes *ForSequenceClassification -> pooling+classify.
llm_kwargs: Dict[str, Any] = {
"model": model_name,
"runner": runner,
"convert": convert,
"tensor_parallel_size": tensor_parallel_size,
"max_model_len": max_model_len,
"gpu_memory_utilization": gpu_memory_utilization,
"enable_prefix_caching": enable_prefix_caching,
"enforce_eager": enforce_eager,
"dtype": dtype,
}
hf_overrides: Dict[str, Any] = dict(self._config.get("hf_overrides") or {})
if use_hf_overrides:
hf_overrides = {
**hf_overrides,
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
if hf_overrides:
llm_kwargs["hf_overrides"] = hf_overrides
attn_cfg = _resolve_vllm_attention_config(self._config)
if attn_cfg is not None:
llm_kwargs["attention_config"] = attn_cfg
self._llm = LLM(**llm_kwargs)
# vLLM score path: single-process safety (mirrors generate backend until verified).
self._infer_lock = threading.Lock()
self._model_name = model_name
logger.info("[Qwen3_VLLM_SCORE] Model ready | model=%s", model_name)
def _format_pair(self, query: str, doc: str) -> Tuple[str, str]:
if self._instruction_format == "compact":
# Align with reranker.backends.qwen3_vllm._format_instruction query/doc split for LLM.score().
compact_prefix = f"<|im_start|>system\n{self._instruction}{_IM_USER_START}"
q_text = (
f"{compact_prefix}: {self._instruction}\n\n: {query}\n"
)
d_text = f"\n: {doc}{self._suffix}"
return q_text, d_text
q_text = self._query_template.format(
prefix=self._prefix,
instruction=self._instruction,
query=query,
)
d_text = self._document_template.format(doc=doc, suffix=self._suffix)
return q_text, d_text
def _score_batch(self, pairs: List[Tuple[str, str]]) -> List[float]:
if not pairs:
return []
queries: List[str] = []
documents: List[str] = []
for q, d in pairs:
qt, dt = self._format_pair(q, d)
queries.append(qt)
documents.append(dt)
with self._infer_lock:
outputs = self._llm.score(queries, documents, use_tqdm=False)
scores: List[float] = []
for out in outputs:
so = out.outputs
scores.append(float(so.score))
return scores
@staticmethod
def _estimate_doc_lengths(docs: List[str]) -> List[int]:
if not docs:
return []
return [len(text) for text in docs]
def score_with_meta(
self,
query: str,
docs: List[str],
normalize: bool = True,
) -> Tuple[List[float], Dict[str, Any]]:
start_ts = time.time()
total_docs = len(docs) if docs else 0
output_scores: List[float] = [0.0] * total_docs
query = "" if query is None else str(query).strip()
indexed: List[Tuple[int, str]] = []
for i, doc in enumerate(docs or []):
if doc is None:
continue
text = str(doc).strip()
if not text:
continue
indexed.append((i, text))
if not query or not indexed:
elapsed_ms = (time.time() - start_ts) * 1000.0
return output_scores, {
"input_docs": total_docs,
"usable_docs": len(indexed),
"unique_docs": 0,
"dedup_ratio": 0.0,
"elapsed_ms": round(elapsed_ms, 3),
"model": self._model_name,
"backend": "qwen3_vllm_score",
"normalize": normalize,
"infer_batch_size": self._infer_batch_size,
"inference_batches": 0,
"sort_by_doc_length": self._sort_by_doc_length,
"instruction_format": self._instruction_format,
}
indexed_texts = [text for _, text in indexed]
unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts)
lengths = self._estimate_doc_lengths(unique_texts)
order = list(range(len(unique_texts)))
if self._sort_by_doc_length and len(unique_texts) > 1:
order = sorted(order, key=lambda i: lengths[i])
unique_scores: List[float] = [0.0] * len(unique_texts)
inference_batches = 0
for start in range(0, len(order), self._infer_batch_size):
batch_indices = order[start : start + self._infer_batch_size]
inference_batches += 1
pairs = [(query, unique_texts[i]) for i in batch_indices]
batch_scores = self._score_batch(pairs)
if len(batch_scores) != len(batch_indices):
raise RuntimeError(
f"Reranker score size mismatch: expected {len(batch_indices)}, got {len(batch_scores)}"
)
for idx, score in zip(batch_indices, batch_scores):
unique_scores[idx] = float(score)
for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
output_scores[orig_idx] = float(unique_scores[unique_idx])
elapsed_ms = (time.time() - start_ts) * 1000.0
dedup_ratio = 0.0
if indexed:
dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))
meta = {
"input_docs": total_docs,
"usable_docs": len(indexed),
"unique_docs": len(unique_texts),
"dedup_ratio": round(dedup_ratio, 4),
"elapsed_ms": round(elapsed_ms, 3),
"model": self._model_name,
"backend": "qwen3_vllm_score",
"normalize": normalize,
"infer_batch_size": self._infer_batch_size,
"inference_batches": inference_batches,
"sort_by_doc_length": self._sort_by_doc_length,
"instruction_format": self._instruction_format,
}
return output_scores, meta