import pickle from typing import Any, Dict, List, Optional import numpy as np import pytest from config import ( FunctionScoreConfig, IndexConfig, QueryConfig, RerankConfig, SPUConfig, SearchConfig, ) from embeddings.text_encoder import TextEmbeddingEncoder from query import QueryParser class _FakeRedis: def __init__(self): self.store: Dict[str, bytes] = {} def ping(self): return True def get(self, key: str): return self.store.get(key) def setex(self, key: str, _expire, value: bytes): self.store[key] = value return True def expire(self, key: str, _expire): return key in self.store def delete(self, key: str): self.store.pop(key, None) return True class _FakeResponse: def __init__(self, payload: List[Optional[List[float]]]): self._payload = payload def raise_for_status(self): return None def json(self): return self._payload class _FakeTranslator: def translate( self, text: str, target_lang: str, source_lang: Optional[str] = None, prompt: Optional[str] = None, ) -> str: return f"{text}-{target_lang}" class _FakeQueryEncoder: def encode(self, sentences, **kwargs): if isinstance(sentences, str): sentences = [sentences] return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object) def _build_test_config() -> SearchConfig: return SearchConfig( field_boosts={"title.en": 3.0}, indexes=[IndexConfig(name="default", label="default", fields=["title.en"], boost=1.0)], query_config=QueryConfig( supported_languages=["en", "zh"], default_language="en", enable_text_embedding=True, enable_query_rewrite=False, rewrite_dictionary={}, translation_prompts={"query_zh": "e-commerce domain", "query_en": "e-commerce domain"}, text_embedding_field="title_embedding", image_embedding_field=None, ), function_score=FunctionScoreConfig(), function_score=FunctionScoreConfig(), rerank=RerankConfig(), spu_config=SPUConfig(enabled=True, spu_field="spu_id", inner_hits_size=3), es_index_name="test_products", tenant_config={}, es_settings={}, services={}, ) def test_text_embedding_encoder_response_alignment(monkeypatch): fake_redis = _FakeRedis() monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) def _fake_post(url, json, timeout, **kwargs): assert url.endswith("/embed/text") assert json == ["hello", "world"] return _FakeResponse([[0.1, 0.2], [0.3, 0.4]]) monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post) encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005") out = encoder.encode(["hello", "world"]) assert len(out) == 2 assert isinstance(out[0], np.ndarray) assert out[0].shape == (2,) assert isinstance(out[1], np.ndarray) assert out[1].shape == (2,) def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): fake_redis = _FakeRedis() monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) def _fake_post(url, json, timeout, **kwargs): return _FakeResponse([[0.1, 0.2], None]) monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post) encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005") with pytest.raises(ValueError): encoder.encode(["hello", "world"]) def test_text_embedding_encoder_cache_hit(monkeypatch): fake_redis = _FakeRedis() cached = np.array([0.9, 0.8], dtype=np.float32) fake_redis.store["embedding:generic:cached-text"] = pickle.dumps(cached) monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) calls = {"count": 0} def _fake_post(url, json, timeout, **kwargs): calls["count"] += 1 return _FakeResponse([[0.3, 0.4]]) monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post) encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005") out = encoder.encode(["cached-text", "new-text"]) assert calls["count"] == 1 assert np.allclose(out[0], cached) assert np.allclose(out[1], np.array([0.3, 0.4], dtype=np.float32)) def test_query_parser_generates_query_vector_with_encoder(): parser = QueryParser( config=_build_test_config(), text_encoder=_FakeQueryEncoder(), translator=_FakeTranslator(), ) parsed = parser.parse("red dress", tenant_id="162", generate_vector=True) assert parsed.query_vector is not None assert parsed.query_vector.shape == (3,) def test_query_parser_skips_query_vector_when_disabled(): parser = QueryParser( config=_build_test_config(), text_encoder=_FakeQueryEncoder(), translator=_FakeTranslator(), ) parsed = parser.parse("red dress", tenant_id="162", generate_vector=False) assert parsed.query_vector is None