from typing import Any, Dict, List, Optional import numpy as np import pytest from config import ( FunctionScoreConfig, IndexConfig, QueryConfig, RerankConfig, SPUConfig, SearchConfig, ) from embeddings.text_encoder import TextEmbeddingEncoder from embeddings.image_encoder import CLIPImageEncoder from embeddings.text_embedding_tei import TEITextModel from embeddings.bf16 import encode_embedding_for_redis from embeddings.cache_keys import build_image_cache_key, build_text_cache_key from query import QueryParser from context.request_context import create_request_context, set_current_request_context, clear_current_request_context class _FakeRedis: def __init__(self): self.store: Dict[str, bytes] = {} def ping(self): return True def get(self, key: str): return self.store.get(key) def setex(self, key: str, _expire, value: bytes): self.store[key] = value return True def expire(self, key: str, _expire): return key in self.store def delete(self, key: str): self.store.pop(key, None) return True class _FakeResponse: def __init__(self, payload: List[Optional[List[float]]]): self._payload = payload def raise_for_status(self): return None def json(self): return self._payload class _FakeTranslator: def translate( self, text: str, target_lang: str, source_lang: Optional[str] = None, prompt: Optional[str] = None, ) -> str: return f"{text}-{target_lang}" class _FakeQueryEncoder: def __init__(self): self.calls = [] def encode(self, sentences, **kwargs): self.calls.append({"sentences": sentences, "kwargs": dict(kwargs)}) if isinstance(sentences, str): sentences = [sentences] return np.array([np.array([0.11, 0.22, 0.33], dtype=np.float32) for _ in sentences], dtype=object) def _tokenizer(text): return str(text).split() class _FakeEmbeddingCache: def __init__(self): self.store: Dict[str, np.ndarray] = {} def get(self, key: str): return self.store.get(key) def set(self, key: str, embedding: np.ndarray): self.store[key] = np.asarray(embedding, dtype=np.float32) return True def _build_test_config() -> SearchConfig: return SearchConfig( field_boosts={"title.en": 3.0}, indexes=[IndexConfig(name="default", label="default", fields=["title.en"], boost=1.0)], query_config=QueryConfig( supported_languages=["en", "zh"], default_language="en", enable_text_embedding=True, enable_query_rewrite=False, rewrite_dictionary={}, text_embedding_field="title_embedding", image_embedding_field=None, ), function_score=FunctionScoreConfig(), rerank=RerankConfig(), spu_config=SPUConfig(enabled=True, spu_field="spu_id", inner_hits_size=3), es_index_name="test_products", es_settings={}, ) def test_text_embedding_encoder_response_alignment(monkeypatch): fake_cache = _FakeEmbeddingCache() monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) def _fake_post(url, json, timeout, **kwargs): assert url.endswith("/embed/text") assert json == ["hello", "world"] assert kwargs["params"]["priority"] == 0 return _FakeResponse([[0.1, 0.2], [0.3, 0.4]]) monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post) encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005") out = encoder.encode(["hello", "world"]) assert len(out) == 2 assert isinstance(out[0], np.ndarray) assert out[0].shape == (2,) assert isinstance(out[1], np.ndarray) assert out[1].shape == (2,) def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): fake_cache = _FakeEmbeddingCache() monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) def _fake_post(url, json, timeout, **kwargs): return _FakeResponse([[0.1, 0.2], None]) monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post) encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005") with pytest.raises(ValueError): encoder.encode(["hello", "world"]) def test_text_embedding_encoder_cache_hit(monkeypatch): fake_cache = _FakeEmbeddingCache() cached = np.array([0.9, 0.8], dtype=np.float32) fake_cache.store[build_text_cache_key("cached-text", normalize=True)] = cached monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) calls = {"count": 0} def _fake_post(url, json, timeout, **kwargs): calls["count"] += 1 return _FakeResponse([[0.3, 0.4]]) monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post) encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005") out = encoder.encode(["cached-text", "new-text"]) assert calls["count"] == 1 assert np.allclose(out[0], cached) assert np.allclose(out[1], np.array([0.3, 0.4], dtype=np.float32)) def test_text_embedding_encoder_forwards_request_headers(monkeypatch): fake_cache = _FakeEmbeddingCache() monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) captured = {} def _fake_post(url, json, timeout, **kwargs): captured["headers"] = dict(kwargs.get("headers") or {}) return _FakeResponse([[0.1, 0.2]]) monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post) context = create_request_context(reqid="req-ctx-1", uid="user-ctx-1") set_current_request_context(context) try: encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005") encoder.encode(["hello"]) finally: clear_current_request_context() assert captured["headers"]["X-Request-ID"] == "req-ctx-1" assert captured["headers"]["X-User-ID"] == "user-ctx-1" def test_image_embedding_encoder_cache_hit(monkeypatch): fake_cache = _FakeEmbeddingCache() cached = np.array([0.5, 0.6], dtype=np.float32) url = "https://example.com/a.jpg" fake_cache.store[build_image_cache_key(url, normalize=True)] = cached monkeypatch.setattr("embeddings.image_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) calls = {"count": 0} def _fake_post(url, params, json, timeout, **kwargs): calls["count"] += 1 assert params["priority"] == 0 return _FakeResponse([[0.1, 0.2]]) monkeypatch.setattr("embeddings.image_encoder.requests.post", _fake_post) encoder = CLIPImageEncoder(service_url="http://127.0.0.1:6008") out = encoder.encode_batch(["https://example.com/a.jpg", "https://example.com/b.jpg"]) assert calls["count"] == 1 assert np.allclose(out[0], cached) assert np.allclose(out[1], np.array([0.1, 0.2], dtype=np.float32)) def test_image_embedding_encoder_passes_priority(monkeypatch): fake_cache = _FakeEmbeddingCache() monkeypatch.setattr("embeddings.image_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) def _fake_post(url, params, json, timeout, **kwargs): assert params["priority"] == 1 return _FakeResponse([[0.1, 0.2]]) monkeypatch.setattr("embeddings.image_encoder.requests.post", _fake_post) encoder = CLIPImageEncoder(service_url="http://127.0.0.1:6008") out = encoder.encode_batch(["https://example.com/a.jpg"], priority=1) assert len(out) == 1 assert np.allclose(out[0], np.array([0.1, 0.2], dtype=np.float32)) def test_query_parser_generates_query_vector_with_encoder(): encoder = _FakeQueryEncoder() parser = QueryParser( config=_build_test_config(), text_encoder=encoder, translator=_FakeTranslator(), tokenizer=_tokenizer, ) parsed = parser.parse("red dress", tenant_id="162", generate_vector=True) assert parsed.query_vector is not None assert parsed.query_vector.shape == (3,) assert encoder.calls assert encoder.calls[0]["kwargs"]["priority"] == 1 def test_query_parser_skips_query_vector_when_disabled(): parser = QueryParser( config=_build_test_config(), text_encoder=_FakeQueryEncoder(), translator=_FakeTranslator(), tokenizer=_tokenizer, ) parsed = parser.parse("red dress", tenant_id="162", generate_vector=False) assert parsed.query_vector is None def test_tei_text_model_splits_batches_over_client_limit(monkeypatch): monkeypatch.setattr(TEITextModel, "_health_check", lambda self: None) calls = [] class _Response: def __init__(self, payload): self._payload = payload def raise_for_status(self): return None def json(self): return self._payload def _fake_post(url, json, timeout): inputs = list(json["inputs"]) calls.append(inputs) return _Response([[float(idx)] for idx, _ in enumerate(inputs, start=1)]) monkeypatch.setattr("embeddings.text_embedding_tei.requests.post", _fake_post) model = TEITextModel( base_url="http://127.0.0.1:8080", timeout_sec=20, max_client_batch_size=24, ) vectors = model.encode([f"text-{idx}" for idx in range(25)], normalize_embeddings=False) assert len(calls) == 2 assert len(calls[0]) == 24 assert len(calls[1]) == 1 assert len(vectors) == 25