test_embedding_service_limits.py
4.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import asyncio
import numpy as np
import pytest
import embeddings.server as embedding_server
class _DummyClient:
host = "127.0.0.1"
class _DummyRequest:
def __init__(self, headers=None):
self.headers = headers or {}
self.client = _DummyClient()
class _DummyResponse:
def __init__(self):
self.headers = {}
class _FakeTextModel:
def encode(self, texts, batch_size, device, normalize_embeddings):
assert texts == ["hello world"]
assert normalize_embeddings is False
return [np.array([1.0, 2.0, 3.0], dtype=np.float32)]
class _FakeImageModel:
def encode_image_urls(self, urls, batch_size, normalize_embeddings):
raise AssertionError("image backend should not be called on cache hit")
class _FakeCache:
def __init__(self, store=None):
self.store = store or {}
self.redis_client = object()
def get(self, key):
return self.store.get(key)
def set(self, key, value):
self.store[key] = np.asarray(value, dtype=np.float32)
return True
def test_health_exposes_limit_stats(monkeypatch):
monkeypatch.setattr(
embedding_server,
"_text_request_limiter",
embedding_server._InflightLimiter("text", 2),
)
monkeypatch.setattr(
embedding_server,
"_image_request_limiter",
embedding_server._InflightLimiter("image", 1),
)
monkeypatch.setattr(embedding_server, "_text_model", object())
monkeypatch.setattr(embedding_server, "_image_model", object())
payload = embedding_server.health()
assert payload["status"] == "ok"
assert payload["limits"]["text"]["limit"] == 2
assert payload["limits"]["image"]["limit"] == 1
assert "queue_depth" in payload["text_microbatch"]
def test_embed_image_rejects_when_image_lane_is_full(monkeypatch):
# Ensure no cache hit (module-level Redis cache may contain this URL from other tests).
monkeypatch.setattr(embedding_server, "_image_cache", _FakeCache({}))
limiter = embedding_server._InflightLimiter("image", 1)
acquired, _ = limiter.try_acquire()
assert acquired is True
monkeypatch.setattr(embedding_server, "_image_request_limiter", limiter)
monkeypatch.setattr(embedding_server, "_image_model", object())
response = _DummyResponse()
with pytest.raises(embedding_server.HTTPException) as exc_info:
asyncio.run(
embedding_server.embed_image(
["https://example.com/a.jpg"],
_DummyRequest(),
response,
)
)
assert exc_info.value.status_code == embedding_server._OVERLOAD_STATUS_CODE
assert "busy" in exc_info.value.detail
assert limiter.snapshot()["rejected_total"] == 1
def test_embed_text_returns_request_id_and_vector(monkeypatch):
monkeypatch.setattr(
embedding_server,
"_text_request_limiter",
embedding_server._InflightLimiter("text", 2),
)
monkeypatch.setattr(embedding_server, "_text_model", _FakeTextModel())
monkeypatch.setattr(embedding_server, "_text_backend_name", "tei")
request = _DummyRequest(headers={"X-Request-ID": "req-123456"})
response = _DummyResponse()
result = asyncio.run(
embedding_server.embed_text(
["hello world"],
request,
response,
normalize=False,
)
)
assert response.headers["X-Request-ID"] == "req-123456"
assert result == [[1.0, 2.0, 3.0]]
def test_embed_image_service_cache_hit_bypasses_backend(monkeypatch):
cache_key = embedding_server.build_image_cache_key("https://example.com/a.jpg", normalize=True)
fake_cache = _FakeCache({cache_key: np.array([0.7, 0.8], dtype=np.float32)})
monkeypatch.setattr(
embedding_server,
"_image_request_limiter",
embedding_server._InflightLimiter("image", 1),
)
monkeypatch.setattr(embedding_server, "_image_model", _FakeImageModel())
monkeypatch.setattr(embedding_server, "_image_cache", fake_cache)
request = _DummyRequest(headers={"X-Request-ID": "img-cache-hit"})
response = _DummyResponse()
result = asyncio.run(
embedding_server.embed_image(
["https://example.com/a.jpg"],
request,
response,
normalize=True,
)
)
assert response.headers["X-Request-ID"] == "img-cache-hit"
assert result == [[0.699999988079071, 0.800000011920929]]