""" Qwen3-Reranker-0.6B backend using vLLM. Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B Requires: vllm>=0.8.5, transformers; GPU recommended. """ from __future__ import annotations import logging import math import time from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger("reranker.backends.qwen3_vllm") try: import torch from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.inputs import TokensPrompt except ImportError as e: raise ImportError( "Qwen3-vLLM reranker backend requires vllm>=0.8.5 and transformers. " "Install with: pip install vllm transformers" ) from e def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]: """Build chat messages for one (query, doc) pair.""" return [ { "role": "system", "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".", }, { "role": "user", "content": f": {instruction}\n\n: {query}\n\n: {doc}", }, ] class Qwen3VLLMRerankerBackend: """ Qwen3-Reranker-0.6B with vLLM inference. Config from services.rerank.backends.qwen3_vllm. """ def __init__(self, config: Dict[str, Any]) -> None: self._config = config or {} model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B") max_model_len = int(self._config.get("max_model_len", 8192)) tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1)) gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.8)) enable_prefix_caching = bool(self._config.get("enable_prefix_caching", True)) self._instruction = str( self._config.get("instruction") or "Given a web search query, retrieve relevant passages that answer the query" ) logger.info( "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, prefix_caching=%s)", model_name, max_model_len, tensor_parallel_size, enable_prefix_caching, ) self._llm = LLM( model=model_name, tensor_parallel_size=tensor_parallel_size, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, enable_prefix_caching=enable_prefix_caching, ) self._tokenizer = AutoTokenizer.from_pretrained(model_name) self._tokenizer.padding_side = "left" self._tokenizer.pad_token = self._tokenizer.eos_token # Suffix for generation prompt (assistant answer) self._suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" self._suffix_tokens = self._tokenizer.encode( self._suffix, add_special_tokens=False ) self._max_prompt_len = max_model_len - len(self._suffix_tokens) self._true_token = self._tokenizer("yes", add_special_tokens=False).input_ids[0] self._false_token = self._tokenizer("no", add_special_tokens=False).input_ids[0] self._sampling_params = SamplingParams( temperature=0, max_tokens=1, logprobs=20, allowed_token_ids=[self._true_token, self._false_token], ) self._model_name = model_name logger.info("[Qwen3_VLLM] Model ready | model=%s", model_name) def _process_inputs( self, pairs: List[Tuple[str, str]], ) -> List[TokensPrompt]: """Build tokenized prompts for vLLM from (query, doc) pairs.""" prompts = [] for q, d in pairs: messages = _format_instruction(self._instruction, q, d) # One conversation per call (apply_chat_template expects single conversation) token_ids = self._tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=False, enable_thinking=False, ) if isinstance(token_ids, list) and token_ids and isinstance(token_ids[0], list): token_ids = token_ids[0] ids = token_ids[: self._max_prompt_len] + self._suffix_tokens prompts.append(TokensPrompt(prompt_token_ids=ids)) return prompts def _compute_scores( self, prompts: List[TokensPrompt], ) -> List[float]: """Run vLLM generate and compute yes/no probability per prompt.""" if not prompts: return [] outputs = self._llm.generate(prompts, self._sampling_params, use_tqdm=False) scores = [] for i in range(len(outputs)): out = outputs[i] if not out.outputs: scores.append(0.0) continue logprobs = out.outputs[0].logprobs if not logprobs: scores.append(0.0) continue last = logprobs[-1] true_logp = last.get(self._true_token) false_logp = last.get(self._false_token) true_p = math.exp(true_logp.logprob) if true_logp else 1e-10 false_p = math.exp(false_logp.logprob) if false_logp else 1e-10 score = true_p / (true_p + false_p) scores.append(float(score)) return scores def score_with_meta( self, query: str, docs: List[str], normalize: bool = True, ) -> Tuple[List[float], Dict[str, Any]]: start_ts = time.time() total_docs = len(docs) if docs else 0 output_scores: List[float] = [0.0] * total_docs query = "" if query is None else str(query).strip() indexed: List[Tuple[int, str]] = [] for i, doc in enumerate(docs or []): if doc is None: continue text = str(doc).strip() if not text: continue indexed.append((i, text)) if not query or not indexed: elapsed_ms = (time.time() - start_ts) * 1000.0 return output_scores, { "input_docs": total_docs, "usable_docs": len(indexed), "unique_docs": 0, "dedup_ratio": 0.0, "elapsed_ms": round(elapsed_ms, 3), "model": self._model_name, "backend": "qwen3_vllm", "normalize": normalize, } # Deduplicate by text, keep mapping to original indices unique_texts: List[str] = [] position_to_unique: List[int] = [] prev: Optional[str] = None for _idx, text in indexed: if text != prev: unique_texts.append(text) prev = text position_to_unique.append(len(unique_texts) - 1) pairs = [(query, t) for t in unique_texts] prompts = self._process_inputs(pairs) unique_scores = self._compute_scores(prompts) for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): # Score is already P(yes) in [0,1] from yes/(yes+no) output_scores[orig_idx] = float(unique_scores[unique_idx]) elapsed_ms = (time.time() - start_ts) * 1000.0 dedup_ratio = 0.0 if indexed: dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) meta = { "input_docs": total_docs, "usable_docs": len(indexed), "unique_docs": len(unique_texts), "dedup_ratio": round(dedup_ratio, 4), "elapsed_ms": round(elapsed_ms, 3), "model": self._model_name, "backend": "qwen3_vllm", "normalize": normalize, } return output_scores, meta