test_embedding_service_limits.py 2.59 KB
import asyncio

import numpy as np
import pytest

import embeddings.server as embedding_server


class _DummyClient:
    host = "127.0.0.1"


class _DummyRequest:
    def __init__(self, headers=None):
        self.headers = headers or {}
        self.client = _DummyClient()


class _DummyResponse:
    def __init__(self):
        self.headers = {}


class _FakeTextModel:
    def encode(self, texts, batch_size, device, normalize_embeddings):
        assert texts == ["hello world"]
        assert normalize_embeddings is False
        return [np.array([1.0, 2.0, 3.0], dtype=np.float32)]


def test_health_exposes_limit_stats(monkeypatch):
    monkeypatch.setattr(
        embedding_server,
        "_text_request_limiter",
        embedding_server._InflightLimiter("text", 2),
    )
    monkeypatch.setattr(
        embedding_server,
        "_image_request_limiter",
        embedding_server._InflightLimiter("image", 1),
    )

    payload = embedding_server.health()

    assert payload["status"] == "ok"
    assert payload["limits"]["text"]["limit"] == 2
    assert payload["limits"]["image"]["limit"] == 1
    assert "queue_depth" in payload["text_microbatch"]


def test_embed_image_rejects_when_image_lane_is_full(monkeypatch):
    limiter = embedding_server._InflightLimiter("image", 1)
    acquired, _ = limiter.try_acquire()
    assert acquired is True
    monkeypatch.setattr(embedding_server, "_image_request_limiter", limiter)

    response = _DummyResponse()
    with pytest.raises(embedding_server.HTTPException) as exc_info:
        asyncio.run(
            embedding_server.embed_image(
                ["https://example.com/a.jpg"],
                _DummyRequest(),
                response,
            )
        )

    assert exc_info.value.status_code == embedding_server._OVERLOAD_STATUS_CODE
    assert "busy" in exc_info.value.detail
    assert limiter.snapshot()["rejected_total"] == 1


def test_embed_text_returns_request_id_and_vector(monkeypatch):
    monkeypatch.setattr(
        embedding_server,
        "_text_request_limiter",
        embedding_server._InflightLimiter("text", 2),
    )
    monkeypatch.setattr(embedding_server, "_text_model", _FakeTextModel())
    monkeypatch.setattr(embedding_server, "_text_backend_name", "tei")

    request = _DummyRequest(headers={"X-Request-ID": "req-123456"})
    response = _DummyResponse()
    result = asyncio.run(
        embedding_server.embed_text(
            ["hello world"],
            request,
            response,
            normalize=False,
        )
    )

    assert response.headers["X-Request-ID"] == "req-123456"
    assert result == [[1.0, 2.0, 3.0]]