test_reranker_server_topn.py 1.43 KB
from __future__ import annotations

from typing import Any, Dict, List

from fastapi.testclient import TestClient


class _FakeTopNReranker:
    _model_name = "fake-topn-reranker"

    def score_with_meta(self, query: str, docs: List[str], normalize: bool = True):
        return [0.1 for _ in docs], {"input_docs": len(docs), "path": "base"}

    def score_with_meta_topn(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
        top_n: int | None = None,
    ):
        scores = [0.0 for _ in docs]
        if docs and top_n:
            scores[0] = 1.0
        return scores, {"input_docs": len(docs), "path": "topn", "top_n": top_n}


def test_reranker_server_forwards_top_n():
    import reranker.server as reranker_server

    reranker_server.app.router.on_startup.clear()
    reranker_server._reranker = _FakeTopNReranker()
    reranker_server._backend_name = "fake_topn"

    with TestClient(reranker_server.app) as client:
        response = client.post(
            "/rerank",
            json={
                "query": "wireless mouse",
                "docs": ["a", "b", "c"],
                "top_n": 2,
            },
        )
        assert response.status_code == 200
        data: Dict[str, Any] = response.json()
        assert data["scores"] == [1.0, 0.0, 0.0]
        assert data["meta"]["path"] == "topn"
        assert data["meta"]["requested_top_n"] == 2
        assert data["meta"]["top_n"] == 2