qwen3_vllm.py 7.72 KB
"""
Qwen3-Reranker-0.6B backend using vLLM.

Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
Requires: vllm>=0.8.5, transformers; GPU recommended.
"""

from __future__ import annotations

import logging
import math
import time
from typing import Any, Dict, List, Optional, Tuple

logger = logging.getLogger("reranker.backends.qwen3_vllm")

try:
    import torch
    from transformers import AutoTokenizer
    from vllm import LLM, SamplingParams
    from vllm.inputs import TokensPrompt
except ImportError as e:
    raise ImportError(
        "Qwen3-vLLM reranker backend requires vllm>=0.8.5 and transformers. "
        "Install with: pip install vllm transformers"
    ) from e


def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]:
    """Build chat messages for one (query, doc) pair."""
    return [
        {
            "role": "system",
            "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".",
        },
        {
            "role": "user",
            "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}",
        },
    ]


class Qwen3VLLMRerankerBackend:
    """
    Qwen3-Reranker-0.6B with vLLM inference.
    Config from services.rerank.backends.qwen3_vllm.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self._config = config or {}
        model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
        max_model_len = int(self._config.get("max_model_len", 8192))
        tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1))
        gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.8))
        enable_prefix_caching = bool(self._config.get("enable_prefix_caching", True))
        self._instruction = str(
            self._config.get("instruction")
            or "Given a web search query, retrieve relevant passages that answer the query"
        )

        logger.info(
            "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, prefix_caching=%s)",
            model_name,
            max_model_len,
            tensor_parallel_size,
            enable_prefix_caching,
        )

        self._llm = LLM(
            model=model_name,
            tensor_parallel_size=tensor_parallel_size,
            max_model_len=max_model_len,
            gpu_memory_utilization=gpu_memory_utilization,
            enable_prefix_caching=enable_prefix_caching,
        )
        self._tokenizer = AutoTokenizer.from_pretrained(model_name)
        self._tokenizer.padding_side = "left"
        self._tokenizer.pad_token = self._tokenizer.eos_token

        # Suffix for generation prompt (assistant answer)
        self._suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self._suffix_tokens = self._tokenizer.encode(
            self._suffix, add_special_tokens=False
        )
        self._max_prompt_len = max_model_len - len(self._suffix_tokens)

        self._true_token = self._tokenizer("yes", add_special_tokens=False).input_ids[0]
        self._false_token = self._tokenizer("no", add_special_tokens=False).input_ids[0]
        self._sampling_params = SamplingParams(
            temperature=0,
            max_tokens=1,
            logprobs=20,
            allowed_token_ids=[self._true_token, self._false_token],
        )

        self._model_name = model_name
        logger.info("[Qwen3_VLLM] Model ready | model=%s", model_name)

    def _process_inputs(
        self,
        pairs: List[Tuple[str, str]],
    ) -> List[TokensPrompt]:
        """Build tokenized prompts for vLLM from (query, doc) pairs."""
        prompts = []
        for q, d in pairs:
            messages = _format_instruction(self._instruction, q, d)
            # One conversation per call (apply_chat_template expects single conversation)
            token_ids = self._tokenizer.apply_chat_template(
                messages,
                tokenize=True,
                add_generation_prompt=False,
                enable_thinking=False,
            )
            if isinstance(token_ids, list) and token_ids and isinstance(token_ids[0], list):
                token_ids = token_ids[0]
            ids = token_ids[: self._max_prompt_len] + self._suffix_tokens
            prompts.append(TokensPrompt(prompt_token_ids=ids))
        return prompts

    def _compute_scores(
        self,
        prompts: List[TokensPrompt],
    ) -> List[float]:
        """Run vLLM generate and compute yes/no probability per prompt."""
        if not prompts:
            return []
        outputs = self._llm.generate(prompts, self._sampling_params, use_tqdm=False)
        scores = []
        for i in range(len(outputs)):
            out = outputs[i]
            if not out.outputs:
                scores.append(0.0)
                continue
            logprobs = out.outputs[0].logprobs
            if not logprobs:
                scores.append(0.0)
                continue
            last = logprobs[-1]
            true_logp = last.get(self._true_token)
            false_logp = last.get(self._false_token)
            true_p = math.exp(true_logp.logprob) if true_logp else 1e-10
            false_p = math.exp(false_logp.logprob) if false_logp else 1e-10
            score = true_p / (true_p + false_p)
            scores.append(float(score))
        return scores

    def score_with_meta(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
    ) -> Tuple[List[float], Dict[str, Any]]:
        start_ts = time.time()
        total_docs = len(docs) if docs else 0
        output_scores: List[float] = [0.0] * total_docs

        query = "" if query is None else str(query).strip()
        indexed: List[Tuple[int, str]] = []
        for i, doc in enumerate(docs or []):
            if doc is None:
                continue
            text = str(doc).strip()
            if not text:
                continue
            indexed.append((i, text))

        if not query or not indexed:
            elapsed_ms = (time.time() - start_ts) * 1000.0
            return output_scores, {
                "input_docs": total_docs,
                "usable_docs": len(indexed),
                "unique_docs": 0,
                "dedup_ratio": 0.0,
                "elapsed_ms": round(elapsed_ms, 3),
                "model": self._model_name,
                "backend": "qwen3_vllm",
                "normalize": normalize,
            }

        # Deduplicate by text, keep mapping to original indices
        unique_texts: List[str] = []
        position_to_unique: List[int] = []
        prev: Optional[str] = None
        for _idx, text in indexed:
            if text != prev:
                unique_texts.append(text)
                prev = text
            position_to_unique.append(len(unique_texts) - 1)

        pairs = [(query, t) for t in unique_texts]
        prompts = self._process_inputs(pairs)
        unique_scores = self._compute_scores(prompts)

        for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
            # Score is already P(yes) in [0,1] from yes/(yes+no)
            output_scores[orig_idx] = float(unique_scores[unique_idx])

        elapsed_ms = (time.time() - start_ts) * 1000.0
        dedup_ratio = 0.0
        if indexed:
            dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))

        meta = {
            "input_docs": total_docs,
            "usable_docs": len(indexed),
            "unique_docs": len(unique_texts),
            "dedup_ratio": round(dedup_ratio, 4),
            "elapsed_ms": round(elapsed_ms, 3),
            "model": self._model_name,
            "backend": "qwen3_vllm",
            "normalize": normalize,
        }
        return output_scores, meta