test_embedding_service_limits.py
2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import asyncio
import numpy as np
import pytest
import embeddings.server as embedding_server
class _DummyClient:
host = "127.0.0.1"
class _DummyRequest:
def __init__(self, headers=None):
self.headers = headers or {}
self.client = _DummyClient()
class _DummyResponse:
def __init__(self):
self.headers = {}
class _FakeTextModel:
def encode(self, texts, batch_size, device, normalize_embeddings):
assert texts == ["hello world"]
assert normalize_embeddings is False
return [np.array([1.0, 2.0, 3.0], dtype=np.float32)]
def test_health_exposes_limit_stats(monkeypatch):
monkeypatch.setattr(
embedding_server,
"_text_request_limiter",
embedding_server._InflightLimiter("text", 2),
)
monkeypatch.setattr(
embedding_server,
"_image_request_limiter",
embedding_server._InflightLimiter("image", 1),
)
payload = embedding_server.health()
assert payload["status"] == "ok"
assert payload["limits"]["text"]["limit"] == 2
assert payload["limits"]["image"]["limit"] == 1
assert "queue_depth" in payload["text_microbatch"]
def test_embed_image_rejects_when_image_lane_is_full(monkeypatch):
limiter = embedding_server._InflightLimiter("image", 1)
acquired, _ = limiter.try_acquire()
assert acquired is True
monkeypatch.setattr(embedding_server, "_image_request_limiter", limiter)
response = _DummyResponse()
with pytest.raises(embedding_server.HTTPException) as exc_info:
asyncio.run(
embedding_server.embed_image(
["https://example.com/a.jpg"],
_DummyRequest(),
response,
)
)
assert exc_info.value.status_code == embedding_server._OVERLOAD_STATUS_CODE
assert "busy" in exc_info.value.detail
assert limiter.snapshot()["rejected_total"] == 1
def test_embed_text_returns_request_id_and_vector(monkeypatch):
monkeypatch.setattr(
embedding_server,
"_text_request_limiter",
embedding_server._InflightLimiter("text", 2),
)
monkeypatch.setattr(embedding_server, "_text_model", _FakeTextModel())
monkeypatch.setattr(embedding_server, "_text_backend_name", "tei")
request = _DummyRequest(headers={"X-Request-ID": "req-123456"})
response = _DummyResponse()
result = asyncio.run(
embedding_server.embed_text(
["hello world"],
request,
response,
normalize=False,
)
)
assert response.headers["X-Request-ID"] == "req-123456"
assert result == [[1.0, 2.0, 3.0]]