""" Qwen3-Reranker backend using packed inference with Transformers. This backend implements the sequence stitching optimization described in Qwen3-Reranker packed inference examples: 1. Share the query/instruction prefix across many documents. 2. Reset document ``position_ids`` relative to the shared prefix. 3. Use a custom causal attention mask so each document can attend to the prefix and itself, but never to other documents. Compared with the standard per-pair batching path, this reduces repeated prefix computation and removes inter-sample padding waste. For online search requests like ``1 query + 400 docs``, the backend further packs documents into multiple chunks under a configurable total token budget. """ from __future__ import annotations import logging import threading import time from typing import Any, Dict, List, Sequence, Tuple import torch from transformers import AutoModelForCausalLM, AutoTokenizer logger = logging.getLogger("reranker.backends.qwen3_transformers_packed") _DEFAULT_PREFIX = ( "<|im_start|>system\n" "Judge whether the Document meets the requirements based on the Query and the Instruct " 'provided. Note that the answer can only be "yes" or "no".' "<|im_end|>\n<|im_start|>user\n" ) _DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" _DEFAULT_PAIR_PREFIX_TEMPLATE = "{prefix}: {instruction}\n: {query}\n: " def _deduplicate_with_positions(texts: Sequence[str]) -> Tuple[List[str], List[int]]: unique_texts: List[str] = [] position_to_unique: List[int] = [] seen: Dict[str, int] = {} for text in texts: idx = seen.get(text) if idx is None: idx = len(unique_texts) seen[text] = idx unique_texts.append(text) position_to_unique.append(idx) return unique_texts, position_to_unique class Qwen3TransformersPackedRerankerBackend: """ Qwen3-Reranker packed inference backend using Transformers. Config from ``services.rerank.backends.qwen3_transformers_packed``. """ def __init__(self, config: Dict[str, Any]) -> None: self._config = config or {} model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B") self._instruction = str( self._config.get("instruction") or "Rank products by query with category & style match prioritized" ) self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX) self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX) self._pair_prefix_template = str( self._config.get("pair_prefix_template") or _DEFAULT_PAIR_PREFIX_TEMPLATE ) max_model_len = int(self._config.get("max_model_len", 4096)) max_doc_len = int(self._config.get("max_doc_len", 160)) max_docs_per_pack = int(self._config.get("max_docs_per_pack", 0)) use_fp16 = bool(self._config.get("use_fp16", True)) device = self._config.get("device") attn_impl = str(self._config.get("attn_implementation") or "eager").strip() sort_by_doc_length = self._config.get("sort_by_doc_length", True) self._model_name = model_name self._max_model_len = max_model_len self._max_doc_len = max_doc_len self._max_docs_per_pack = max_docs_per_pack self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in { "1", "true", "yes", "y", "on", } self._attn_impl = attn_impl logger.info( "[Qwen3_Transformers_Packed] Loading model %s (max_model_len=%s, max_doc_len=%s, " "max_docs_per_pack=%s, fp16=%s, attn_impl=%s)", model_name, max_model_len, max_doc_len, max_docs_per_pack, use_fp16, attn_impl, ) self._tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self._tokenizer.pad_token = self._tokenizer.eos_token self._prefix_tokens = self._tokenizer.encode(self._prefix, add_special_tokens=False) self._suffix_tokens = self._tokenizer.encode(self._suffix, add_special_tokens=False) self._suffix_len = len(self._suffix_tokens) if not torch.cuda.is_available(): raise RuntimeError( "qwen3_transformers_packed backend requires CUDA GPU, " "but torch.cuda.is_available() is False" ) kwargs: Dict[str, Any] = {} if use_fp16: kwargs["torch_dtype"] = torch.float16 if attn_impl: kwargs["attn_implementation"] = attn_impl self._model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).eval() target_device = str(device).strip() if device is not None else "cuda" if not target_device.startswith("cuda"): raise ValueError( "qwen3_transformers_packed backend is GPU-only. " f"Unsupported device setting: {target_device!r}" ) self._model = self._model.to(target_device) self._device = next(self._model.parameters()).device if self._device.type != "cuda": raise RuntimeError( "qwen3_transformers_packed backend failed to place model on CUDA. " f"Current device: {self._device}" ) self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes") self._token_false_id = self._tokenizer.convert_tokens_to_ids("no") if self._token_true_id is None or self._token_false_id is None: raise RuntimeError("Failed to resolve Qwen3 reranker classifier token ids for yes/no") prefix_budget = len(self._prefix_tokens) + self._suffix_len + 1 if self._max_model_len <= prefix_budget: raise ValueError( "max_model_len is too small for packed reranking. " f"Need > {prefix_budget}, got {self._max_model_len}." ) if self._max_doc_len <= 0: raise ValueError(f"max_doc_len must be > 0, got {self._max_doc_len}") if self._max_docs_per_pack < 0: raise ValueError( f"max_docs_per_pack must be >= 0, got {self._max_docs_per_pack}" ) self._infer_lock = threading.Lock() logger.info( "[Qwen3_Transformers_Packed] Model ready | model=%s device=%s", model_name, self._device, ) def _build_pair_prefix_tokens(self, query: str) -> List[int]: pair_prefix = self._pair_prefix_template.format( prefix=self._prefix, instruction=self._instruction, query=query, ) return self._tokenizer.encode(pair_prefix, add_special_tokens=False) def _tokenize_documents(self, docs: Sequence[str], query_prefix_len: int) -> List[List[int]]: max_doc_tokens = min( self._max_doc_len, max(1, self._max_model_len - query_prefix_len - self._suffix_len), ) tokenized = self._tokenizer( list(docs), padding=False, truncation=True, max_length=max_doc_tokens, add_special_tokens=False, return_attention_mask=False, ) return [list(ids) for ids in tokenized["input_ids"]] def _build_pack_plan( self, query_prefix_len: int, doc_tokens: Sequence[Sequence[int]], ) -> List[List[int]]: order = list(range(len(doc_tokens))) if self._sort_by_doc_length and len(order) > 1: order.sort(key=lambda idx: len(doc_tokens[idx])) packs: List[List[int]] = [] current_pack: List[int] = [] current_len = query_prefix_len for idx in order: packed_doc_len = len(doc_tokens[idx]) + self._suffix_len if packed_doc_len <= 0: continue over_docs_cap = self._max_docs_per_pack > 0 and len(current_pack) >= self._max_docs_per_pack over_token_cap = current_pack and (current_len + packed_doc_len > self._max_model_len) if over_docs_cap or over_token_cap: packs.append(current_pack) current_pack = [] current_len = query_prefix_len if query_prefix_len + packed_doc_len > self._max_model_len: raise ValueError( "Packed doc still exceeds max_model_len after truncation. " f"query_prefix_len={query_prefix_len}, doc_len={packed_doc_len}, " f"max_model_len={self._max_model_len}" ) current_pack.append(idx) current_len += packed_doc_len if current_pack: packs.append(current_pack) return packs def _build_pack_inputs( self, query_prefix_tokens: Sequence[int], doc_tokens: Sequence[Sequence[int]], doc_indices: Sequence[int], ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: prefix_len = len(query_prefix_tokens) input_ids_list = list(query_prefix_tokens) position_ids_list = list(range(prefix_len)) spans: List[Tuple[int, int]] = [] current_len = prefix_len for idx in doc_indices: doc_with_suffix = list(doc_tokens[idx]) + self._suffix_tokens start = current_len end = start + len(doc_with_suffix) spans.append((start, end)) input_ids_list.extend(doc_with_suffix) position_ids_list.extend(range(prefix_len, prefix_len + len(doc_with_suffix))) current_len = end total_len = len(input_ids_list) device = self._device neg_inf = torch.finfo(torch.float32).min allowed = torch.zeros((total_len, total_len), dtype=torch.bool, device=device) prefix_causal = torch.tril( torch.ones((prefix_len, prefix_len), dtype=torch.bool, device=device) ) allowed[:prefix_len, :prefix_len] = prefix_causal for start, end in spans: allowed[start:end, :prefix_len] = True doc_len = end - start allowed[start:end, start:end] = torch.tril( torch.ones((doc_len, doc_len), dtype=torch.bool, device=device) ) attention_mask = torch.full( (total_len, total_len), neg_inf, dtype=torch.float32, device=device, ) attention_mask.masked_fill_(allowed, 0.0) inputs = { "input_ids": torch.tensor([input_ids_list], dtype=torch.long, device=device), "position_ids": torch.tensor([position_ids_list], dtype=torch.long, device=device), "attention_mask": attention_mask.view(1, 1, total_len, total_len), } logits_ids = torch.tensor( [end - 1 for _, end in spans], dtype=torch.long, device=device, ) return inputs, logits_ids @torch.no_grad() def _score_pack( self, query_prefix_tokens: Sequence[int], doc_tokens: Sequence[Sequence[int]], doc_indices: Sequence[int], ) -> Tuple[List[float], int]: inputs, logits_ids = self._build_pack_inputs( query_prefix_tokens=query_prefix_tokens, doc_tokens=doc_tokens, doc_indices=doc_indices, ) outputs = self._model(**inputs) scores = outputs.logits[0, logits_ids, :] true_vector = scores[:, self._token_true_id] false_vector = scores[:, self._token_false_id] pair_scores = torch.stack([false_vector, true_vector], dim=1) pair_scores = torch.nn.functional.log_softmax(pair_scores, dim=1) return pair_scores[:, 1].exp().tolist(), int(inputs["input_ids"].shape[1]) def score_with_meta( self, query: str, docs: List[str], normalize: bool = True, ) -> Tuple[List[float], Dict[str, Any]]: start_ts = time.time() total_docs = len(docs) if docs else 0 output_scores: List[float] = [0.0] * total_docs query = "" if query is None else str(query).strip() indexed: List[Tuple[int, str]] = [] for i, doc in enumerate(docs or []): if doc is None: continue text = str(doc).strip() if not text: continue indexed.append((i, text)) if not query or not indexed: elapsed_ms = (time.time() - start_ts) * 1000.0 return output_scores, { "input_docs": total_docs, "usable_docs": len(indexed), "unique_docs": 0, "dedup_ratio": 0.0, "elapsed_ms": round(elapsed_ms, 3), "model": self._model_name, "backend": "qwen3_transformers_packed", "normalize": normalize, "packed_batches": 0, "max_model_len": self._max_model_len, "max_doc_len": self._max_doc_len, "sort_by_doc_length": self._sort_by_doc_length, } indexed_texts = [text for _, text in indexed] unique_texts, position_to_unique = _deduplicate_with_positions(indexed_texts) query_prefix_tokens = self._build_pair_prefix_tokens(query) doc_tokens = self._tokenize_documents(unique_texts, query_prefix_len=len(query_prefix_tokens)) pack_plan = self._build_pack_plan( query_prefix_len=len(query_prefix_tokens), doc_tokens=doc_tokens, ) unique_scores: List[float] = [0.0] * len(unique_texts) pack_lengths: List[int] = [] with self._infer_lock: for pack_doc_indices in pack_plan: batch_scores, pack_seq_len = self._score_pack( query_prefix_tokens=query_prefix_tokens, doc_tokens=doc_tokens, doc_indices=pack_doc_indices, ) if len(batch_scores) != len(pack_doc_indices): raise RuntimeError( "Packed reranker score size mismatch: " f"expected {len(pack_doc_indices)}, got {len(batch_scores)}" ) for idx, score in zip(pack_doc_indices, batch_scores): unique_scores[idx] = float(score) pack_lengths.append(pack_seq_len) for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): output_scores[orig_idx] = float(unique_scores[unique_idx]) elapsed_ms = (time.time() - start_ts) * 1000.0 dedup_ratio = 0.0 if indexed: dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) meta = { "input_docs": total_docs, "usable_docs": len(indexed), "unique_docs": len(unique_texts), "dedup_ratio": round(dedup_ratio, 4), "elapsed_ms": round(elapsed_ms, 3), "model": self._model_name, "backend": "qwen3_transformers_packed", "normalize": normalize, "packed_batches": len(pack_plan), "packed_max_seq_len": max(pack_lengths) if pack_lengths else 0, "packed_avg_seq_len": round(sum(pack_lengths) / len(pack_lengths), 3) if pack_lengths else 0.0, "max_model_len": self._max_model_len, "max_doc_len": self._max_doc_len, "max_docs_per_pack": self._max_docs_per_pack, "sort_by_doc_length": self._sort_by_doc_length, "attn_implementation": self._attn_impl, } return output_scores, meta