""" Qwen3-Reranker-0.6B backend using vLLM. Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B Requires: vllm>=0.8.5, transformers; GPU recommended. """ from __future__ import annotations import logging import math import os import threading import time from typing import Any, Dict, List, Tuple from reranker.backends.batching_utils import ( deduplicate_with_positions, iter_batches, sort_indices_by_length, ) logger = logging.getLogger("reranker.backends.qwen3_vllm") try: import torch from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.inputs.data import TokensPrompt except ImportError as e: raise ImportError( "Qwen3-vLLM reranker backend requires vllm>=0.8.5 and transformers. " "Install with: pip install vllm transformers" ) from e def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]: """Build chat messages for one (query, doc) pair.""" return [ { "role": "system", "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".", }, { "role": "user", "content": f": {instruction}\n\n: {query}\n\n: {doc}", }, ] class Qwen3VLLMRerankerBackend: """ Qwen3-Reranker-0.6B with vLLM inference. Config from services.rerank.backends.qwen3_vllm. """ def __init__(self, config: Dict[str, Any]) -> None: self._config = config or {} model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B") max_model_len = int(self._config.get("max_model_len", 2048)) tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1)) gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.4)) enable_prefix_caching = bool(self._config.get("enable_prefix_caching", False)) enforce_eager = bool(self._config.get("enforce_eager", True)) dtype = str(self._config.get("dtype", "float16")).strip().lower() self._instruction = str( self._config.get("instruction") or "Given a web search query, retrieve relevant passages that answer the query" ) infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get("infer_batch_size", 64) sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH") if sort_by_doc_length is None: sort_by_doc_length = self._config.get("sort_by_doc_length", True) length_sort_mode = os.getenv("RERANK_VLLM_LENGTH_SORT_MODE") or self._config.get("length_sort_mode", "char") self._infer_batch_size = int(infer_batch_size) self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {"1", "true", "yes", "y", "on"} self._length_sort_mode = str(length_sort_mode).strip().lower() if not torch.cuda.is_available(): raise RuntimeError("qwen3_vllm backend requires CUDA GPU, but torch.cuda.is_available() is False") if dtype not in {"float16", "half", "auto"}: raise ValueError(f"Unsupported dtype for qwen3_vllm: {dtype!r}. Use float16/half/auto.") if self._infer_batch_size <= 0: raise ValueError(f"infer_batch_size must be > 0, got {self._infer_batch_size}") if self._length_sort_mode not in {"char", "token"}: raise ValueError(f"length_sort_mode must be 'char' or 'token', got {self._length_sort_mode!r}") logger.info( "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)", model_name, max_model_len, tensor_parallel_size, gpu_memory_utilization, dtype, enable_prefix_caching, ) self._llm = LLM( model=model_name, tensor_parallel_size=tensor_parallel_size, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, enable_prefix_caching=enable_prefix_caching, enforce_eager=enforce_eager, dtype=dtype, ) self._tokenizer = AutoTokenizer.from_pretrained(model_name) self._tokenizer.padding_side = "left" self._tokenizer.pad_token = self._tokenizer.eos_token # Suffix for generation prompt (assistant answer) self._suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" self._suffix_tokens = self._tokenizer.encode( self._suffix, add_special_tokens=False ) self._max_prompt_len = max_model_len - len(self._suffix_tokens) self._true_token = self._tokenizer("yes", add_special_tokens=False).input_ids[0] self._false_token = self._tokenizer("no", add_special_tokens=False).input_ids[0] self._sampling_params = SamplingParams( temperature=0, max_tokens=1, logprobs=20, allowed_token_ids=[self._true_token, self._false_token], ) # vLLM generate path is unstable under concurrent calls in this process model. # Serialize infer calls to avoid engine-core protocol corruption. self._infer_lock = threading.Lock() self._model_name = model_name logger.info("[Qwen3_VLLM] Model ready | model=%s", model_name) def _process_inputs( self, pairs: List[Tuple[str, str]], ) -> List[TokensPrompt]: """Build tokenized prompts for vLLM from (query, doc) pairs. Batch apply_chat_template.""" messages_batch = [ _format_instruction(self._instruction, q, d) for q, d in pairs ] tokenized = self._tokenizer.apply_chat_template( messages_batch, tokenize=True, add_generation_prompt=False, enable_thinking=False, ) # Single conv returns flat list; batch returns list of lists if tokenized and not isinstance(tokenized[0], list): tokenized = [tokenized] prompts = [ TokensPrompt( prompt_token_ids=ids[: self._max_prompt_len] + self._suffix_tokens ) for ids in tokenized ] return prompts def _compute_scores( self, prompts: List[TokensPrompt], ) -> List[float]: """Run vLLM generate and compute yes/no probability per prompt.""" if not prompts: return [] outputs = self._llm.generate(prompts, self._sampling_params, use_tqdm=False) scores = [] for i in range(len(outputs)): out = outputs[i] if not out.outputs: scores.append(0.0) continue final_logits = out.outputs[0].logprobs if not final_logits: scores.append(0.0) continue last = final_logits[-1] # Match official: missing token -> logprob = -10 if self._true_token not in last: true_logit = -10 else: true_logit = last[self._true_token].logprob if self._false_token not in last: false_logit = -10 else: false_logit = last[self._false_token].logprob true_score = math.exp(true_logit) false_score = math.exp(false_logit) score = true_score / (true_score + false_score) scores.append(float(score)) return scores def _estimate_doc_lengths(self, docs: List[str]) -> List[int]: """ Estimate token lengths for sorting documents into similar-length batches. Falls back to character length when tokenizer length output is unavailable. """ if not docs: return [] if self._length_sort_mode == "char": return [len(text) for text in docs] try: enc = self._tokenizer( docs, add_special_tokens=False, truncation=True, max_length=self._max_prompt_len, return_length=True, ) lengths = enc.get("length") if isinstance(lengths, list) and len(lengths) == len(docs): return [int(x) for x in lengths] except Exception as exc: logger.debug("Length estimation fallback to char length: %s", exc) return [len(text) for text in docs] def score_with_meta( self, query: str, docs: List[str], normalize: bool = True, ) -> Tuple[List[float], Dict[str, Any]]: start_ts = time.time() total_docs = len(docs) if docs else 0 output_scores: List[float] = [0.0] * total_docs query = "" if query is None else str(query).strip() indexed: List[Tuple[int, str]] = [] for i, doc in enumerate(docs or []): if doc is None: continue text = str(doc).strip() if not text: continue indexed.append((i, text)) if not query or not indexed: elapsed_ms = (time.time() - start_ts) * 1000.0 return output_scores, { "input_docs": total_docs, "usable_docs": len(indexed), "unique_docs": 0, "dedup_ratio": 0.0, "elapsed_ms": round(elapsed_ms, 3), "model": self._model_name, "backend": "qwen3_vllm", "normalize": normalize, "infer_batch_size": self._infer_batch_size, "inference_batches": 0, "sort_by_doc_length": self._sort_by_doc_length, "length_sort_mode": self._length_sort_mode, } # Deduplicate globally by text, keep mapping to original indices. indexed_texts = [text for _, text in indexed] unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts) lengths = self._estimate_doc_lengths(unique_texts) order = list(range(len(unique_texts))) if self._sort_by_doc_length and len(unique_texts) > 1: order = sort_indices_by_length(lengths) unique_scores: List[float] = [0.0] * len(unique_texts) inference_batches = 0 for batch_indices in iter_batches(order, self._infer_batch_size): inference_batches += 1 pairs = [(query, unique_texts[i]) for i in batch_indices] prompts = self._process_inputs(pairs) with self._infer_lock: batch_scores = self._compute_scores(prompts) if len(batch_scores) != len(batch_indices): raise RuntimeError( f"Reranker score size mismatch: expected {len(batch_indices)}, got {len(batch_scores)}" ) for idx, score in zip(batch_indices, batch_scores): unique_scores[idx] = float(score) for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): # Score is already P(yes) in [0,1] from yes/(yes+no) output_scores[orig_idx] = float(unique_scores[unique_idx]) elapsed_ms = (time.time() - start_ts) * 1000.0 dedup_ratio = 0.0 if indexed: dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) meta = { "input_docs": total_docs, "usable_docs": len(indexed), "unique_docs": len(unique_texts), "dedup_ratio": round(dedup_ratio, 4), "elapsed_ms": round(elapsed_ms, 3), "model": self._model_name, "backend": "qwen3_vllm", "normalize": normalize, "infer_batch_size": self._infer_batch_size, "inference_batches": inference_batches, "sort_by_doc_length": self._sort_by_doc_length, "length_sort_mode": self._length_sort_mode, } return output_scores, meta