Blame view

tests/test_embedding_service_limits.py 4.2 KB
4747e2f4   tangwang   embedding perform...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
  import asyncio
  
  import numpy as np
  import pytest
  
  import embeddings.server as embedding_server
  
  
  class _DummyClient:
      host = "127.0.0.1"
  
  
  class _DummyRequest:
      def __init__(self, headers=None):
          self.headers = headers or {}
          self.client = _DummyClient()
  
  
  class _DummyResponse:
      def __init__(self):
          self.headers = {}
  
  
  class _FakeTextModel:
      def encode(self, texts, batch_size, device, normalize_embeddings):
          assert texts == ["hello world"]
          assert normalize_embeddings is False
          return [np.array([1.0, 2.0, 3.0], dtype=np.float32)]
  
  
7214c2e7   tangwang   mplemented**
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
  class _FakeImageModel:
      def encode_image_urls(self, urls, batch_size, normalize_embeddings):
          raise AssertionError("image backend should not be called on cache hit")
  
  
  class _FakeCache:
      def __init__(self, store=None):
          self.store = store or {}
          self.redis_client = object()
  
      def get(self, key):
          return self.store.get(key)
  
      def set(self, key, value):
          self.store[key] = np.asarray(value, dtype=np.float32)
          return True
  
  
4747e2f4   tangwang   embedding perform...
49
50
51
52
53
54
55
56
57
58
59
  def test_health_exposes_limit_stats(monkeypatch):
      monkeypatch.setattr(
          embedding_server,
          "_text_request_limiter",
          embedding_server._InflightLimiter("text", 2),
      )
      monkeypatch.setattr(
          embedding_server,
          "_image_request_limiter",
          embedding_server._InflightLimiter("image", 1),
      )
7214c2e7   tangwang   mplemented**
60
61
      monkeypatch.setattr(embedding_server, "_text_model", object())
      monkeypatch.setattr(embedding_server, "_image_model", object())
4747e2f4   tangwang   embedding perform...
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  
      payload = embedding_server.health()
  
      assert payload["status"] == "ok"
      assert payload["limits"]["text"]["limit"] == 2
      assert payload["limits"]["image"]["limit"] == 1
      assert "queue_depth" in payload["text_microbatch"]
  
  
  def test_embed_image_rejects_when_image_lane_is_full(monkeypatch):
      limiter = embedding_server._InflightLimiter("image", 1)
      acquired, _ = limiter.try_acquire()
      assert acquired is True
      monkeypatch.setattr(embedding_server, "_image_request_limiter", limiter)
7214c2e7   tangwang   mplemented**
76
      monkeypatch.setattr(embedding_server, "_image_model", object())
4747e2f4   tangwang   embedding perform...
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  
      response = _DummyResponse()
      with pytest.raises(embedding_server.HTTPException) as exc_info:
          asyncio.run(
              embedding_server.embed_image(
                  ["https://example.com/a.jpg"],
                  _DummyRequest(),
                  response,
              )
          )
  
      assert exc_info.value.status_code == embedding_server._OVERLOAD_STATUS_CODE
      assert "busy" in exc_info.value.detail
      assert limiter.snapshot()["rejected_total"] == 1
  
  
  def test_embed_text_returns_request_id_and_vector(monkeypatch):
      monkeypatch.setattr(
          embedding_server,
          "_text_request_limiter",
          embedding_server._InflightLimiter("text", 2),
      )
      monkeypatch.setattr(embedding_server, "_text_model", _FakeTextModel())
      monkeypatch.setattr(embedding_server, "_text_backend_name", "tei")
  
      request = _DummyRequest(headers={"X-Request-ID": "req-123456"})
      response = _DummyResponse()
      result = asyncio.run(
          embedding_server.embed_text(
              ["hello world"],
              request,
              response,
              normalize=False,
          )
      )
  
      assert response.headers["X-Request-ID"] == "req-123456"
      assert result == [[1.0, 2.0, 3.0]]
7214c2e7   tangwang   mplemented**
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  
  
  def test_embed_image_service_cache_hit_bypasses_backend(monkeypatch):
      cache_key = embedding_server.build_image_cache_key("https://example.com/a.jpg", normalize=True)
      fake_cache = _FakeCache({cache_key: np.array([0.7, 0.8], dtype=np.float32)})
      monkeypatch.setattr(
          embedding_server,
          "_image_request_limiter",
          embedding_server._InflightLimiter("image", 1),
      )
      monkeypatch.setattr(embedding_server, "_image_model", _FakeImageModel())
      monkeypatch.setattr(embedding_server, "_image_cache", fake_cache)
  
      request = _DummyRequest(headers={"X-Request-ID": "img-cache-hit"})
      response = _DummyResponse()
      result = asyncio.run(
          embedding_server.embed_image(
              ["https://example.com/a.jpg"],
              request,
              response,
              normalize=True,
          )
      )
  
      assert response.headers["X-Request-ID"] == "img-cache-hit"
      assert result == [[0.699999988079071, 0.800000011920929]]