qwen3_vllm_score.py 13.1 KB
"""
Qwen3-Reranker via vLLM ``LLM.score()`` (pooling / cross-encoder score API).

Matches vLLM ``examples/offline_inference/qwen3_reranker.py``: paired
``llm.score(query_texts, doc_texts)`` with the recommended prefix/suffix templates.
Requires vLLM >= 0.17 (uses ``runner``/``convert`` auto, not legacy ``task="score"``).

Dedicated venv: ``.venv-reranker-score`` + ``requirements_reranker_qwen3_vllm_score.txt``
(see ``./scripts/setup_reranker_venv.sh qwen3_vllm_score``). Default ``model_name`` can match
``qwen3_vllm``; only the Python env differs for pinned high-performance vLLM.

Reference: https://docs.vllm.ai/ — Qwen3 reranker example
"""

from __future__ import annotations

import logging
import os
import threading
import time
from typing import Any, Dict, List, Tuple

logger = logging.getLogger("reranker.backends.qwen3_vllm_score")

import torch
from vllm import LLM

from reranker.backends.qwen3_vllm import deduplicate_with_positions

# Official vLLM Qwen3 reranker prompt layout (im_start blocks + assistant suffix).
_DEFAULT_PREFIX = (
    "<|im_start|>system\n"
    "Judge whether the Document meets the requirements based on the Query and the Instruct "
    'provided. Note that the answer can only be "yes" or "no".'
    "<|im_end|>\n<|im_start|>user\n"
)
_DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
_DEFAULT_QUERY_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
_DEFAULT_DOCUMENT_TEMPLATE = "<Document>: {doc}{suffix}"
# compact:与 qwen3_vllm._format_instruction 一致(instruction 作 system,user 内重复 Instruct)
_IM_USER_START = "<|im_end|>\n<|im_start|>user\n"


def _resolve_vllm_attention_config(config: Dict[str, Any]) -> Dict[str, Any] | None:
    """
    vLLM 0.18 defaults to Flash-Attention paths that require compute capability >= 8 (Ampere+).
    Turing / Volta (e.g. T4 sm_75) must use a non-FA backend such as TRITON_ATTN.
    """
    env = (os.getenv("RERANK_VLLM_ATTENTION_BACKEND") or "").strip()
    raw = config.get("vllm_attention_backend")
    if env:
        choice = env
    elif raw is not None and str(raw).strip() and str(raw).strip().lower() != "auto":
        choice = str(raw).strip()
    else:
        choice = ""
    if choice:
        backend = choice.strip().upper()
        if backend == "AUTO":
            choice = ""
        else:
            logger.info("[Qwen3_VLLM_SCORE] attention_config.backend=%s (from config/env)", backend)
            return {"backend": backend}

    major, minor = torch.cuda.get_device_capability()
    if major < 8:
        logger.info(
            "[Qwen3_VLLM_SCORE] GPU compute capability %d.%d < 8.0; using attention backend "
            "TRITON_ATTN (Flash-Attention 2 requires sm >= 80). "
            "Override with services.rerank.backends.qwen3_vllm_score.vllm_attention_backend "
            "or RERANK_VLLM_ATTENTION_BACKEND.",
            major,
            minor,
        )
        return {"backend": "TRITON_ATTN"}
    return None


class Qwen3VLLMScoreRerankerBackend:
    """
    Qwen3 reranker using vLLM ``LLM.score()`` (pooling runner) for cross-encoder scores.

    Config from ``services.rerank.backends.qwen3_vllm_score``.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self._config = config or {}
        model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
        max_model_len = int(self._config.get("max_model_len", 2048))
        tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1))
        gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.4))
        enable_prefix_caching = bool(self._config.get("enable_prefix_caching", False))
        enforce_eager = bool(self._config.get("enforce_eager", True))
        dtype = str(self._config.get("dtype", "float16")).strip().lower()
        use_hf_overrides = self._config.get("use_original_qwen3_hf_overrides")
        if use_hf_overrides is None:
            use_hf_overrides = True
        use_hf_overrides = bool(use_hf_overrides)

        self._instruction = str(
            self._config.get("instruction")
            or "Given a query, score the product for relevance"
        )
        _fmt = str(self._config.get("instruction_format") or "standard").strip().lower()
        if _fmt not in {"standard", "compact"}:
            raise ValueError(
                f"instruction_format must be 'standard' or 'compact', got {_fmt!r}"
            )
        self._instruction_format = _fmt
        self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX)
        self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX)
        self._query_template = str(self._config.get("query_template") or _DEFAULT_QUERY_TEMPLATE)
        self._document_template = str(
            self._config.get("document_template") or _DEFAULT_DOCUMENT_TEMPLATE
        )

        infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get(
            "infer_batch_size", 64
        )
        sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH")
        if sort_by_doc_length is None:
            sort_by_doc_length = self._config.get("sort_by_doc_length", True)

        self._infer_batch_size = int(infer_batch_size)
        self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {
            "1",
            "true",
            "yes",
            "y",
            "on",
        }

        if not torch.cuda.is_available():
            raise RuntimeError(
                "qwen3_vllm_score backend requires CUDA GPU, but torch.cuda.is_available() is False"
            )
        if dtype not in {"float16", "half", "auto"}:
            raise ValueError(
                f"Unsupported dtype for qwen3_vllm_score: {dtype!r}. Use float16/half/auto."
            )
        if self._infer_batch_size <= 0:
            raise ValueError(f"infer_batch_size must be > 0, got {self._infer_batch_size}")

        runner = str(self._config.get("vllm_runner") or "auto").strip().lower()
        convert = str(self._config.get("vllm_convert") or "auto").strip().lower()
        if runner not in {"auto", "generate", "pooling", "draft"}:
            raise ValueError(f"Invalid vllm_runner: {runner!r}")
        if convert not in {"auto", "none", "embed", "classify"}:
            raise ValueError(f"Invalid vllm_convert: {convert!r}")

        logger.info(
            "[Qwen3_VLLM_SCORE] Loading model %s (LLM.score API, runner=%s, convert=%s, "
            "hf_overrides=%s, max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s, "
            "instruction_format=%s)",
            model_name,
            runner,
            convert,
            use_hf_overrides,
            max_model_len,
            tensor_parallel_size,
            gpu_memory_utilization,
            dtype,
            enable_prefix_caching,
            self._instruction_format,
        )

        # vLLM 0.17+ uses runner/convert instead of LLM(..., task="score"). With the official
        # Qwen3 reranker hf_overrides, architecture becomes *ForSequenceClassification -> pooling+classify.
        llm_kwargs: Dict[str, Any] = {
            "model": model_name,
            "runner": runner,
            "convert": convert,
            "tensor_parallel_size": tensor_parallel_size,
            "max_model_len": max_model_len,
            "gpu_memory_utilization": gpu_memory_utilization,
            "enable_prefix_caching": enable_prefix_caching,
            "enforce_eager": enforce_eager,
            "dtype": dtype,
        }
        hf_overrides: Dict[str, Any] = dict(self._config.get("hf_overrides") or {})
        if use_hf_overrides:
            hf_overrides = {
                **hf_overrides,
                "architectures": ["Qwen3ForSequenceClassification"],
                "classifier_from_token": ["no", "yes"],
                "is_original_qwen3_reranker": True,
            }
        if hf_overrides:
            llm_kwargs["hf_overrides"] = hf_overrides

        attn_cfg = _resolve_vllm_attention_config(self._config)
        if attn_cfg is not None:
            llm_kwargs["attention_config"] = attn_cfg

        self._llm = LLM(**llm_kwargs)
        # vLLM score path: single-process safety (mirrors generate backend until verified).
        self._infer_lock = threading.Lock()

        self._model_name = model_name
        logger.info("[Qwen3_VLLM_SCORE] Model ready | model=%s", model_name)

    def _format_pair(self, query: str, doc: str) -> Tuple[str, str]:
        if self._instruction_format == "compact":
            # Align with reranker.backends.qwen3_vllm._format_instruction query/doc split for LLM.score().
            compact_prefix = f"<|im_start|>system\n{self._instruction}{_IM_USER_START}"
            q_text = (
                f"{compact_prefix}<Instruct>: {self._instruction}\n\n<Query>: {query}\n"
            )
            d_text = f"\n<Document>: {doc}{self._suffix}"
            return q_text, d_text
        q_text = self._query_template.format(
            prefix=self._prefix,
            instruction=self._instruction,
            query=query,
        )
        d_text = self._document_template.format(doc=doc, suffix=self._suffix)
        return q_text, d_text

    def _score_batch(self, pairs: List[Tuple[str, str]]) -> List[float]:
        if not pairs:
            return []
        queries: List[str] = []
        documents: List[str] = []
        for q, d in pairs:
            qt, dt = self._format_pair(q, d)
            queries.append(qt)
            documents.append(dt)
        with self._infer_lock:
            outputs = self._llm.score(queries, documents, use_tqdm=False)
        scores: List[float] = []
        for out in outputs:
            so = out.outputs
            scores.append(float(so.score))
        return scores

    @staticmethod
    def _estimate_doc_lengths(docs: List[str]) -> List[int]:
        if not docs:
            return []
        return [len(text) for text in docs]

    def score_with_meta(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
    ) -> Tuple[List[float], Dict[str, Any]]:
        start_ts = time.time()
        total_docs = len(docs) if docs else 0
        output_scores: List[float] = [0.0] * total_docs

        query = "" if query is None else str(query).strip()
        indexed: List[Tuple[int, str]] = []
        for i, doc in enumerate(docs or []):
            if doc is None:
                continue
            text = str(doc).strip()
            if not text:
                continue
            indexed.append((i, text))

        if not query or not indexed:
            elapsed_ms = (time.time() - start_ts) * 1000.0
            return output_scores, {
                "input_docs": total_docs,
                "usable_docs": len(indexed),
                "unique_docs": 0,
                "dedup_ratio": 0.0,
                "elapsed_ms": round(elapsed_ms, 3),
                "model": self._model_name,
                "backend": "qwen3_vllm_score",
                "normalize": normalize,
                "infer_batch_size": self._infer_batch_size,
                "inference_batches": 0,
                "sort_by_doc_length": self._sort_by_doc_length,
                "instruction_format": self._instruction_format,
            }

        indexed_texts = [text for _, text in indexed]
        unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts)

        lengths = self._estimate_doc_lengths(unique_texts)
        order = list(range(len(unique_texts)))
        if self._sort_by_doc_length and len(unique_texts) > 1:
            order = sorted(order, key=lambda i: lengths[i])

        unique_scores: List[float] = [0.0] * len(unique_texts)
        inference_batches = 0
        for start in range(0, len(order), self._infer_batch_size):
            batch_indices = order[start : start + self._infer_batch_size]
            inference_batches += 1
            pairs = [(query, unique_texts[i]) for i in batch_indices]
            batch_scores = self._score_batch(pairs)
            if len(batch_scores) != len(batch_indices):
                raise RuntimeError(
                    f"Reranker score size mismatch: expected {len(batch_indices)}, got {len(batch_scores)}"
                )
            for idx, score in zip(batch_indices, batch_scores):
                unique_scores[idx] = float(score)

        for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
            output_scores[orig_idx] = float(unique_scores[unique_idx])

        elapsed_ms = (time.time() - start_ts) * 1000.0
        dedup_ratio = 0.0
        if indexed:
            dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))

        meta = {
            "input_docs": total_docs,
            "usable_docs": len(indexed),
            "unique_docs": len(unique_texts),
            "dedup_ratio": round(dedup_ratio, 4),
            "elapsed_ms": round(elapsed_ms, 3),
            "model": self._model_name,
            "backend": "qwen3_vllm_score",
            "normalize": normalize,
            "infer_batch_size": self._infer_batch_size,
            "inference_batches": inference_batches,
            "sort_by_doc_length": self._sort_by_doc_length,
            "instruction_format": self._instruction_format,
        }
        return output_scores, meta