""" Qwen3-Reranker-0.6B backend using Transformers (direct usage). No vLLM required. Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B Requires: transformers>=4.51.0, torch. """ from __future__ import annotations import logging import time from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger("reranker.backends.qwen3_transformers") try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError as e: raise ImportError( "Qwen3-Transformers reranker backend requires transformers>=4.51.0 and torch. " "Install with: pip install transformers>=4.51.0 torch" ) from e def _format_instruction(instruction: str, query: str, doc: str) -> str: """Format (query, doc) pair per official Qwen3-Reranker spec.""" return ": {instruction}\n: {query}\n: {doc}".format( instruction=instruction, query=query, doc=doc ) class Qwen3TransformersRerankerBackend: """ Qwen3-Reranker-0.6B with Transformers (AutoModelForCausalLM) inference. Config from services.rerank.backends.qwen3_transformers. No vLLM dependency; lighter than qwen3_vllm, suitable for CPU or small GPU. """ def __init__(self, config: Dict[str, Any]) -> None: self._config = config or {} model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B") self._instruction = str( self._config.get("instruction") or "Given a web search query, retrieve relevant passages that answer the query" ) max_length = int(self._config.get("max_length", 8192)) batch_size = int(self._config.get("batch_size", 64)) use_fp16 = bool(self._config.get("use_fp16", True)) device = self._config.get("device") attn_impl = self._config.get("attn_implementation") # e.g. "flash_attention_2" self._model_name = model_name self._batch_size = batch_size logger.info( "[Qwen3_Transformers] Loading model %s (max_length=%s, batch=%s, fp16=%s)", model_name, max_length, batch_size, use_fp16, ) self._tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self._tokenizer.pad_token = self._tokenizer.eos_token # Prefix/suffix from official reference prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" self._prefix_tokens = self._tokenizer.encode(prefix, add_special_tokens=False) self._suffix_tokens = self._tokenizer.encode(suffix, add_special_tokens=False) self._max_length = max_length self._effective_max_len = max_length - len(self._prefix_tokens) - len(self._suffix_tokens) self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes") self._token_false_id = self._tokenizer.convert_tokens_to_ids("no") kwargs = {} if use_fp16 and torch.cuda.is_available(): kwargs["torch_dtype"] = torch.float16 if attn_impl: kwargs["attn_implementation"] = attn_impl self._model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).eval() if device is not None: self._model = self._model.to(device) elif torch.cuda.is_available(): self._model = self._model.cuda() logger.info( "[Qwen3_Transformers] Model ready | model=%s device=%s", model_name, next(self._model.parameters()).device, ) def _process_inputs(self, pairs: List[str]) -> Dict[str, torch.Tensor]: """Tokenize pairs and add prefix/suffix tokens. Returns batched tensors on model device.""" inputs = self._tokenizer( pairs, padding=False, truncation="longest_first", return_attention_mask=False, max_length=self._effective_max_len, ) for i, ele in enumerate(inputs["input_ids"]): inputs["input_ids"][i] = self._prefix_tokens + ele + self._suffix_tokens inputs = self._tokenizer.pad( inputs, padding=True, return_tensors="pt", ) for key in inputs: inputs[key] = inputs[key].to(self._model.device) return inputs @torch.no_grad() def _compute_scores(self, pairs: List[str]) -> List[float]: """Run forward pass and compute yes/no probability per pair.""" if not pairs: return [] inputs = self._process_inputs(pairs) outputs = self._model(**inputs) batch_scores = outputs.logits[:, -1, :] true_vector = batch_scores[:, self._token_true_id] false_vector = batch_scores[:, self._token_false_id] batch_scores = torch.stack([false_vector, true_vector], dim=1) batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) scores = batch_scores[:, 1].exp().tolist() return scores def score_with_meta( self, query: str, docs: List[str], normalize: bool = True, ) -> Tuple[List[float], Dict[str, Any]]: start_ts = time.time() total_docs = len(docs) if docs else 0 output_scores: List[float] = [0.0] * total_docs query = "" if query is None else str(query).strip() indexed: List[Tuple[int, str]] = [] for i, doc in enumerate(docs or []): if doc is None: continue text = str(doc).strip() if not text: continue indexed.append((i, text)) if not query or not indexed: elapsed_ms = (time.time() - start_ts) * 1000.0 return output_scores, { "input_docs": total_docs, "usable_docs": len(indexed), "unique_docs": 0, "dedup_ratio": 0.0, "elapsed_ms": round(elapsed_ms, 3), "model": self._model_name, "backend": "qwen3_transformers", "normalize": normalize, } # Deduplicate by text, keep mapping to original indices unique_texts: List[str] = [] position_to_unique: List[int] = [] prev: Optional[str] = None for _idx, text in indexed: if text != prev: unique_texts.append(text) prev = text position_to_unique.append(len(unique_texts) - 1) pairs = [ _format_instruction(self._instruction, query, t) for t in unique_texts ] # Batch inference unique_scores: List[float] = [] for i in range(0, len(pairs), self._batch_size): batch = pairs[i : i + self._batch_size] batch_scores = self._compute_scores(batch) unique_scores.extend(batch_scores) for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): output_scores[orig_idx] = float(unique_scores[unique_idx]) elapsed_ms = (time.time() - start_ts) * 1000.0 dedup_ratio = 0.0 if indexed: dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) meta = { "input_docs": total_docs, "usable_docs": len(indexed), "unique_docs": len(unique_texts), "dedup_ratio": round(dedup_ratio, 4), "elapsed_ms": round(elapsed_ms, 3), "model": self._model_name, "backend": "qwen3_transformers", "normalize": normalize, } return output_scores, meta