qwen3_transformers_packed.py 15.5 KB
"""
Qwen3-Reranker backend using packed inference with Transformers.

This backend implements the sequence stitching optimization described in
Qwen3-Reranker packed inference examples:
1. Share the query/instruction prefix across many documents.
2. Reset document ``position_ids`` relative to the shared prefix.
3. Use a custom causal attention mask so each document can attend to the
   prefix and itself, but never to other documents.

Compared with the standard per-pair batching path, this reduces repeated
prefix computation and removes inter-sample padding waste. For online search
requests like ``1 query + 400 docs``, the backend further packs documents into
multiple chunks under a configurable total token budget.
"""

from __future__ import annotations

import logging
import threading
import time
from typing import Any, Dict, List, Sequence, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = logging.getLogger("reranker.backends.qwen3_transformers_packed")

_DEFAULT_PREFIX = (
    "<|im_start|>system\n"
    "Judge whether the Document meets the requirements based on the Query and the Instruct "
    'provided. Note that the answer can only be "yes" or "no".'
    "<|im_end|>\n<|im_start|>user\n"
)
_DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
_DEFAULT_PAIR_PREFIX_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n<Document>: "


def _deduplicate_with_positions(texts: Sequence[str]) -> Tuple[List[str], List[int]]:
    unique_texts: List[str] = []
    position_to_unique: List[int] = []
    seen: Dict[str, int] = {}

    for text in texts:
        idx = seen.get(text)
        if idx is None:
            idx = len(unique_texts)
            seen[text] = idx
            unique_texts.append(text)
        position_to_unique.append(idx)

    return unique_texts, position_to_unique


class Qwen3TransformersPackedRerankerBackend:
    """
    Qwen3-Reranker packed inference backend using Transformers.

    Config from ``services.rerank.backends.qwen3_transformers_packed``.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self._config = config or {}
        model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
        self._instruction = str(
            self._config.get("instruction")
            or "Rank products by query with category & style match prioritized"
        )
        self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX)
        self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX)
        self._pair_prefix_template = str(
            self._config.get("pair_prefix_template") or _DEFAULT_PAIR_PREFIX_TEMPLATE
        )

        max_model_len = int(self._config.get("max_model_len", 4096))
        max_doc_len = int(self._config.get("max_doc_len", 160))
        max_docs_per_pack = int(self._config.get("max_docs_per_pack", 0))
        use_fp16 = bool(self._config.get("use_fp16", True))
        device = self._config.get("device")
        attn_impl = str(self._config.get("attn_implementation") or "eager").strip()
        sort_by_doc_length = self._config.get("sort_by_doc_length", True)

        self._model_name = model_name
        self._max_model_len = max_model_len
        self._max_doc_len = max_doc_len
        self._max_docs_per_pack = max_docs_per_pack
        self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {
            "1",
            "true",
            "yes",
            "y",
            "on",
        }
        self._attn_impl = attn_impl

        logger.info(
            "[Qwen3_Transformers_Packed] Loading model %s (max_model_len=%s, max_doc_len=%s, "
            "max_docs_per_pack=%s, fp16=%s, attn_impl=%s)",
            model_name,
            max_model_len,
            max_doc_len,
            max_docs_per_pack,
            use_fp16,
            attn_impl,
        )

        self._tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self._tokenizer.pad_token = self._tokenizer.eos_token

        self._prefix_tokens = self._tokenizer.encode(self._prefix, add_special_tokens=False)
        self._suffix_tokens = self._tokenizer.encode(self._suffix, add_special_tokens=False)
        self._suffix_len = len(self._suffix_tokens)

        if not torch.cuda.is_available():
            raise RuntimeError(
                "qwen3_transformers_packed backend requires CUDA GPU, "
                "but torch.cuda.is_available() is False"
            )

        kwargs: Dict[str, Any] = {}
        if use_fp16:
            kwargs["torch_dtype"] = torch.float16
        if attn_impl:
            kwargs["attn_implementation"] = attn_impl

        self._model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).eval()
        target_device = str(device).strip() if device is not None else "cuda"
        if not target_device.startswith("cuda"):
            raise ValueError(
                "qwen3_transformers_packed backend is GPU-only. "
                f"Unsupported device setting: {target_device!r}"
            )
        self._model = self._model.to(target_device)
        self._device = next(self._model.parameters()).device
        if self._device.type != "cuda":
            raise RuntimeError(
                "qwen3_transformers_packed backend failed to place model on CUDA. "
                f"Current device: {self._device}"
            )

        self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes")
        self._token_false_id = self._tokenizer.convert_tokens_to_ids("no")
        if self._token_true_id is None or self._token_false_id is None:
            raise RuntimeError("Failed to resolve Qwen3 reranker classifier token ids for yes/no")

        prefix_budget = len(self._prefix_tokens) + self._suffix_len + 1
        if self._max_model_len <= prefix_budget:
            raise ValueError(
                "max_model_len is too small for packed reranking. "
                f"Need > {prefix_budget}, got {self._max_model_len}."
            )
        if self._max_doc_len <= 0:
            raise ValueError(f"max_doc_len must be > 0, got {self._max_doc_len}")
        if self._max_docs_per_pack < 0:
            raise ValueError(
                f"max_docs_per_pack must be >= 0, got {self._max_docs_per_pack}"
            )

        self._infer_lock = threading.Lock()

        logger.info(
            "[Qwen3_Transformers_Packed] Model ready | model=%s device=%s",
            model_name,
            self._device,
        )

    def _build_pair_prefix_tokens(self, query: str) -> List[int]:
        pair_prefix = self._pair_prefix_template.format(
            prefix=self._prefix,
            instruction=self._instruction,
            query=query,
        )
        return self._tokenizer.encode(pair_prefix, add_special_tokens=False)

    def _tokenize_documents(self, docs: Sequence[str], query_prefix_len: int) -> List[List[int]]:
        max_doc_tokens = min(
            self._max_doc_len,
            max(1, self._max_model_len - query_prefix_len - self._suffix_len),
        )
        tokenized = self._tokenizer(
            list(docs),
            padding=False,
            truncation=True,
            max_length=max_doc_tokens,
            add_special_tokens=False,
            return_attention_mask=False,
        )
        return [list(ids) for ids in tokenized["input_ids"]]

    def _build_pack_plan(
        self,
        query_prefix_len: int,
        doc_tokens: Sequence[Sequence[int]],
    ) -> List[List[int]]:
        order = list(range(len(doc_tokens)))
        if self._sort_by_doc_length and len(order) > 1:
            order.sort(key=lambda idx: len(doc_tokens[idx]))

        packs: List[List[int]] = []
        current_pack: List[int] = []
        current_len = query_prefix_len
        for idx in order:
            packed_doc_len = len(doc_tokens[idx]) + self._suffix_len
            if packed_doc_len <= 0:
                continue

            over_docs_cap = self._max_docs_per_pack > 0 and len(current_pack) >= self._max_docs_per_pack
            over_token_cap = current_pack and (current_len + packed_doc_len > self._max_model_len)
            if over_docs_cap or over_token_cap:
                packs.append(current_pack)
                current_pack = []
                current_len = query_prefix_len

            if query_prefix_len + packed_doc_len > self._max_model_len:
                raise ValueError(
                    "Packed doc still exceeds max_model_len after truncation. "
                    f"query_prefix_len={query_prefix_len}, doc_len={packed_doc_len}, "
                    f"max_model_len={self._max_model_len}"
                )

            current_pack.append(idx)
            current_len += packed_doc_len

        if current_pack:
            packs.append(current_pack)
        return packs

    def _build_pack_inputs(
        self,
        query_prefix_tokens: Sequence[int],
        doc_tokens: Sequence[Sequence[int]],
        doc_indices: Sequence[int],
    ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
        prefix_len = len(query_prefix_tokens)
        input_ids_list = list(query_prefix_tokens)
        position_ids_list = list(range(prefix_len))
        spans: List[Tuple[int, int]] = []
        current_len = prefix_len

        for idx in doc_indices:
            doc_with_suffix = list(doc_tokens[idx]) + self._suffix_tokens
            start = current_len
            end = start + len(doc_with_suffix)
            spans.append((start, end))
            input_ids_list.extend(doc_with_suffix)
            position_ids_list.extend(range(prefix_len, prefix_len + len(doc_with_suffix)))
            current_len = end

        total_len = len(input_ids_list)
        device = self._device
        neg_inf = torch.finfo(torch.float32).min

        allowed = torch.zeros((total_len, total_len), dtype=torch.bool, device=device)
        prefix_causal = torch.tril(
            torch.ones((prefix_len, prefix_len), dtype=torch.bool, device=device)
        )
        allowed[:prefix_len, :prefix_len] = prefix_causal
        for start, end in spans:
            allowed[start:end, :prefix_len] = True
            doc_len = end - start
            allowed[start:end, start:end] = torch.tril(
                torch.ones((doc_len, doc_len), dtype=torch.bool, device=device)
            )

        attention_mask = torch.full(
            (total_len, total_len),
            neg_inf,
            dtype=torch.float32,
            device=device,
        )
        attention_mask.masked_fill_(allowed, 0.0)

        inputs = {
            "input_ids": torch.tensor([input_ids_list], dtype=torch.long, device=device),
            "position_ids": torch.tensor([position_ids_list], dtype=torch.long, device=device),
            "attention_mask": attention_mask.view(1, 1, total_len, total_len),
        }
        logits_ids = torch.tensor(
            [end - 1 for _, end in spans],
            dtype=torch.long,
            device=device,
        )
        return inputs, logits_ids

    @torch.no_grad()
    def _score_pack(
        self,
        query_prefix_tokens: Sequence[int],
        doc_tokens: Sequence[Sequence[int]],
        doc_indices: Sequence[int],
    ) -> Tuple[List[float], int]:
        inputs, logits_ids = self._build_pack_inputs(
            query_prefix_tokens=query_prefix_tokens,
            doc_tokens=doc_tokens,
            doc_indices=doc_indices,
        )
        outputs = self._model(**inputs)
        scores = outputs.logits[0, logits_ids, :]
        true_vector = scores[:, self._token_true_id]
        false_vector = scores[:, self._token_false_id]
        pair_scores = torch.stack([false_vector, true_vector], dim=1)
        pair_scores = torch.nn.functional.log_softmax(pair_scores, dim=1)
        return pair_scores[:, 1].exp().tolist(), int(inputs["input_ids"].shape[1])

    def score_with_meta(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
    ) -> Tuple[List[float], Dict[str, Any]]:
        start_ts = time.time()
        total_docs = len(docs) if docs else 0
        output_scores: List[float] = [0.0] * total_docs

        query = "" if query is None else str(query).strip()
        indexed: List[Tuple[int, str]] = []
        for i, doc in enumerate(docs or []):
            if doc is None:
                continue
            text = str(doc).strip()
            if not text:
                continue
            indexed.append((i, text))

        if not query or not indexed:
            elapsed_ms = (time.time() - start_ts) * 1000.0
            return output_scores, {
                "input_docs": total_docs,
                "usable_docs": len(indexed),
                "unique_docs": 0,
                "dedup_ratio": 0.0,
                "elapsed_ms": round(elapsed_ms, 3),
                "model": self._model_name,
                "backend": "qwen3_transformers_packed",
                "normalize": normalize,
                "packed_batches": 0,
                "max_model_len": self._max_model_len,
                "max_doc_len": self._max_doc_len,
                "sort_by_doc_length": self._sort_by_doc_length,
            }

        indexed_texts = [text for _, text in indexed]
        unique_texts, position_to_unique = _deduplicate_with_positions(indexed_texts)

        query_prefix_tokens = self._build_pair_prefix_tokens(query)
        doc_tokens = self._tokenize_documents(unique_texts, query_prefix_len=len(query_prefix_tokens))
        pack_plan = self._build_pack_plan(
            query_prefix_len=len(query_prefix_tokens),
            doc_tokens=doc_tokens,
        )

        unique_scores: List[float] = [0.0] * len(unique_texts)
        pack_lengths: List[int] = []
        with self._infer_lock:
            for pack_doc_indices in pack_plan:
                batch_scores, pack_seq_len = self._score_pack(
                    query_prefix_tokens=query_prefix_tokens,
                    doc_tokens=doc_tokens,
                    doc_indices=pack_doc_indices,
                )
                if len(batch_scores) != len(pack_doc_indices):
                    raise RuntimeError(
                        "Packed reranker score size mismatch: "
                        f"expected {len(pack_doc_indices)}, got {len(batch_scores)}"
                    )
                for idx, score in zip(pack_doc_indices, batch_scores):
                    unique_scores[idx] = float(score)
                pack_lengths.append(pack_seq_len)

        for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
            output_scores[orig_idx] = float(unique_scores[unique_idx])

        elapsed_ms = (time.time() - start_ts) * 1000.0
        dedup_ratio = 0.0
        if indexed:
            dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))

        meta = {
            "input_docs": total_docs,
            "usable_docs": len(indexed),
            "unique_docs": len(unique_texts),
            "dedup_ratio": round(dedup_ratio, 4),
            "elapsed_ms": round(elapsed_ms, 3),
            "model": self._model_name,
            "backend": "qwen3_transformers_packed",
            "normalize": normalize,
            "packed_batches": len(pack_plan),
            "packed_max_seq_len": max(pack_lengths) if pack_lengths else 0,
            "packed_avg_seq_len": round(sum(pack_lengths) / len(pack_lengths), 3)
            if pack_lengths
            else 0.0,
            "max_model_len": self._max_model_len,
            "max_doc_len": self._max_doc_len,
            "max_docs_per_pack": self._max_docs_per_pack,
            "sort_by_doc_length": self._sort_by_doc_length,
            "attn_implementation": self._attn_impl,
        }
        return output_scores, meta