test_reranker_dashscope_backend.py 8.35 KB
from __future__ import annotations

import time

import pytest

from reranker.backends import get_rerank_backend
from reranker.backends.dashscope_rerank import DashScopeRerankBackend


@pytest.fixture(autouse=True)
def _clear_global_dashscope_key(monkeypatch):
    # Prevent accidental pass-through from unrelated global key.
    monkeypatch.delenv("DASHSCOPE_API_KEY", raising=False)


def test_dashscope_backend_factory_loads(monkeypatch):
    monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key")
    backend = get_rerank_backend(
        "dashscope_rerank",
        {
            "model_name": "qwen3-rerank",
            "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks",
            "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY",
        },
    )
    assert isinstance(backend, DashScopeRerankBackend)


def test_dashscope_backend_score_with_meta_dedup_and_restore(monkeypatch):
    monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key")
    backend = DashScopeRerankBackend(
        {
            "model_name": "qwen3-rerank",
            "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks",
            "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY",
            "top_n_cap": 0,
        }
    )

    def _fake_post(query: str, docs: list[str], top_n: int):
        assert query == "wireless mouse"
        # deduplicated docs
        assert docs == ["doc-a", "doc-b"]
        assert top_n == 2
        return {
            "results": [
                {"index": 1, "relevance_score": 0.9},
                {"index": 0, "relevance_score": 0.2},
            ]
        }

    monkeypatch.setattr(backend, "_post_rerank", _fake_post)
    scores, meta = backend.score_with_meta(
        query="wireless mouse",
        docs=["doc-a", "doc-b", "doc-a", "", "   ", None],
        normalize=True,
    )

    assert scores == [0.2, 0.9, 0.2, 0.0, 0.0, 0.0]
    assert meta["input_docs"] == 6
    assert meta["usable_docs"] == 3
    assert meta["unique_docs"] == 2
    assert meta["top_n"] == 2
    assert meta["response_results"] == 2
    assert meta["backend"] == "dashscope_rerank"


def test_dashscope_backend_top_n_cap_and_normalize_fallback(monkeypatch):
    monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key")
    backend = DashScopeRerankBackend(
        {
            "model_name": "qwen3-rerank",
            "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks",
            "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY",
            "top_n_cap": 1,
        }
    )

    def _fake_post(query: str, docs: list[str], top_n: int):
        assert query == "q"
        assert len(docs) == 2
        assert top_n == 1
        # Only top-1 returned, score outside [0,1] to trigger sigmoid fallback
        return {"results": [{"index": 1, "score": 3.0}]}

    monkeypatch.setattr(backend, "_post_rerank", _fake_post)
    scores_norm, _ = backend.score_with_meta(query="q", docs=["a", "b"], normalize=True)
    scores_raw, _ = backend.score_with_meta(query="q", docs=["a", "b"], normalize=False)

    assert scores_norm[0] == 0.0
    assert 0.95 < scores_norm[1] < 0.96
    assert scores_raw == [0.0, 3.0]


def test_dashscope_backend_score_with_meta_topn_request(monkeypatch):
    monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key")
    backend = DashScopeRerankBackend(
        {
            "model_name": "qwen3-rerank",
            "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks",
            "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY",
            "top_n_cap": 0,
        }
    )

    def _fake_post(query: str, docs: list[str], top_n: int):
        assert query == "q"
        assert docs == ["d1", "d2", "d3"]
        assert top_n == 2
        return {"results": [{"index": 2, "relevance_score": 0.8}, {"index": 0, "relevance_score": 0.3}]}

    monkeypatch.setattr(backend, "_post_rerank", _fake_post)
    scores, meta = backend.score_with_meta_topn(query="q", docs=["d1", "d2", "d3"], top_n=2)
    assert scores == [0.3, 0.0, 0.8]
    assert meta["top_n"] == 2
    assert meta["requested_top_n"] == 2


def test_dashscope_backend_batchsize_concurrent_full_topn(monkeypatch):
    monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key")
    backend = DashScopeRerankBackend(
        {
            "model_name": "qwen3-rerank",
            "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks",
            "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY",
            "top_n_cap": 0,
            "batchsize": 2,
        }
    )

    def _fake_post(query: str, docs: list[str], top_n: int):
        assert query == "q"
        # batching path asks every batch for full local list
        assert top_n == len(docs)
        time.sleep(0.05)
        return {
            "results": [
                {"index": i, "relevance_score": float(i + 1) / 10.0}
                for i, _ in enumerate(docs)
            ]
        }

    monkeypatch.setattr(backend, "_post_rerank", _fake_post)
    start = time.perf_counter()
    scores, meta = backend.score_with_meta(query="q", docs=["d1", "d2", "d3", "d4", "d5", "d6"])
    elapsed = time.perf_counter() - start

    # 3 batches * 50ms serial ~=150ms; concurrent should be significantly lower.
    assert elapsed < 0.14
    assert len(scores) == 6
    assert meta["batches"] == 3
    assert meta["batch_concurrency"] == 3
    assert meta["response_results"] == 6


def test_dashscope_backend_batchsize_still_effective_when_topn_limited(monkeypatch):
    monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key")
    backend = DashScopeRerankBackend(
        {
            "model_name": "qwen3-rerank",
            "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks",
            "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY",
            "top_n_cap": 0,
            "batchsize": 2,
        }
    )

    called = {"count": 0}

    def _fake_post(query: str, docs: list[str], top_n: int):
        called["count"] += 1
        # batching remains enabled; each batch asks for full local scores
        assert top_n == len(docs)
        score_map = {"d1": 0.9, "d2": 0.1, "d3": 0.8, "d4": 0.2}
        return {
            "results": [
                {"index": i, "relevance_score": score_map[doc]}
                for i, doc in enumerate(docs)
            ]
        }

    monkeypatch.setattr(backend, "_post_rerank", _fake_post)
    scores, meta = backend.score_with_meta_topn(query="q", docs=["d1", "d2", "d3", "d4"], top_n=2)

    assert called["count"] == 2
    assert scores == [0.9, 0.0, 0.8, 0.0]
    assert meta["batches"] == 2
    assert meta["top_n"] == 2


def test_dashscope_backend_batchsize_raises_when_one_batch_fails(monkeypatch):
    monkeypatch.setenv("TEST_RERANK_DASHSCOPE_API_KEY", "test-key")
    backend = DashScopeRerankBackend(
        {
            "model_name": "qwen3-rerank",
            "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks",
            "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY",
            "top_n_cap": 0,
            "batchsize": 2,
        }
    )

    def _fake_post(query: str, docs: list[str], top_n: int):
        if docs == ["d3", "d4"]:
            raise RuntimeError("provider temporary error")
        return {
            "results": [
                {"index": i, "relevance_score": 0.1}
                for i, _ in enumerate(docs)
            ]
        }

    monkeypatch.setattr(backend, "_post_rerank", _fake_post)

    with pytest.raises(RuntimeError, match="DashScope rerank batch failed"):
        backend.score_with_meta(query="q", docs=["d1", "d2", "d3", "d4"])


def test_dashscope_backend_requires_api_key_env():
    with pytest.raises(ValueError, match="api_key_env is required"):
        DashScopeRerankBackend(
            {
                "model_name": "qwen3-rerank",
                "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks",
                "top_n_cap": 0,
            }
        )


def test_dashscope_backend_requires_api_key_env_value(monkeypatch):
    monkeypatch.delenv("TEST_RERANK_DASHSCOPE_API_KEY", raising=False)
    with pytest.raises(ValueError, match="set env TEST_RERANK_DASHSCOPE_API_KEY"):
        DashScopeRerankBackend(
            {
                "model_name": "qwen3-rerank",
                "endpoint": "https://dashscope.aliyuncs.com/compatible-api/v1/reranks",
                "api_key_env": "TEST_RERANK_DASHSCOPE_API_KEY",
                "top_n_cap": 0,
            }
        )