__init__.py 3 KB
"""
Rerank backends - pluggable implementations of the rerank protocol.

Each backend implements score_with_meta(query, docs, normalize) -> (scores, meta).
Service loads one backend via get_rerank_backend(name, config) from config.
"""

from __future__ import annotations

from typing import Any, Dict, List, Protocol, Tuple


class RerankBackendProtocol(Protocol):
    """Protocol for reranker backends (service-internal)."""

    def score_with_meta(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
    ) -> Tuple[List[float], Dict[str, Any]]:
        """
        Input:
          query: search query string
          docs: list of documents, scores must align 1:1 with docs
          normalize: whether to normalize scores (e.g. sigmoid)
        Output:
          scores: list same length as docs, same order
          meta: at least input_docs, usable_docs, unique_docs, elapsed_ms
        """
        ...


def get_rerank_backend(name: str, config: Dict[str, Any]) -> RerankBackendProtocol:
    """
    Factory: return a reranker backend instance for the given name and config.
    Config is the corresponding block from services.rerank.backends.<name>.
    """
    name = (name or "bge").strip().lower()
    if name == "bge":
        from reranker.backends.bge import BGERerankerBackend
        return BGERerankerBackend(config)
    if name == "qwen3_vllm":
        from reranker.backends.qwen3_vllm import Qwen3VLLMRerankerBackend
        return Qwen3VLLMRerankerBackend(config)
    if name == "qwen3_vllm_score":
        from reranker.backends.qwen3_vllm_score import Qwen3VLLMScoreRerankerBackend
        return Qwen3VLLMScoreRerankerBackend(config)
    if name == "qwen3_transformers":
        from reranker.backends.qwen3_transformers import Qwen3TransformersRerankerBackend
        return Qwen3TransformersRerankerBackend(config)
    if name == "qwen3_transformers_packed":
        from reranker.backends.qwen3_transformers_packed import (
            Qwen3TransformersPackedRerankerBackend,
        )
        return Qwen3TransformersPackedRerankerBackend(config)
    if name == "qwen3_gguf":
        from reranker.backends.qwen3_gguf import Qwen3GGUFRerankerBackend
        gguf_config = dict(config or {})
        gguf_config.setdefault("_backend_name", "qwen3_gguf")
        return Qwen3GGUFRerankerBackend(gguf_config)
    if name == "qwen3_gguf_06b":
        from reranker.backends.qwen3_gguf import Qwen3GGUFRerankerBackend
        gguf_config = dict(config or {})
        gguf_config.setdefault("_backend_name", "qwen3_gguf_06b")
        return Qwen3GGUFRerankerBackend(gguf_config)
    if name == "dashscope_rerank":
        from reranker.backends.dashscope_rerank import DashScopeRerankBackend
        return DashScopeRerankBackend(config)
    raise ValueError(
        f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_vllm_score, qwen3_transformers, qwen3_transformers_packed, qwen3_gguf, qwen3_gguf_06b, dashscope_rerank"
    )


__all__ = ["RerankBackendProtocol", "get_rerank_backend"]