Blame view

tests/test_embedding_service_limits.py 2.59 KB
4747e2f4   tangwang   embedding perform...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  import asyncio
  
  import numpy as np
  import pytest
  
  import embeddings.server as embedding_server
  
  
  class _DummyClient:
      host = "127.0.0.1"
  
  
  class _DummyRequest:
      def __init__(self, headers=None):
          self.headers = headers or {}
          self.client = _DummyClient()
  
  
  class _DummyResponse:
      def __init__(self):
          self.headers = {}
  
  
  class _FakeTextModel:
      def encode(self, texts, batch_size, device, normalize_embeddings):
          assert texts == ["hello world"]
          assert normalize_embeddings is False
          return [np.array([1.0, 2.0, 3.0], dtype=np.float32)]
  
  
  def test_health_exposes_limit_stats(monkeypatch):
      monkeypatch.setattr(
          embedding_server,
          "_text_request_limiter",
          embedding_server._InflightLimiter("text", 2),
      )
      monkeypatch.setattr(
          embedding_server,
          "_image_request_limiter",
          embedding_server._InflightLimiter("image", 1),
      )
  
      payload = embedding_server.health()
  
      assert payload["status"] == "ok"
      assert payload["limits"]["text"]["limit"] == 2
      assert payload["limits"]["image"]["limit"] == 1
      assert "queue_depth" in payload["text_microbatch"]
  
  
  def test_embed_image_rejects_when_image_lane_is_full(monkeypatch):
      limiter = embedding_server._InflightLimiter("image", 1)
      acquired, _ = limiter.try_acquire()
      assert acquired is True
      monkeypatch.setattr(embedding_server, "_image_request_limiter", limiter)
  
      response = _DummyResponse()
      with pytest.raises(embedding_server.HTTPException) as exc_info:
          asyncio.run(
              embedding_server.embed_image(
                  ["https://example.com/a.jpg"],
                  _DummyRequest(),
                  response,
              )
          )
  
      assert exc_info.value.status_code == embedding_server._OVERLOAD_STATUS_CODE
      assert "busy" in exc_info.value.detail
      assert limiter.snapshot()["rejected_total"] == 1
  
  
  def test_embed_text_returns_request_id_and_vector(monkeypatch):
      monkeypatch.setattr(
          embedding_server,
          "_text_request_limiter",
          embedding_server._InflightLimiter("text", 2),
      )
      monkeypatch.setattr(embedding_server, "_text_model", _FakeTextModel())
      monkeypatch.setattr(embedding_server, "_text_backend_name", "tei")
  
      request = _DummyRequest(headers={"X-Request-ID": "req-123456"})
      response = _DummyResponse()
      result = asyncio.run(
          embedding_server.embed_text(
              ["hello world"],
              request,
              response,
              normalize=False,
          )
      )
  
      assert response.headers["X-Request-ID"] == "req-123456"
      assert result == [[1.0, 2.0, 3.0]]