qwen3_vllm.py 11.5 KB
"""
Qwen3-Reranker-0.6B backend using vLLM.

Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
Requires: vllm>=0.8.5, transformers; GPU recommended.
"""

from __future__ import annotations

import logging
import math
import os
import threading
import time
from typing import Any, Dict, List, Tuple

logger = logging.getLogger("reranker.backends.qwen3_vllm")

try:
    import torch
    from transformers import AutoTokenizer
    from vllm import LLM, SamplingParams
    from vllm.inputs.data import TokensPrompt
except ImportError as e:
    raise ImportError(
        "Qwen3-vLLM reranker backend requires vllm>=0.8.5 and transformers. "
        "Install with: pip install vllm transformers"
    ) from e


def deduplicate_with_positions(texts: List[str]) -> Tuple[List[str], List[int]]:
    """
    Deduplicate texts globally while preserving first-seen order.

    Returns:
        unique_texts: deduplicated texts in first-seen order
        position_to_unique: mapping from each original position to unique index
    """
    unique_texts: List[str] = []
    position_to_unique: List[int] = []
    seen: Dict[str, int] = {}

    for text in texts:
        idx = seen.get(text)
        if idx is None:
            idx = len(unique_texts)
            seen[text] = idx
            unique_texts.append(text)
        position_to_unique.append(idx)

    return unique_texts, position_to_unique


def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]:
    """Build chat messages for one (query, doc) pair."""
    return [
        {
            "role": "system",
            "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".",
        },
        {
            "role": "user",
            "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}",
        },
    ]


class Qwen3VLLMRerankerBackend:
    """
    Qwen3-Reranker-0.6B with vLLM inference.
    Config from services.rerank.backends.qwen3_vllm.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self._config = config or {}
        model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
        max_model_len = int(self._config.get("max_model_len", 2048))
        tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1))
        gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.4))
        enable_prefix_caching = bool(self._config.get("enable_prefix_caching", False))
        enforce_eager = bool(self._config.get("enforce_eager", True))
        dtype = str(self._config.get("dtype", "float16")).strip().lower()
        self._instruction = str(
            self._config.get("instruction")
            or "Given a shopping query, rank product titles by relevance"
        )
        infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get("infer_batch_size", 64)
        sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH")
        if sort_by_doc_length is None:
            sort_by_doc_length = self._config.get("sort_by_doc_length", True)

        self._infer_batch_size = int(infer_batch_size)
        self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {"1", "true", "yes", "y", "on"}
        if not torch.cuda.is_available():
            raise RuntimeError("qwen3_vllm backend requires CUDA GPU, but torch.cuda.is_available() is False")
        if dtype not in {"float16", "half", "auto"}:
            raise ValueError(f"Unsupported dtype for qwen3_vllm: {dtype!r}. Use float16/half/auto.")
        if self._infer_batch_size <= 0:
            raise ValueError(
                f"infer_batch_size must be > 0, got {self._infer_batch_size}"
            )

        logger.info(
            "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)",
            model_name,
            max_model_len,
            tensor_parallel_size,
            gpu_memory_utilization,
            dtype,
            enable_prefix_caching,
        )

        self._llm = LLM(
            model=model_name,
            tensor_parallel_size=tensor_parallel_size,
            max_model_len=max_model_len,
            gpu_memory_utilization=gpu_memory_utilization,
            enable_prefix_caching=enable_prefix_caching,
            enforce_eager=enforce_eager,
            dtype=dtype,
        )
        self._tokenizer = AutoTokenizer.from_pretrained(model_name)
        self._tokenizer.padding_side = "left"
        self._tokenizer.pad_token = self._tokenizer.eos_token

        # Suffix for generation prompt (assistant answer)
        self._suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self._suffix_tokens = self._tokenizer.encode(
            self._suffix, add_special_tokens=False
        )
        self._max_prompt_len = max_model_len - len(self._suffix_tokens)

        self._true_token = self._tokenizer("yes", add_special_tokens=False).input_ids[0]
        self._false_token = self._tokenizer("no", add_special_tokens=False).input_ids[0]
        self._sampling_params = SamplingParams(
            temperature=0,
            max_tokens=1,
            logprobs=20,
            allowed_token_ids=[self._true_token, self._false_token],
        )
        # vLLM generate path is unstable under concurrent calls in this process model.
        # Serialize infer calls to avoid engine-core protocol corruption.
        self._infer_lock = threading.Lock()

        self._model_name = model_name
        logger.info("[Qwen3_VLLM] Model ready | model=%s", model_name)

    def _process_inputs(
        self,
        pairs: List[Tuple[str, str]],
    ) -> List[TokensPrompt]:
        """Build tokenized prompts for vLLM from (query, doc) pairs. Batch apply_chat_template."""
        messages_batch = [
            _format_instruction(self._instruction, q, d) for q, d in pairs
        ]
        tokenized = self._tokenizer.apply_chat_template(
            messages_batch,
            tokenize=True,
            add_generation_prompt=False,
            enable_thinking=False,
        )
        # Single conv returns flat list; batch returns list of lists
        if tokenized and not isinstance(tokenized[0], list):
            tokenized = [tokenized]
        prompts = [
            TokensPrompt(
                prompt_token_ids=ids[: self._max_prompt_len] + self._suffix_tokens
            )
            for ids in tokenized
        ]
        return prompts

    def _compute_scores(
        self,
        prompts: List[TokensPrompt],
    ) -> List[float]:
        """Run vLLM generate and compute yes/no probability per prompt."""
        if not prompts:
            return []
        outputs = self._llm.generate(prompts, self._sampling_params, use_tqdm=False)
        scores = []
        for i in range(len(outputs)):
            out = outputs[i]
            if not out.outputs:
                scores.append(0.0)
                continue
            final_logits = out.outputs[0].logprobs
            if not final_logits:
                scores.append(0.0)
                continue
            last = final_logits[-1]
            # Match official: missing token -> logprob = -10
            if self._true_token not in last:
                true_logit = -10
            else:
                true_logit = last[self._true_token].logprob
            if self._false_token not in last:
                false_logit = -10
            else:
                false_logit = last[self._false_token].logprob
            true_score = math.exp(true_logit)
            false_score = math.exp(false_logit)
            score = true_score / (true_score + false_score)
            scores.append(float(score))
        return scores

    def _estimate_doc_lengths(self, docs: List[str]) -> List[int]:
        """
        Estimate token lengths for sorting documents into similar-length batches.
        Falls back to character length when tokenizer length output is unavailable.
        """
        if not docs:
            return []
        # Use simple character length to approximate document length.
        return [len(text) for text in docs]

    def score_with_meta(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
    ) -> Tuple[List[float], Dict[str, Any]]:
        start_ts = time.time()
        total_docs = len(docs) if docs else 0
        output_scores: List[float] = [0.0] * total_docs

        query = "" if query is None else str(query).strip()
        indexed: List[Tuple[int, str]] = []
        for i, doc in enumerate(docs or []):
            if doc is None:
                continue
            text = str(doc).strip()
            if not text:
                continue
            indexed.append((i, text))

        if not query or not indexed:
            elapsed_ms = (time.time() - start_ts) * 1000.0
            return output_scores, {
                "input_docs": total_docs,
                "usable_docs": len(indexed),
                "unique_docs": 0,
                "dedup_ratio": 0.0,
                "elapsed_ms": round(elapsed_ms, 3),
                "model": self._model_name,
                "backend": "qwen3_vllm",
                "normalize": normalize,
                "infer_batch_size": self._infer_batch_size,
                "inference_batches": 0,
                "sort_by_doc_length": self._sort_by_doc_length,
            }

        # Deduplicate globally by text, keep mapping to original indices.
        indexed_texts = [text for _, text in indexed]
        unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts)

        lengths = self._estimate_doc_lengths(unique_texts)
        order = list(range(len(unique_texts)))
        if self._sort_by_doc_length and len(unique_texts) > 1:
            order = sorted(order, key=lambda i: lengths[i])

        unique_scores: List[float] = [0.0] * len(unique_texts)
        inference_batches = 0
        for start in range(0, len(order), self._infer_batch_size):
            batch_indices = order[start : start + self._infer_batch_size]
            inference_batches += 1
            pairs = [(query, unique_texts[i]) for i in batch_indices]
            prompts = self._process_inputs(pairs)
            with self._infer_lock:
                batch_scores = self._compute_scores(prompts)
            if len(batch_scores) != len(batch_indices):
                raise RuntimeError(
                    f"Reranker score size mismatch: expected {len(batch_indices)}, got {len(batch_scores)}"
                )
            for idx, score in zip(batch_indices, batch_scores):
                unique_scores[idx] = float(score)

        for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
            # Score is already P(yes) in [0,1] from yes/(yes+no)
            output_scores[orig_idx] = float(unique_scores[unique_idx])

        elapsed_ms = (time.time() - start_ts) * 1000.0
        dedup_ratio = 0.0
        if indexed:
            dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))

        meta = {
            "input_docs": total_docs,
            "usable_docs": len(indexed),
            "unique_docs": len(unique_texts),
            "dedup_ratio": round(dedup_ratio, 4),
            "elapsed_ms": round(elapsed_ms, 3),
            "model": self._model_name,
            "backend": "qwen3_vllm",
            "normalize": normalize,
            "infer_batch_size": self._infer_batch_size,
            "inference_batches": inference_batches,
            "sort_by_doc_length": self._sort_by_doc_length
        }
        return output_scores, meta