test_rerank_provider_topn.py 982 Bytes
from __future__ import annotations

from typing import Any, Dict

from providers.rerank import HttpRerankProvider


class _FakeResponse:
    def __init__(self, status_code: int, data: Dict[str, Any]):
        self.status_code = status_code
        self._data = data
        self.text = str(data)

    def json(self):
        return self._data


def test_http_rerank_provider_includes_top_n(monkeypatch):
    captured: Dict[str, Any] = {}

    def _fake_post(url, json, timeout):
        captured["url"] = url
        captured["json"] = json
        captured["timeout"] = timeout
        return _FakeResponse(200, {"scores": [0.1, 0.2], "meta": {"ok": True}})

    monkeypatch.setattr("providers.rerank.requests.post", _fake_post)

    provider = HttpRerankProvider("http://127.0.0.1:6007/rerank")
    scores, meta = provider.rerank("q", ["a", "b"], timeout_sec=3.0, top_n=2)

    assert scores == [0.1, 0.2]
    assert meta == {"ok": True}
    assert captured["json"]["top_n"] == 2