test_embedding_service_limits.py 4.36 KB
import asyncio

import numpy as np
import pytest

import embeddings.server as embedding_server


class _DummyClient:
    host = "127.0.0.1"


class _DummyRequest:
    def __init__(self, headers=None):
        self.headers = headers or {}
        self.client = _DummyClient()


class _DummyResponse:
    def __init__(self):
        self.headers = {}


class _FakeTextModel:
    def encode(self, texts, batch_size, device, normalize_embeddings):
        assert texts == ["hello world"]
        assert normalize_embeddings is False
        return [np.array([1.0, 2.0, 3.0], dtype=np.float32)]


class _FakeImageModel:
    def encode_image_urls(self, urls, batch_size, normalize_embeddings):
        raise AssertionError("image backend should not be called on cache hit")


class _FakeCache:
    def __init__(self, store=None):
        self.store = store or {}
        self.redis_client = object()

    def get(self, key):
        return self.store.get(key)

    def set(self, key, value):
        self.store[key] = np.asarray(value, dtype=np.float32)
        return True


def test_health_exposes_limit_stats(monkeypatch):
    monkeypatch.setattr(
        embedding_server,
        "_text_request_limiter",
        embedding_server._InflightLimiter("text", 2),
    )
    monkeypatch.setattr(
        embedding_server,
        "_image_request_limiter",
        embedding_server._InflightLimiter("image", 1),
    )
    monkeypatch.setattr(embedding_server, "_text_model", object())
    monkeypatch.setattr(embedding_server, "_image_model", object())

    payload = embedding_server.health()

    assert payload["status"] == "ok"
    assert payload["limits"]["text"]["limit"] == 2
    assert payload["limits"]["image"]["limit"] == 1
    assert "queue_depth" in payload["text_microbatch"]


def test_embed_image_rejects_when_image_lane_is_full(monkeypatch):
    # Ensure no cache hit (module-level Redis cache may contain this URL from other tests).
    monkeypatch.setattr(embedding_server, "_image_cache", _FakeCache({}))
    limiter = embedding_server._InflightLimiter("image", 1)
    acquired, _ = limiter.try_acquire()
    assert acquired is True
    monkeypatch.setattr(embedding_server, "_image_request_limiter", limiter)
    monkeypatch.setattr(embedding_server, "_image_model", object())

    response = _DummyResponse()
    with pytest.raises(embedding_server.HTTPException) as exc_info:
        asyncio.run(
            embedding_server.embed_image(
                ["https://example.com/a.jpg"],
                _DummyRequest(),
                response,
            )
        )

    assert exc_info.value.status_code == embedding_server._OVERLOAD_STATUS_CODE
    assert "busy" in exc_info.value.detail
    assert limiter.snapshot()["rejected_total"] == 1


def test_embed_text_returns_request_id_and_vector(monkeypatch):
    monkeypatch.setattr(
        embedding_server,
        "_text_request_limiter",
        embedding_server._InflightLimiter("text", 2),
    )
    monkeypatch.setattr(embedding_server, "_text_model", _FakeTextModel())
    monkeypatch.setattr(embedding_server, "_text_backend_name", "tei")

    request = _DummyRequest(headers={"X-Request-ID": "req-123456"})
    response = _DummyResponse()
    result = asyncio.run(
        embedding_server.embed_text(
            ["hello world"],
            request,
            response,
            normalize=False,
        )
    )

    assert response.headers["X-Request-ID"] == "req-123456"
    assert result == [[1.0, 2.0, 3.0]]


def test_embed_image_service_cache_hit_bypasses_backend(monkeypatch):
    cache_key = embedding_server.build_image_cache_key("https://example.com/a.jpg", normalize=True)
    fake_cache = _FakeCache({cache_key: np.array([0.7, 0.8], dtype=np.float32)})
    monkeypatch.setattr(
        embedding_server,
        "_image_request_limiter",
        embedding_server._InflightLimiter("image", 1),
    )
    monkeypatch.setattr(embedding_server, "_image_model", _FakeImageModel())
    monkeypatch.setattr(embedding_server, "_image_cache", fake_cache)

    request = _DummyRequest(headers={"X-Request-ID": "img-cache-hit"})
    response = _DummyResponse()
    result = asyncio.run(
        embedding_server.embed_image(
            ["https://example.com/a.jpg"],
            request,
            response,
            normalize=True,
        )
    )

    assert response.headers["X-Request-ID"] == "img-cache-hit"
    assert result == [[0.699999988079071, 0.800000011920929]]