"""
Qwen3-Reranker backend using packed inference with Transformers.
This backend implements the sequence stitching optimization described in
Qwen3-Reranker packed inference examples:
1. Share the query/instruction prefix across many documents.
2. Reset document ``position_ids`` relative to the shared prefix.
3. Use a custom causal attention mask so each document can attend to the
prefix and itself, but never to other documents.
Compared with the standard per-pair batching path, this reduces repeated
prefix computation and removes inter-sample padding waste. For online search
requests like ``1 query + 400 docs``, the backend further packs documents into
multiple chunks under a configurable total token budget.
"""
from __future__ import annotations
import logging
import threading
import time
from typing import Any, Dict, List, Sequence, Tuple
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
logger = logging.getLogger("reranker.backends.qwen3_transformers_packed")
_DEFAULT_PREFIX = (
"<|im_start|>system\n"
"Judge whether the Document meets the requirements based on the Query and the Instruct "
'provided. Note that the answer can only be "yes" or "no".'
"<|im_end|>\n<|im_start|>user\n"
)
_DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
_DEFAULT_PAIR_PREFIX_TEMPLATE = "{prefix}: {instruction}\n: {query}\n: "
def _deduplicate_with_positions(texts: Sequence[str]) -> Tuple[List[str], List[int]]:
unique_texts: List[str] = []
position_to_unique: List[int] = []
seen: Dict[str, int] = {}
for text in texts:
idx = seen.get(text)
if idx is None:
idx = len(unique_texts)
seen[text] = idx
unique_texts.append(text)
position_to_unique.append(idx)
return unique_texts, position_to_unique
class Qwen3TransformersPackedRerankerBackend:
"""
Qwen3-Reranker packed inference backend using Transformers.
Config from ``services.rerank.backends.qwen3_transformers_packed``.
"""
def __init__(self, config: Dict[str, Any]) -> None:
self._config = config or {}
model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
self._instruction = str(
self._config.get("instruction")
or "Rank products by query with category & style match prioritized"
)
self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX)
self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX)
self._pair_prefix_template = str(
self._config.get("pair_prefix_template") or _DEFAULT_PAIR_PREFIX_TEMPLATE
)
max_model_len = int(self._config.get("max_model_len", 4096))
max_doc_len = int(self._config.get("max_doc_len", 160))
max_docs_per_pack = int(self._config.get("max_docs_per_pack", 0))
use_fp16 = bool(self._config.get("use_fp16", True))
device = self._config.get("device")
attn_impl = str(self._config.get("attn_implementation") or "eager").strip()
sort_by_doc_length = self._config.get("sort_by_doc_length", True)
self._model_name = model_name
self._max_model_len = max_model_len
self._max_doc_len = max_doc_len
self._max_docs_per_pack = max_docs_per_pack
self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {
"1",
"true",
"yes",
"y",
"on",
}
self._attn_impl = attn_impl
logger.info(
"[Qwen3_Transformers_Packed] Loading model %s (max_model_len=%s, max_doc_len=%s, "
"max_docs_per_pack=%s, fp16=%s, attn_impl=%s)",
model_name,
max_model_len,
max_doc_len,
max_docs_per_pack,
use_fp16,
attn_impl,
)
self._tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self._tokenizer.pad_token = self._tokenizer.eos_token
self._prefix_tokens = self._tokenizer.encode(self._prefix, add_special_tokens=False)
self._suffix_tokens = self._tokenizer.encode(self._suffix, add_special_tokens=False)
self._suffix_len = len(self._suffix_tokens)
if not torch.cuda.is_available():
raise RuntimeError(
"qwen3_transformers_packed backend requires CUDA GPU, "
"but torch.cuda.is_available() is False"
)
kwargs: Dict[str, Any] = {}
if use_fp16:
kwargs["torch_dtype"] = torch.float16
if attn_impl:
kwargs["attn_implementation"] = attn_impl
self._model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).eval()
target_device = str(device).strip() if device is not None else "cuda"
if not target_device.startswith("cuda"):
raise ValueError(
"qwen3_transformers_packed backend is GPU-only. "
f"Unsupported device setting: {target_device!r}"
)
self._model = self._model.to(target_device)
self._device = next(self._model.parameters()).device
if self._device.type != "cuda":
raise RuntimeError(
"qwen3_transformers_packed backend failed to place model on CUDA. "
f"Current device: {self._device}"
)
self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes")
self._token_false_id = self._tokenizer.convert_tokens_to_ids("no")
if self._token_true_id is None or self._token_false_id is None:
raise RuntimeError("Failed to resolve Qwen3 reranker classifier token ids for yes/no")
prefix_budget = len(self._prefix_tokens) + self._suffix_len + 1
if self._max_model_len <= prefix_budget:
raise ValueError(
"max_model_len is too small for packed reranking. "
f"Need > {prefix_budget}, got {self._max_model_len}."
)
if self._max_doc_len <= 0:
raise ValueError(f"max_doc_len must be > 0, got {self._max_doc_len}")
if self._max_docs_per_pack < 0:
raise ValueError(
f"max_docs_per_pack must be >= 0, got {self._max_docs_per_pack}"
)
self._infer_lock = threading.Lock()
logger.info(
"[Qwen3_Transformers_Packed] Model ready | model=%s device=%s",
model_name,
self._device,
)
def _build_pair_prefix_tokens(self, query: str) -> List[int]:
pair_prefix = self._pair_prefix_template.format(
prefix=self._prefix,
instruction=self._instruction,
query=query,
)
return self._tokenizer.encode(pair_prefix, add_special_tokens=False)
def _tokenize_documents(self, docs: Sequence[str], query_prefix_len: int) -> List[List[int]]:
max_doc_tokens = min(
self._max_doc_len,
max(1, self._max_model_len - query_prefix_len - self._suffix_len),
)
tokenized = self._tokenizer(
list(docs),
padding=False,
truncation=True,
max_length=max_doc_tokens,
add_special_tokens=False,
return_attention_mask=False,
)
return [list(ids) for ids in tokenized["input_ids"]]
def _build_pack_plan(
self,
query_prefix_len: int,
doc_tokens: Sequence[Sequence[int]],
) -> List[List[int]]:
order = list(range(len(doc_tokens)))
if self._sort_by_doc_length and len(order) > 1:
order.sort(key=lambda idx: len(doc_tokens[idx]))
packs: List[List[int]] = []
current_pack: List[int] = []
current_len = query_prefix_len
for idx in order:
packed_doc_len = len(doc_tokens[idx]) + self._suffix_len
if packed_doc_len <= 0:
continue
over_docs_cap = self._max_docs_per_pack > 0 and len(current_pack) >= self._max_docs_per_pack
over_token_cap = current_pack and (current_len + packed_doc_len > self._max_model_len)
if over_docs_cap or over_token_cap:
packs.append(current_pack)
current_pack = []
current_len = query_prefix_len
if query_prefix_len + packed_doc_len > self._max_model_len:
raise ValueError(
"Packed doc still exceeds max_model_len after truncation. "
f"query_prefix_len={query_prefix_len}, doc_len={packed_doc_len}, "
f"max_model_len={self._max_model_len}"
)
current_pack.append(idx)
current_len += packed_doc_len
if current_pack:
packs.append(current_pack)
return packs
def _build_pack_inputs(
self,
query_prefix_tokens: Sequence[int],
doc_tokens: Sequence[Sequence[int]],
doc_indices: Sequence[int],
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
prefix_len = len(query_prefix_tokens)
input_ids_list = list(query_prefix_tokens)
position_ids_list = list(range(prefix_len))
spans: List[Tuple[int, int]] = []
current_len = prefix_len
for idx in doc_indices:
doc_with_suffix = list(doc_tokens[idx]) + self._suffix_tokens
start = current_len
end = start + len(doc_with_suffix)
spans.append((start, end))
input_ids_list.extend(doc_with_suffix)
position_ids_list.extend(range(prefix_len, prefix_len + len(doc_with_suffix)))
current_len = end
total_len = len(input_ids_list)
device = self._device
neg_inf = torch.finfo(torch.float32).min
allowed = torch.zeros((total_len, total_len), dtype=torch.bool, device=device)
prefix_causal = torch.tril(
torch.ones((prefix_len, prefix_len), dtype=torch.bool, device=device)
)
allowed[:prefix_len, :prefix_len] = prefix_causal
for start, end in spans:
allowed[start:end, :prefix_len] = True
doc_len = end - start
allowed[start:end, start:end] = torch.tril(
torch.ones((doc_len, doc_len), dtype=torch.bool, device=device)
)
attention_mask = torch.full(
(total_len, total_len),
neg_inf,
dtype=torch.float32,
device=device,
)
attention_mask.masked_fill_(allowed, 0.0)
inputs = {
"input_ids": torch.tensor([input_ids_list], dtype=torch.long, device=device),
"position_ids": torch.tensor([position_ids_list], dtype=torch.long, device=device),
"attention_mask": attention_mask.view(1, 1, total_len, total_len),
}
logits_ids = torch.tensor(
[end - 1 for _, end in spans],
dtype=torch.long,
device=device,
)
return inputs, logits_ids
@torch.no_grad()
def _score_pack(
self,
query_prefix_tokens: Sequence[int],
doc_tokens: Sequence[Sequence[int]],
doc_indices: Sequence[int],
) -> Tuple[List[float], int]:
inputs, logits_ids = self._build_pack_inputs(
query_prefix_tokens=query_prefix_tokens,
doc_tokens=doc_tokens,
doc_indices=doc_indices,
)
outputs = self._model(**inputs)
scores = outputs.logits[0, logits_ids, :]
true_vector = scores[:, self._token_true_id]
false_vector = scores[:, self._token_false_id]
pair_scores = torch.stack([false_vector, true_vector], dim=1)
pair_scores = torch.nn.functional.log_softmax(pair_scores, dim=1)
return pair_scores[:, 1].exp().tolist(), int(inputs["input_ids"].shape[1])
def score_with_meta(
self,
query: str,
docs: List[str],
normalize: bool = True,
) -> Tuple[List[float], Dict[str, Any]]:
start_ts = time.time()
total_docs = len(docs) if docs else 0
output_scores: List[float] = [0.0] * total_docs
query = "" if query is None else str(query).strip()
indexed: List[Tuple[int, str]] = []
for i, doc in enumerate(docs or []):
if doc is None:
continue
text = str(doc).strip()
if not text:
continue
indexed.append((i, text))
if not query or not indexed:
elapsed_ms = (time.time() - start_ts) * 1000.0
return output_scores, {
"input_docs": total_docs,
"usable_docs": len(indexed),
"unique_docs": 0,
"dedup_ratio": 0.0,
"elapsed_ms": round(elapsed_ms, 3),
"model": self._model_name,
"backend": "qwen3_transformers_packed",
"normalize": normalize,
"packed_batches": 0,
"max_model_len": self._max_model_len,
"max_doc_len": self._max_doc_len,
"sort_by_doc_length": self._sort_by_doc_length,
}
indexed_texts = [text for _, text in indexed]
unique_texts, position_to_unique = _deduplicate_with_positions(indexed_texts)
query_prefix_tokens = self._build_pair_prefix_tokens(query)
doc_tokens = self._tokenize_documents(unique_texts, query_prefix_len=len(query_prefix_tokens))
pack_plan = self._build_pack_plan(
query_prefix_len=len(query_prefix_tokens),
doc_tokens=doc_tokens,
)
unique_scores: List[float] = [0.0] * len(unique_texts)
pack_lengths: List[int] = []
with self._infer_lock:
for pack_doc_indices in pack_plan:
batch_scores, pack_seq_len = self._score_pack(
query_prefix_tokens=query_prefix_tokens,
doc_tokens=doc_tokens,
doc_indices=pack_doc_indices,
)
if len(batch_scores) != len(pack_doc_indices):
raise RuntimeError(
"Packed reranker score size mismatch: "
f"expected {len(pack_doc_indices)}, got {len(batch_scores)}"
)
for idx, score in zip(pack_doc_indices, batch_scores):
unique_scores[idx] = float(score)
pack_lengths.append(pack_seq_len)
for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
output_scores[orig_idx] = float(unique_scores[unique_idx])
elapsed_ms = (time.time() - start_ts) * 1000.0
dedup_ratio = 0.0
if indexed:
dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))
meta = {
"input_docs": total_docs,
"usable_docs": len(indexed),
"unique_docs": len(unique_texts),
"dedup_ratio": round(dedup_ratio, 4),
"elapsed_ms": round(elapsed_ms, 3),
"model": self._model_name,
"backend": "qwen3_transformers_packed",
"normalize": normalize,
"packed_batches": len(pack_plan),
"packed_max_seq_len": max(pack_lengths) if pack_lengths else 0,
"packed_avg_seq_len": round(sum(pack_lengths) / len(pack_lengths), 3)
if pack_lengths
else 0.0,
"max_model_len": self._max_model_len,
"max_doc_len": self._max_doc_len,
"max_docs_per_pack": self._max_docs_per_pack,
"sort_by_doc_length": self._sort_by_doc_length,
"attn_implementation": self._attn_impl,
}
return output_scores, meta