qwen3_gguf.py 12.6 KB
"""
Qwen3-Reranker-4B GGUF backend using llama-cpp-python.

Reference:
- https://huggingface.co/DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF
- https://huggingface.co/Qwen/Qwen3-Reranker-4B
"""

from __future__ import annotations

import logging
import math
import os
import threading
import time
from typing import Any, Dict, List, Tuple


logger = logging.getLogger("reranker.backends.qwen3_gguf")


def deduplicate_with_positions(texts: List[str]) -> Tuple[List[str], List[int]]:
    """Deduplicate texts globally while preserving first-seen order."""
    unique_texts: List[str] = []
    position_to_unique: List[int] = []
    seen: Dict[str, int] = {}

    for text in texts:
        idx = seen.get(text)
        if idx is None:
            idx = len(unique_texts)
            seen[text] = idx
            unique_texts.append(text)
        position_to_unique.append(idx)

    return unique_texts, position_to_unique


def _format_instruction(instruction: str, query: str, doc: str) -> str:
    return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
        instruction=instruction,
        query=query,
        doc=doc,
    )


class Qwen3GGUFRerankerBackend:
    """
    Qwen3-Reranker-4B GGUF backend using llama.cpp through llama-cpp-python.

    Tuned for short-query / short-doc reranking on a memory-constrained single T4.
    Config from services.rerank.backends.qwen3_gguf.
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self._config = config or {}
        self._repo_id = str(
            self._config.get("repo_id") or "DevQuasar/Qwen.Qwen3-Reranker-4B-GGUF"
        ).strip()
        self._filename = str(self._config.get("filename") or "*Q8_0.gguf").strip()
        self._model_path = str(self._config.get("model_path") or "").strip()
        self._cache_dir = str(self._config.get("cache_dir") or "").strip() or None
        self._local_dir = str(self._config.get("local_dir") or "").strip() or None
        self._instruction = str(
            self._config.get("instruction")
            or "Rank products by query with category & style match prioritized"
        )
        self._infer_batch_size = int(
            os.getenv("RERANK_GGUF_INFER_BATCH_SIZE") or self._config.get("infer_batch_size", 8)
        )
        sort_by_doc_length = os.getenv("RERANK_GGUF_SORT_BY_DOC_LENGTH")
        if sort_by_doc_length is None:
            sort_by_doc_length = self._config.get("sort_by_doc_length", True)
        self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {
            "1",
            "true",
            "yes",
            "y",
            "on",
        }
        self._length_sort_mode = str(self._config.get("length_sort_mode") or "char").strip().lower()

        n_ctx = int(self._config.get("n_ctx", self._config.get("max_model_len", 384)))
        n_batch = int(self._config.get("n_batch", min(n_ctx, 384)))
        n_ubatch = int(self._config.get("n_ubatch", min(n_batch, 128)))
        n_gpu_layers = int(self._config.get("n_gpu_layers", 24))
        main_gpu = int(self._config.get("main_gpu", 0))
        n_threads = int(self._config.get("n_threads", 2))
        n_threads_batch = int(self._config.get("n_threads_batch", 4))
        flash_attn = bool(self._config.get("flash_attn", True))
        offload_kqv = bool(self._config.get("offload_kqv", True))
        use_mmap = bool(self._config.get("use_mmap", True))
        use_mlock = bool(self._config.get("use_mlock", False))
        verbose = bool(self._config.get("verbose", False))
        enable_warmup = bool(self._config.get("enable_warmup", True))

        if self._infer_batch_size <= 0:
            raise ValueError(f"infer_batch_size must be > 0, got {self._infer_batch_size}")
        if n_ctx <= 0:
            raise ValueError(f"n_ctx must be > 0, got {n_ctx}")
        if n_batch <= 0 or n_ubatch <= 0:
            raise ValueError(f"n_batch/n_ubatch must be > 0, got {n_batch}/{n_ubatch}")

        try:
            from llama_cpp import Llama
        except Exception as exc:  # pragma: no cover - depends on optional dependency
            raise RuntimeError(
                "qwen3_gguf backend requires llama-cpp-python. "
                "Install the qwen3_gguf backend venv first via scripts/setup_reranker_venv.sh qwen3_gguf."
            ) from exc

        self._llama_class = Llama
        self._n_ctx = n_ctx
        self._n_batch = n_batch
        self._n_ubatch = n_ubatch
        self._n_gpu_layers = n_gpu_layers
        self._enable_warmup = enable_warmup
        self._infer_lock = threading.Lock()

        logger.info(
            "[Qwen3_GGUF] Loading model repo=%s filename=%s model_path=%s n_ctx=%s n_batch=%s n_ubatch=%s n_gpu_layers=%s flash_attn=%s offload_kqv=%s",
            self._repo_id,
            self._filename,
            self._model_path or None,
            n_ctx,
            n_batch,
            n_ubatch,
            n_gpu_layers,
            flash_attn,
            offload_kqv,
        )

        llm_kwargs = {
            "n_ctx": n_ctx,
            "n_batch": n_batch,
            "n_ubatch": n_ubatch,
            "n_gpu_layers": n_gpu_layers,
            "main_gpu": main_gpu,
            "n_threads": n_threads,
            "n_threads_batch": n_threads_batch,
            "logits_all": True,
            "offload_kqv": offload_kqv,
            "flash_attn": flash_attn,
            "use_mmap": use_mmap,
            "use_mlock": use_mlock,
            "verbose": verbose,
        }
        llm_kwargs = {key: value for key, value in llm_kwargs.items() if value is not None}
        self._llm = self._load_model(llm_kwargs)
        self._model_name = self._model_path or f"{self._repo_id}:{self._filename}"

        self._prefix = (
            "<|im_start|>system\n"
            "Judge whether the Document meets the requirements based on the Query and the Instruct provided. "
            'Note that the answer can only be "yes" or "no".'
            "<|im_end|>\n<|im_start|>user\n"
        )
        self._suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
        self._prefix_tokens = self._tokenize(self._prefix, special=True)
        self._suffix_tokens = self._tokenize(self._suffix, special=True)
        self._effective_max_len = self._n_ctx - len(self._prefix_tokens) - len(self._suffix_tokens)
        if self._effective_max_len <= 16:
            raise RuntimeError(
                f"n_ctx={self._n_ctx} is too small after prompt overhead; effective={self._effective_max_len}"
            )

        self._true_token = self._single_token_id("yes")
        self._false_token = self._single_token_id("no")

        if self._enable_warmup:
            self._warmup()

        logger.info(
            "[Qwen3_GGUF] Model ready | model=%s effective_max_len=%s infer_batch_size=%s sort_by_doc_length=%s",
            self._model_name,
            self._effective_max_len,
            self._infer_batch_size,
            self._sort_by_doc_length,
        )

    def _load_model(self, llm_kwargs: Dict[str, Any]):
        if self._model_path:
            return self._llama_class(model_path=self._model_path, **llm_kwargs)
        return self._llama_class.from_pretrained(
            repo_id=self._repo_id,
            filename=self._filename,
            local_dir=self._local_dir,
            cache_dir=self._cache_dir,
            **llm_kwargs,
        )

    def _tokenize(self, text: str, *, special: bool) -> List[int]:
        return list(
            self._llm.tokenize(
                text.encode("utf-8"),
                add_bos=False,
                special=special,
            )
        )

    def _single_token_id(self, text: str) -> int:
        token_ids = self._tokenize(text, special=False)
        if len(token_ids) != 1:
            raise RuntimeError(f"Expected {text!r} to be one token, got {token_ids}")
        return int(token_ids[0])

    def _warmup(self) -> None:
        try:
            prompt = self._build_prompt_tokens("warmup query", "warmup document")
            with self._infer_lock:
                self._eval_logits(prompt)
        except Exception as exc:  # pragma: no cover - defensive
            logger.warning("[Qwen3_GGUF] Warmup failed: %s", exc)

    def _build_prompt_tokens(self, query: str, doc: str) -> List[int]:
        pair = _format_instruction(self._instruction, query, doc)
        pair_tokens = self._tokenize(pair, special=False)
        pair_tokens = pair_tokens[: self._effective_max_len]
        return self._prefix_tokens + pair_tokens + self._suffix_tokens

    def _eval_logits(self, prompt_tokens: List[int]) -> List[float]:
        self._llm.reset()
        self._llm.eval(prompt_tokens)
        logits = self._llm.eval_logits
        if not logits:
            raise RuntimeError("llama.cpp returned empty logits")
        return list(logits[-1])

    def _score_prompt(self, prompt_tokens: List[int]) -> float:
        logits = self._eval_logits(prompt_tokens)
        true_logit = float(logits[self._true_token])
        false_logit = float(logits[self._false_token])
        max_logit = max(true_logit, false_logit)
        true_exp = math.exp(true_logit - max_logit)
        false_exp = math.exp(false_logit - max_logit)
        return float(true_exp / (true_exp + false_exp))

    def _estimate_doc_lengths(self, docs: List[str]) -> List[int]:
        if self._length_sort_mode == "token":
            return [len(self._tokenize(text, special=False)) for text in docs]
        return [len(text) for text in docs]

    def score_with_meta(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
    ) -> Tuple[List[float], Dict[str, Any]]:
        start_ts = time.time()
        total_docs = len(docs) if docs else 0
        output_scores: List[float] = [0.0] * total_docs

        query = "" if query is None else str(query).strip()
        indexed: List[Tuple[int, str]] = []
        for i, doc in enumerate(docs or []):
            if doc is None:
                continue
            text = str(doc).strip()
            if not text:
                continue
            indexed.append((i, text))

        if not query or not indexed:
            elapsed_ms = (time.time() - start_ts) * 1000.0
            return output_scores, {
                "input_docs": total_docs,
                "usable_docs": len(indexed),
                "unique_docs": 0,
                "dedup_ratio": 0.0,
                "elapsed_ms": round(elapsed_ms, 3),
                "model": self._model_name,
                "backend": "qwen3_gguf",
                "normalize": normalize,
                "infer_batch_size": self._infer_batch_size,
                "inference_batches": 0,
                "sort_by_doc_length": self._sort_by_doc_length,
                "n_ctx": self._n_ctx,
                "n_batch": self._n_batch,
                "n_ubatch": self._n_ubatch,
                "n_gpu_layers": self._n_gpu_layers,
            }

        indexed_texts = [text for _, text in indexed]
        unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts)

        lengths = self._estimate_doc_lengths(unique_texts)
        order = list(range(len(unique_texts)))
        if self._sort_by_doc_length and len(unique_texts) > 1:
            order = sorted(order, key=lambda i: lengths[i])

        unique_scores: List[float] = [0.0] * len(unique_texts)
        inference_batches = 0
        for start in range(0, len(order), self._infer_batch_size):
            batch_indices = order[start : start + self._infer_batch_size]
            inference_batches += 1
            for idx in batch_indices:
                prompt = self._build_prompt_tokens(query, unique_texts[idx])
                with self._infer_lock:
                    unique_scores[idx] = self._score_prompt(prompt)

        for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
            output_scores[orig_idx] = float(unique_scores[unique_idx])

        elapsed_ms = (time.time() - start_ts) * 1000.0
        dedup_ratio = 0.0
        if indexed:
            dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))

        meta = {
            "input_docs": total_docs,
            "usable_docs": len(indexed),
            "unique_docs": len(unique_texts),
            "dedup_ratio": round(dedup_ratio, 4),
            "elapsed_ms": round(elapsed_ms, 3),
            "model": self._model_name,
            "backend": "qwen3_gguf",
            "normalize": normalize,
            "infer_batch_size": self._infer_batch_size,
            "inference_batches": inference_batches,
            "sort_by_doc_length": self._sort_by_doc_length,
            "length_sort_mode": self._length_sort_mode,
            "n_ctx": self._n_ctx,
            "n_batch": self._n_batch,
            "n_ubatch": self._n_ubatch,
            "n_gpu_layers": self._n_gpu_layers,
        }
        return output_scores, meta