test_rerank_provider_topn.py
982 Bytes
from __future__ import annotations
from typing import Any, Dict
from providers.rerank import HttpRerankProvider
class _FakeResponse:
def __init__(self, status_code: int, data: Dict[str, Any]):
self.status_code = status_code
self._data = data
self.text = str(data)
def json(self):
return self._data
def test_http_rerank_provider_includes_top_n(monkeypatch):
captured: Dict[str, Any] = {}
def _fake_post(url, json, timeout):
captured["url"] = url
captured["json"] = json
captured["timeout"] = timeout
return _FakeResponse(200, {"scores": [0.1, 0.2], "meta": {"ok": True}})
monkeypatch.setattr("providers.rerank.requests.post", _fake_post)
provider = HttpRerankProvider("http://127.0.0.1:6007/rerank")
scores, meta = provider.rerank("q", ["a", "b"], timeout_sec=3.0, top_n=2)
assert scores == [0.1, 0.2]
assert meta == {"ok": True}
assert captured["json"]["top_n"] == 2