qwen3_transformers.py 7.43 KB
"""
Qwen3-Reranker-0.6B backend using Transformers (direct usage). No vLLM required.

Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
Requires: transformers>=4.51.0, torch.
"""

from __future__ import annotations

import logging
import time
from typing import Any, Dict, List, Optional, Tuple

logger = logging.getLogger("reranker.backends.qwen3_transformers")

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def _format_instruction(instruction: str, query: str, doc: str) -> str:
    """Format (query, doc) pair per official Qwen3-Reranker spec."""
    return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
        instruction=instruction, query=query, doc=doc
    )


class Qwen3TransformersRerankerBackend:
    """
    Qwen3-Reranker-0.6B with Transformers (AutoModelForCausalLM) inference.
    Config from services.rerank.backends.qwen3_transformers.
    No vLLM dependency; lighter than qwen3_vllm, suitable for CPU or small GPU.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self._config = config or {}
        model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
        self._instruction = str(
            self._config.get("instruction")
            or "Given a shopping query, rank product titles by relevance"
        )
        max_length = int(self._config.get("max_length", 8192))
        batch_size = int(self._config.get("batch_size", 64))
        use_fp16 = bool(self._config.get("use_fp16", True))
        device = self._config.get("device")
        attn_impl = self._config.get("attn_implementation")  # e.g. "flash_attention_2"

        self._model_name = model_name
        self._batch_size = batch_size

        logger.info(
            "[Qwen3_Transformers] Loading model %s (max_length=%s, batch=%s, fp16=%s)",
            model_name,
            max_length,
            batch_size,
            use_fp16,
        )

        self._tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
        self._tokenizer.pad_token = self._tokenizer.eos_token

        # Prefix/suffix from official reference
        prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
        suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self._prefix_tokens = self._tokenizer.encode(prefix, add_special_tokens=False)
        self._suffix_tokens = self._tokenizer.encode(suffix, add_special_tokens=False)
        self._max_length = max_length
        self._effective_max_len = max_length - len(self._prefix_tokens) - len(self._suffix_tokens)

        self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes")
        self._token_false_id = self._tokenizer.convert_tokens_to_ids("no")

        kwargs = {}
        if use_fp16 and torch.cuda.is_available():
            kwargs["torch_dtype"] = torch.float16
        if attn_impl:
            kwargs["attn_implementation"] = attn_impl

        self._model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).eval()
        if device is not None:
            self._model = self._model.to(device)
        elif torch.cuda.is_available():
            self._model = self._model.cuda()

        logger.info(
            "[Qwen3_Transformers] Model ready | model=%s device=%s",
            model_name,
            next(self._model.parameters()).device,
        )

    def _process_inputs(self, pairs: List[str]) -> Dict[str, torch.Tensor]:
        """Tokenize pairs and add prefix/suffix tokens. Returns batched tensors on model device."""
        inputs = self._tokenizer(
            pairs,
            padding=False,
            truncation="longest_first",
            return_attention_mask=False,
            max_length=self._effective_max_len,
        )
        for i, ele in enumerate(inputs["input_ids"]):
            inputs["input_ids"][i] = self._prefix_tokens + ele + self._suffix_tokens
        inputs = self._tokenizer.pad(
            inputs,
            padding=True,
            return_tensors="pt",
        )
        for key in inputs:
            inputs[key] = inputs[key].to(self._model.device)
        return inputs

    @torch.no_grad()
    def _compute_scores(self, pairs: List[str]) -> List[float]:
        """Run forward pass and compute yes/no probability per pair."""
        if not pairs:
            return []
        inputs = self._process_inputs(pairs)
        outputs = self._model(**inputs)
        batch_scores = outputs.logits[:, -1, :]
        true_vector = batch_scores[:, self._token_true_id]
        false_vector = batch_scores[:, self._token_false_id]
        batch_scores = torch.stack([false_vector, true_vector], dim=1)
        batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
        scores = batch_scores[:, 1].exp().tolist()
        return scores

    def score_with_meta(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
    ) -> Tuple[List[float], Dict[str, Any]]:
        start_ts = time.time()
        total_docs = len(docs) if docs else 0
        output_scores: List[float] = [0.0] * total_docs

        query = "" if query is None else str(query).strip()
        indexed: List[Tuple[int, str]] = []
        for i, doc in enumerate(docs or []):
            if doc is None:
                continue
            text = str(doc).strip()
            if not text:
                continue
            indexed.append((i, text))

        if not query or not indexed:
            elapsed_ms = (time.time() - start_ts) * 1000.0
            return output_scores, {
                "input_docs": total_docs,
                "usable_docs": len(indexed),
                "unique_docs": 0,
                "dedup_ratio": 0.0,
                "elapsed_ms": round(elapsed_ms, 3),
                "model": self._model_name,
                "backend": "qwen3_transformers",
                "normalize": normalize,
            }

        # Deduplicate by text, keep mapping to original indices
        unique_texts: List[str] = []
        position_to_unique: List[int] = []
        prev: Optional[str] = None
        for _idx, text in indexed:
            if text != prev:
                unique_texts.append(text)
                prev = text
            position_to_unique.append(len(unique_texts) - 1)

        pairs = [
            _format_instruction(self._instruction, query, t)
            for t in unique_texts
        ]

        # Batch inference
        unique_scores: List[float] = []
        for i in range(0, len(pairs), self._batch_size):
            batch = pairs[i : i + self._batch_size]
            batch_scores = self._compute_scores(batch)
            unique_scores.extend(batch_scores)

        for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
            output_scores[orig_idx] = float(unique_scores[unique_idx])

        elapsed_ms = (time.time() - start_ts) * 1000.0
        dedup_ratio = 0.0
        if indexed:
            dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))

        meta = {
            "input_docs": total_docs,
            "usable_docs": len(indexed),
            "unique_docs": len(unique_texts),
            "dedup_ratio": round(dedup_ratio, 4),
            "elapsed_ms": round(elapsed_ms, 3),
            "model": self._model_name,
            "backend": "qwen3_transformers",
            "normalize": normalize,
        }
        return output_scores, meta