test_embedding_pipeline.py 5.01 KB
from typing import Any, Dict, List, Optional

import numpy as np
import pytest

from config import (
    FunctionScoreConfig,
    IndexConfig,
    QueryConfig,
    RerankConfig,
    SPUConfig,
    SearchConfig,
)
from embeddings.text_encoder import TextEmbeddingEncoder
from embeddings.bf16 import encode_embedding_for_redis
from query import QueryParser


class _FakeRedis:
    def __init__(self):
        self.store: Dict[str, bytes] = {}

    def ping(self):
        return True

    def get(self, key: str):
        return self.store.get(key)

    def setex(self, key: str, _expire, value: bytes):
        self.store[key] = value
        return True

    def expire(self, key: str, _expire):
        return key in self.store

    def delete(self, key: str):
        self.store.pop(key, None)
        return True


class _FakeResponse:
    def __init__(self, payload: List[Optional[List[float]]]):
        self._payload = payload

    def raise_for_status(self):
        return None

    def json(self):
        return self._payload


class _FakeTranslator:
    def translate(
        self,
        text: str,
        target_lang: str,
        source_lang: Optional[str] = None,
        prompt: Optional[str] = None,
    ) -> str:
        return f"{text}-{target_lang}"


class _FakeQueryEncoder:
    def encode(self, sentences, **kwargs):
        if isinstance(sentences, str):
            sentences = [sentences]
        return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object)


def _build_test_config() -> SearchConfig:
    return SearchConfig(
        field_boosts={"title.en": 3.0},
        indexes=[IndexConfig(name="default", label="default", fields=["title.en"], boost=1.0)],
        query_config=QueryConfig(
            supported_languages=["en", "zh"],
            default_language="en",
            enable_text_embedding=True,
            enable_query_rewrite=False,
            rewrite_dictionary={},
            text_embedding_field="title_embedding",
            image_embedding_field=None,
        ),
        function_score=FunctionScoreConfig(),
        rerank=RerankConfig(),
        spu_config=SPUConfig(enabled=True, spu_field="spu_id", inner_hits_size=3),
        es_index_name="test_products",
        tenant_config={},
        es_settings={},
        services={},
    )


def test_text_embedding_encoder_response_alignment(monkeypatch):
    fake_redis = _FakeRedis()
    monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis)

    def _fake_post(url, json, timeout, **kwargs):
        assert url.endswith("/embed/text")
        assert json == ["hello", "world"]
        return _FakeResponse([[0.1, 0.2], [0.3, 0.4]])

    monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post)

    encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005")
    out = encoder.encode(["hello", "world"])

    assert len(out) == 2
    assert isinstance(out[0], np.ndarray)
    assert out[0].shape == (2,)
    assert isinstance(out[1], np.ndarray)
    assert out[1].shape == (2,)


def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch):
    fake_redis = _FakeRedis()
    monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis)

    def _fake_post(url, json, timeout, **kwargs):
        return _FakeResponse([[0.1, 0.2], None])

    monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post)

    encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005")
    with pytest.raises(ValueError):
        encoder.encode(["hello", "world"])


def test_text_embedding_encoder_cache_hit(monkeypatch):
    fake_redis = _FakeRedis()
    cached = np.array([0.9, 0.8], dtype=np.float32)
    fake_redis.store["embedding:cached-text"] = encode_embedding_for_redis(cached)
    monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis)

    calls = {"count": 0}

    def _fake_post(url, json, timeout, **kwargs):
        calls["count"] += 1
        return _FakeResponse([[0.3, 0.4]])

    monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post)

    encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005")
    out = encoder.encode(["cached-text", "new-text"])

    assert calls["count"] == 1
    assert np.allclose(out[0], cached)
    assert np.allclose(out[1], np.array([0.3, 0.4], dtype=np.float32))


def test_query_parser_generates_query_vector_with_encoder():
    parser = QueryParser(
        config=_build_test_config(),
        text_encoder=_FakeQueryEncoder(),
        translator=_FakeTranslator(),
    )

    parsed = parser.parse("red dress", tenant_id="162", generate_vector=True)
    assert parsed.query_vector is not None
    assert parsed.query_vector.shape == (3,)


def test_query_parser_skips_query_vector_when_disabled():
    parser = QueryParser(
        config=_build_test_config(),
        text_encoder=_FakeQueryEncoder(),
        translator=_FakeTranslator(),
    )

    parsed = parser.parse("red dress", tenant_id="162", generate_vector=False)
    assert parsed.query_vector is None