Blame view

tests/test_rerank_provider_topn.py 982 Bytes
d31c7f65   tangwang   补充云服务reranker
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
  from __future__ import annotations
  
  from typing import Any, Dict
  
  from providers.rerank import HttpRerankProvider
  
  
  class _FakeResponse:
      def __init__(self, status_code: int, data: Dict[str, Any]):
          self.status_code = status_code
          self._data = data
          self.text = str(data)
  
      def json(self):
          return self._data
  
  
  def test_http_rerank_provider_includes_top_n(monkeypatch):
      captured: Dict[str, Any] = {}
  
      def _fake_post(url, json, timeout):
          captured["url"] = url
          captured["json"] = json
          captured["timeout"] = timeout
          return _FakeResponse(200, {"scores": [0.1, 0.2], "meta": {"ok": True}})
  
      monkeypatch.setattr("providers.rerank.requests.post", _fake_post)
  
      provider = HttpRerankProvider("http://127.0.0.1:6007/rerank")
      scores, meta = provider.rerank("q", ["a", "b"], timeout_sec=3.0, top_n=2)
  
      assert scores == [0.1, 0.2]
      assert meta == {"ok": True}
      assert captured["json"]["top_n"] == 2