test_embedding_pipeline.py 10.6 KB
from typing import Any, Dict, List, Optional

import numpy as np
import pytest

from config import (
    FunctionScoreConfig,
    IndexConfig,
    QueryConfig,
    RerankConfig,
    SPUConfig,
    SearchConfig,
)
from embeddings.text_encoder import TextEmbeddingEncoder
from embeddings.image_encoder import CLIPImageEncoder
from embeddings.text_embedding_tei import TEITextModel
from embeddings.bf16 import encode_embedding_for_redis
from embeddings.cache_keys import build_image_cache_key, build_text_cache_key
from embeddings.config import CONFIG
from query import QueryParser
from context.request_context import create_request_context, set_current_request_context, clear_current_request_context


class _FakeRedis:
    def __init__(self):
        self.store: Dict[str, bytes] = {}

    def ping(self):
        return True

    def get(self, key: str):
        return self.store.get(key)

    def setex(self, key: str, _expire, value: bytes):
        self.store[key] = value
        return True

    def expire(self, key: str, _expire):
        return key in self.store

    def delete(self, key: str):
        self.store.pop(key, None)
        return True


class _FakeResponse:
    def __init__(self, payload: List[Optional[List[float]]]):
        self._payload = payload

    def raise_for_status(self):
        return None

    def json(self):
        return self._payload


class _FakeTranslator:
    def translate(
        self,
        text: str,
        target_lang: str,
        source_lang: Optional[str] = None,
        prompt: Optional[str] = None,
    ) -> str:
        return f"{text}-{target_lang}"


class _FakeQueryEncoder:
    def __init__(self):
        self.calls = []

    def encode(self, sentences, **kwargs):
        self.calls.append({"sentences": sentences, "kwargs": dict(kwargs)})
        if isinstance(sentences, str):
            sentences = [sentences]
        return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object)


class _FakeClipTextEncoder:
    def __init__(self):
        self.calls = []

    def encode_clip_text(self, text, **kwargs):
        self.calls.append({"text": text, "kwargs": dict(kwargs)})
        return np.array([0.44, 0.55, 0.66], dtype=np.float32)


def _tokenizer(text):
    return str(text).split()


class _FakeEmbeddingCache:
    def __init__(self):
        self.store: Dict[str, np.ndarray] = {}

    def get(self, key: str):
        return self.store.get(key)

    def set(self, key: str, embedding: np.ndarray):
        self.store[key] = np.asarray(embedding, dtype=np.float32)
        return True


def _build_test_config(*, image_embedding_field: Optional[str] = None) -> SearchConfig:
    return SearchConfig(
        field_boosts={"title.en": 3.0},
        indexes=[IndexConfig(name="default", label="default", fields=["title.en"], boost=1.0)],
        query_config=QueryConfig(
            supported_languages=["en", "zh"],
            default_language="en",
            enable_text_embedding=True,
            enable_query_rewrite=False,
            rewrite_dictionary={},
            text_embedding_field="title_embedding",
            image_embedding_field=image_embedding_field,
        ),
        function_score=FunctionScoreConfig(),
        rerank=RerankConfig(),
        spu_config=SPUConfig(enabled=True, spu_field="spu_id", inner_hits_size=3),
        es_index_name="test_products",
        es_settings={},
    )


def test_text_embedding_encoder_response_alignment(monkeypatch):
    fake_cache = _FakeEmbeddingCache()
    monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache)

    def _fake_post(url, json, timeout, **kwargs):
        assert url.endswith("/embed/text")
        assert json == ["hello", "world"]
        assert kwargs["params"]["priority"] == 0
        return _FakeResponse([[0.1, 0.2], [0.3, 0.4]])

    monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post)

    encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005")
    out = encoder.encode(["hello", "world"])

    assert len(out) == 2
    assert isinstance(out[0], np.ndarray)
    assert out[0].shape == (2,)
    assert isinstance(out[1], np.ndarray)
    assert out[1].shape == (2,)


def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch):
    fake_cache = _FakeEmbeddingCache()
    monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache)

    def _fake_post(url, json, timeout, **kwargs):
        return _FakeResponse([[0.1, 0.2], None])

    monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post)

    encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005")
    with pytest.raises(ValueError):
        encoder.encode(["hello", "world"])


def test_text_embedding_encoder_cache_hit(monkeypatch):
    fake_cache = _FakeEmbeddingCache()
    cached = np.array([0.9, 0.8], dtype=np.float32)
    fake_cache.store[build_text_cache_key("cached-text", normalize=True)] = cached
    monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache)

    calls = {"count": 0}

    def _fake_post(url, json, timeout, **kwargs):
        calls["count"] += 1
        return _FakeResponse([[0.3, 0.4]])

    monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post)

    encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005")
    out = encoder.encode(["cached-text", "new-text"])

    assert calls["count"] == 1
    assert np.allclose(out[0], cached)
    assert np.allclose(out[1], np.array([0.3, 0.4], dtype=np.float32))


def test_text_embedding_encoder_forwards_request_headers(monkeypatch):
    fake_cache = _FakeEmbeddingCache()
    monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache)

    captured = {}

    def _fake_post(url, json, timeout, **kwargs):
        captured["headers"] = dict(kwargs.get("headers") or {})
        return _FakeResponse([[0.1, 0.2]])

    monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post)

    context = create_request_context(reqid="req-ctx-1", uid="user-ctx-1")
    set_current_request_context(context)
    try:
        encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005")
        encoder.encode(["hello"])
    finally:
        clear_current_request_context()

    assert captured["headers"]["X-Request-ID"] == "req-ctx-1"
    assert captured["headers"]["X-User-ID"] == "user-ctx-1"


def test_image_embedding_encoder_cache_hit(monkeypatch):
    fake_cache = _FakeEmbeddingCache()
    cached = np.array([0.5, 0.6], dtype=np.float32)
    url = "https://example.com/a.jpg"
    fake_cache.store[
        build_image_cache_key(url, normalize=True, model_name=CONFIG.MULTIMODAL_MODEL_NAME)
    ] = cached
    monkeypatch.setattr("embeddings.image_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache)

    calls = {"count": 0}

    def _fake_post(url, params, json, timeout, **kwargs):
        calls["count"] += 1
        assert params["priority"] == 0
        return _FakeResponse([[0.1, 0.2]])

    monkeypatch.setattr("embeddings.image_encoder.requests.post", _fake_post)

    encoder = CLIPImageEncoder(service_url="http://127.0.0.1:6008")
    out = encoder.encode_batch(["https://example.com/a.jpg", "https://example.com/b.jpg"])

    assert calls["count"] == 1
    assert np.allclose(out[0], cached)
    assert np.allclose(out[1], np.array([0.1, 0.2], dtype=np.float32))


def test_image_embedding_encoder_passes_priority(monkeypatch):
    fake_cache = _FakeEmbeddingCache()
    monkeypatch.setattr("embeddings.image_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache)

    def _fake_post(url, params, json, timeout, **kwargs):
        assert params["priority"] == 1
        return _FakeResponse([[0.1, 0.2]])

    monkeypatch.setattr("embeddings.image_encoder.requests.post", _fake_post)

    encoder = CLIPImageEncoder(service_url="http://127.0.0.1:6008")
    out = encoder.encode_batch(["https://example.com/a.jpg"], priority=1)
    assert len(out) == 1
    assert np.allclose(out[0], np.array([0.1, 0.2], dtype=np.float32))


def test_query_parser_generates_query_vector_with_encoder():
    encoder = _FakeQueryEncoder()
    parser = QueryParser(
        config=_build_test_config(),
        text_encoder=encoder,
        translator=_FakeTranslator(),
        tokenizer=_tokenizer,
    )

    parsed = parser.parse("red dress", tenant_id="162", generate_vector=True)
    assert parsed.query_vector is not None
    assert parsed.query_vector.shape == (3,)
    assert encoder.calls
    assert encoder.calls[0]["kwargs"]["priority"] == 1


def test_query_parser_generates_image_query_vector_with_clip_text_encoder():
    text_encoder = _FakeQueryEncoder()
    image_encoder = _FakeClipTextEncoder()
    parser = QueryParser(
        config=_build_test_config(image_embedding_field="image_embedding.vector"),
        text_encoder=text_encoder,
        image_encoder=image_encoder,
        translator=_FakeTranslator(),
        tokenizer=_tokenizer,
    )

    parsed = parser.parse("red dress", tenant_id="162", generate_vector=True)
    assert parsed.query_vector is not None
    assert parsed.image_query_vector is not None
    assert parsed.image_query_vector.shape == (3,)
    assert image_encoder.calls
    assert image_encoder.calls[0]["text"] == "red dress"
    assert image_encoder.calls[0]["kwargs"]["priority"] == 1


def test_query_parser_skips_query_vector_when_disabled():
    parser = QueryParser(
        config=_build_test_config(),
        text_encoder=_FakeQueryEncoder(),
        translator=_FakeTranslator(),
        tokenizer=_tokenizer,
    )

    parsed = parser.parse("red dress", tenant_id="162", generate_vector=False)
    assert parsed.query_vector is None
    assert parsed.image_query_vector is None


def test_tei_text_model_splits_batches_over_client_limit(monkeypatch):
    monkeypatch.setattr(TEITextModel, "_health_check", lambda self: None)
    calls = []

    class _Response:
        def __init__(self, payload):
            self._payload = payload

        def raise_for_status(self):
            return None

        def json(self):
            return self._payload

    def _fake_post(url, json, timeout):
        inputs = list(json["inputs"])
        calls.append(inputs)
        return _Response([[float(idx)] for idx, _ in enumerate(inputs, start=1)])

    monkeypatch.setattr("embeddings.text_embedding_tei.requests.post", _fake_post)

    model = TEITextModel(
        base_url="http://127.0.0.1:8080",
        timeout_sec=20,
        max_client_batch_size=24,
    )
    vectors = model.encode([f"text-{idx}" for idx in range(25)], normalize_embeddings=False)

    assert len(calls) == 2
    assert len(calls[0]) == 24
    assert len(calls[1]) == 1
    assert len(vectors) == 25