qwen3_vllm.py 8.79 KB
"""
Qwen3-Reranker-0.6B backend using vLLM.

Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
Requires: vllm>=0.8.5, transformers; GPU recommended.
"""

from __future__ import annotations

import logging
import math
import threading
import time
from typing import Any, Dict, List, Optional, Tuple

logger = logging.getLogger("reranker.backends.qwen3_vllm")

try:
    import torch
    from transformers import AutoTokenizer
    from vllm import LLM, SamplingParams
    from vllm.inputs.data import TokensPrompt
except ImportError as e:
    raise ImportError(
        "Qwen3-vLLM reranker backend requires vllm>=0.8.5 and transformers. "
        "Install with: pip install vllm transformers"
    ) from e


def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]:
    """Build chat messages for one (query, doc) pair."""
    return [
        {
            "role": "system",
            "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".",
        },
        {
            "role": "user",
            "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}",
        },
    ]


class Qwen3VLLMRerankerBackend:
    """
    Qwen3-Reranker-0.6B with vLLM inference.
    Config from services.rerank.backends.qwen3_vllm.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self._config = config or {}
        model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
        max_model_len = int(self._config.get("max_model_len", 2048))
        tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1))
        gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.4))
        enable_prefix_caching = bool(self._config.get("enable_prefix_caching", False))
        enforce_eager = bool(self._config.get("enforce_eager", True))
        dtype = str(self._config.get("dtype", "float16")).strip().lower()
        self._instruction = str(
            self._config.get("instruction")
            or "Given a web search query, retrieve relevant passages that answer the query"
        )
        if not torch.cuda.is_available():
            raise RuntimeError("qwen3_vllm backend requires CUDA GPU, but torch.cuda.is_available() is False")
        if dtype not in {"float16", "half", "auto"}:
            raise ValueError(f"Unsupported dtype for qwen3_vllm: {dtype!r}. Use float16/half/auto.")

        logger.info(
            "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)",
            model_name,
            max_model_len,
            tensor_parallel_size,
            gpu_memory_utilization,
            dtype,
            enable_prefix_caching,
        )

        self._llm = LLM(
            model=model_name,
            tensor_parallel_size=tensor_parallel_size,
            max_model_len=max_model_len,
            gpu_memory_utilization=gpu_memory_utilization,
            enable_prefix_caching=enable_prefix_caching,
            enforce_eager=enforce_eager,
            dtype=dtype,
        )
        self._tokenizer = AutoTokenizer.from_pretrained(model_name)
        self._tokenizer.padding_side = "left"
        self._tokenizer.pad_token = self._tokenizer.eos_token

        # Suffix for generation prompt (assistant answer)
        self._suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self._suffix_tokens = self._tokenizer.encode(
            self._suffix, add_special_tokens=False
        )
        self._max_prompt_len = max_model_len - len(self._suffix_tokens)

        self._true_token = self._tokenizer("yes", add_special_tokens=False).input_ids[0]
        self._false_token = self._tokenizer("no", add_special_tokens=False).input_ids[0]
        self._sampling_params = SamplingParams(
            temperature=0,
            max_tokens=1,
            logprobs=20,
            allowed_token_ids=[self._true_token, self._false_token],
        )
        # vLLM generate path is unstable under concurrent calls in this process model.
        # Serialize infer calls to avoid engine-core protocol corruption.
        self._infer_lock = threading.Lock()

        self._model_name = model_name
        logger.info("[Qwen3_VLLM] Model ready | model=%s", model_name)

    def _process_inputs(
        self,
        pairs: List[Tuple[str, str]],
    ) -> List[TokensPrompt]:
        """Build tokenized prompts for vLLM from (query, doc) pairs. Batch apply_chat_template."""
        messages_batch = [
            _format_instruction(self._instruction, q, d) for q, d in pairs
        ]
        tokenized = self._tokenizer.apply_chat_template(
            messages_batch,
            tokenize=True,
            add_generation_prompt=False,
            enable_thinking=False,
        )
        # Single conv returns flat list; batch returns list of lists
        if tokenized and not isinstance(tokenized[0], list):
            tokenized = [tokenized]
        prompts = [
            TokensPrompt(
                prompt_token_ids=ids[: self._max_prompt_len] + self._suffix_tokens
            )
            for ids in tokenized
        ]
        return prompts

    def _compute_scores(
        self,
        prompts: List[TokensPrompt],
    ) -> List[float]:
        """Run vLLM generate and compute yes/no probability per prompt."""
        if not prompts:
            return []
        outputs = self._llm.generate(prompts, self._sampling_params, use_tqdm=False)
        scores = []
        for i in range(len(outputs)):
            out = outputs[i]
            if not out.outputs:
                scores.append(0.0)
                continue
            final_logits = out.outputs[0].logprobs
            if not final_logits:
                scores.append(0.0)
                continue
            last = final_logits[-1]
            # Match official: missing token -> logprob = -10
            if self._true_token not in last:
                true_logit = -10
            else:
                true_logit = last[self._true_token].logprob
            if self._false_token not in last:
                false_logit = -10
            else:
                false_logit = last[self._false_token].logprob
            true_score = math.exp(true_logit)
            false_score = math.exp(false_logit)
            score = true_score / (true_score + false_score)
            scores.append(float(score))
        return scores

    def score_with_meta(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
    ) -> Tuple[List[float], Dict[str, Any]]:
        start_ts = time.time()
        total_docs = len(docs) if docs else 0
        output_scores: List[float] = [0.0] * total_docs

        query = "" if query is None else str(query).strip()
        indexed: List[Tuple[int, str]] = []
        for i, doc in enumerate(docs or []):
            if doc is None:
                continue
            text = str(doc).strip()
            if not text:
                continue
            indexed.append((i, text))

        if not query or not indexed:
            elapsed_ms = (time.time() - start_ts) * 1000.0
            return output_scores, {
                "input_docs": total_docs,
                "usable_docs": len(indexed),
                "unique_docs": 0,
                "dedup_ratio": 0.0,
                "elapsed_ms": round(elapsed_ms, 3),
                "model": self._model_name,
                "backend": "qwen3_vllm",
                "normalize": normalize,
            }

        # Deduplicate by text, keep mapping to original indices
        unique_texts: List[str] = []
        position_to_unique: List[int] = []
        prev: Optional[str] = None
        for _idx, text in indexed:
            if text != prev:
                unique_texts.append(text)
                prev = text
            position_to_unique.append(len(unique_texts) - 1)

        pairs = [(query, t) for t in unique_texts]
        with self._infer_lock:
            prompts = self._process_inputs(pairs)
            unique_scores = self._compute_scores(prompts)

        for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
            # Score is already P(yes) in [0,1] from yes/(yes+no)
            output_scores[orig_idx] = float(unique_scores[unique_idx])

        elapsed_ms = (time.time() - start_ts) * 1000.0
        dedup_ratio = 0.0
        if indexed:
            dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))

        meta = {
            "input_docs": total_docs,
            "usable_docs": len(indexed),
            "unique_docs": len(unique_texts),
            "dedup_ratio": round(dedup_ratio, 4),
            "elapsed_ms": round(elapsed_ms, 3),
            "model": self._model_name,
            "backend": "qwen3_vllm",
            "normalize": normalize,
        }
        return output_scores, meta