from __future__ import annotations from dataclasses import dataclass from pathlib import Path from types import SimpleNamespace from typing import Any, Dict, List import numpy as np import yaml from config import ( ConfigLoader, FunctionScoreConfig, IndexConfig, QueryConfig, RerankConfig, SPUConfig, SearchConfig, ) from context import create_request_context from search.searcher import Searcher @dataclass class _FakeParsedQuery: original_query: str query_normalized: str rewritten_query: str detected_language: str = "en" translations: Dict[str, str] = None query_vector: Any = None domain: str = "default" def to_dict(self) -> Dict[str, Any]: return { "original_query": self.original_query, "query_normalized": self.query_normalized, "rewritten_query": self.rewritten_query, "detected_language": self.detected_language, "translations": self.translations or {}, "domain": self.domain, } class _FakeQueryParser: def parse(self, query: str, tenant_id: str, generate_vector: bool, context: Any): return _FakeParsedQuery( original_query=query, query_normalized=query, rewritten_query=query, translations={}, ) class _FakeQueryBuilder: def build_query(self, **kwargs): return { "query": {"match_all": {}}, "size": kwargs["size"], "from": kwargs["from_"], } def build_facets(self, facets: Any): return {} def add_sorting(self, es_query: Dict[str, Any], sort_by: str, sort_order: str): return es_query class _FakeESClient: def __init__(self, total_hits: int = 5000): self.calls: List[Dict[str, Any]] = [] self.total_hits = total_hits @staticmethod def _apply_source_filter(src: Dict[str, Any], source_spec: Any) -> Dict[str, Any]: if source_spec is None: return dict(src) if source_spec is False: return {} if isinstance(source_spec, dict): includes = source_spec.get("includes") or [] elif isinstance(source_spec, list): includes = source_spec else: includes = [] if not includes: return dict(src) return {k: v for k, v in src.items() if k in set(includes)} @staticmethod def _full_source(doc_id: str) -> Dict[str, Any]: return { "spu_id": doc_id, "title": {"en": f"product-{doc_id}"}, "brief": {"en": f"brief-{doc_id}"}, "vendor": {"en": f"vendor-{doc_id}"}, "skus": [], } def search( self, index_name: str, body: Dict[str, Any], size: int, from_: int, include_named_queries_score: bool = False, ): self.calls.append( { "index_name": index_name, "body": body, "size": size, "from_": from_, "include_named_queries_score": include_named_queries_score, } ) ids_query = (((body or {}).get("query") or {}).get("ids") or {}).get("values") source_spec = (body or {}).get("_source") if isinstance(ids_query, list): # Return reversed order intentionally; caller should restore original ranking order. ids = [str(i) for i in ids_query][::-1] hits = [] for doc_id in ids: src = self._apply_source_filter(self._full_source(doc_id), source_spec) hit = {"_id": doc_id, "_score": 1.0} if source_spec is not False: hit["_source"] = src hits.append(hit) else: end = min(from_ + size, self.total_hits) hits = [] for i in range(from_, end): doc_id = str(i) src = self._apply_source_filter(self._full_source(doc_id), source_spec) hit = {"_id": doc_id, "_score": float(self.total_hits - i)} if source_spec is not False: hit["_source"] = src hits.append(hit) return { "took": 8, "hits": { "total": {"value": self.total_hits}, "max_score": hits[0]["_score"] if hits else 0.0, "hits": hits, }, } def _build_search_config(*, rerank_enabled: bool = True, rerank_window: int = 384): return SearchConfig( field_boosts={"title.en": 3.0}, indexes=[IndexConfig(name="default", label="default", fields=["title.en"])], query_config=QueryConfig(enable_text_embedding=False, enable_query_rewrite=False), function_score=FunctionScoreConfig(), rerank=RerankConfig(enabled=rerank_enabled, rerank_window=rerank_window), spu_config=SPUConfig(enabled=False), es_index_name="test_products", es_settings={}, ) def _build_searcher(config: SearchConfig, es_client: _FakeESClient) -> Searcher: searcher = Searcher( es_client=es_client, config=config, query_parser=_FakeQueryParser(), ) searcher.query_builder = _FakeQueryBuilder() return searcher class _FakeTextEncoder: def __init__(self, vectors: Dict[str, List[float]]): self.vectors = { key: np.array(value, dtype=np.float32) for key, value in vectors.items() } def encode(self, sentences, priority: int = 0, **kwargs): if isinstance(sentences, str): sentences = [sentences] return np.array([self.vectors[text] for text in sentences], dtype=object) def test_config_loader_rerank_enabled_defaults_true(tmp_path: Path): config_data = { "es_index_name": "test_products", "field_boosts": {"title.en": 3.0}, "indexes": [{"name": "default", "label": "default", "fields": ["title.en"]}], "query_config": {"supported_languages": ["en"], "default_language": "en"}, "spu_config": {"enabled": False}, "function_score": {"score_mode": "sum", "boost_mode": "multiply", "functions": []}, "rerank": {"rerank_window": 384}, } config_path = tmp_path / "config.yaml" config_path.write_text(yaml.safe_dump(config_data), encoding="utf-8") loader = ConfigLoader(config_path) loaded = loader.load_config(validate=False) assert loaded.rerank.enabled is True def test_searcher_reranks_top_window_by_default(monkeypatch): es_client = _FakeESClient() searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client) context = create_request_context(reqid="t1", uid="u1") monkeypatch.setattr( "search.searcher.get_tenant_config_loader", lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), ) called: Dict[str, Any] = {"count": 0, "docs": 0} def _fake_run_rerank(**kwargs): called["count"] += 1 called["docs"] = len(kwargs["es_response"]["hits"]["hits"]) return kwargs["es_response"], None, [] monkeypatch.setattr("search.rerank_client.run_rerank", _fake_run_rerank) result = searcher.search( query="toy", tenant_id="162", from_=20, size=10, context=context, enable_rerank=None, ) assert called["count"] == 1 # 应当对配置的 rerank_window 条文档做重排预取 window = searcher.config.rerank.rerank_window assert called["docs"] == window assert es_client.calls[0]["from_"] == 0 assert es_client.calls[0]["size"] == window assert es_client.calls[0]["include_named_queries_score"] is True assert es_client.calls[0]["body"]["_source"] == {"includes": ["title"]} assert len(es_client.calls) == 2 assert es_client.calls[1]["size"] == 10 assert es_client.calls[1]["from_"] == 0 assert es_client.calls[1]["body"]["query"]["ids"]["values"] == [str(i) for i in range(20, 30)] assert len(result.results) == 10 assert result.results[0].spu_id == "20" assert result.results[0].brief == "brief-20" def test_searcher_rerank_prefetch_source_follows_doc_template(monkeypatch): es_client = _FakeESClient() searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client) context = create_request_context(reqid="t1b", uid="u1b") monkeypatch.setattr( "search.searcher.get_tenant_config_loader", lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), ) monkeypatch.setattr("search.rerank_client.run_rerank", lambda **kwargs: (kwargs["es_response"], None, [])) searcher.search( query="toy", tenant_id="162", from_=0, size=5, context=context, enable_rerank=None, rerank_doc_template="{title} {vendor} {brief}", ) assert es_client.calls[0]["body"]["_source"] == {"includes": ["brief", "title", "vendor"]} def test_searcher_skips_rerank_when_request_explicitly_false(monkeypatch): es_client = _FakeESClient() searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client) context = create_request_context(reqid="t2", uid="u2") monkeypatch.setattr( "search.searcher.get_tenant_config_loader", lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), ) called: Dict[str, int] = {"count": 0} def _fake_run_rerank(**kwargs): called["count"] += 1 return kwargs["es_response"], None, [] monkeypatch.setattr("search.rerank_client.run_rerank", _fake_run_rerank) searcher.search( query="toy", tenant_id="162", from_=20, size=10, context=context, enable_rerank=False, ) assert called["count"] == 0 assert es_client.calls[0]["from_"] == 20 assert es_client.calls[0]["size"] == 10 assert es_client.calls[0]["include_named_queries_score"] is False assert len(es_client.calls) == 1 def test_searcher_skips_rerank_when_page_exceeds_window(monkeypatch): es_client = _FakeESClient() searcher = _build_searcher(_build_search_config(rerank_enabled=True, rerank_window=384), es_client) context = create_request_context(reqid="t3", uid="u3") monkeypatch.setattr( "search.searcher.get_tenant_config_loader", lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), ) called: Dict[str, int] = {"count": 0} def _fake_run_rerank(**kwargs): called["count"] += 1 return kwargs["es_response"], None, [] monkeypatch.setattr("search.rerank_client.run_rerank", _fake_run_rerank) searcher.search( query="toy", tenant_id="162", from_=995, size=10, context=context, enable_rerank=None, ) assert called["count"] == 0 assert es_client.calls[0]["from_"] == 995 assert es_client.calls[0]["size"] == 10 assert es_client.calls[0]["include_named_queries_score"] is False assert len(es_client.calls) == 1 def test_searcher_promotes_sku_when_option1_matches_translated_query(monkeypatch): es_client = _FakeESClient(total_hits=1) searcher = _build_searcher(_build_search_config(rerank_enabled=False), es_client) context = create_request_context(reqid="sku-text", uid="u-sku-text") monkeypatch.setattr( "search.searcher.get_tenant_config_loader", lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en", "zh"]}), ) class _TranslatedQueryParser: text_encoder = None def parse(self, query: str, tenant_id: str, generate_vector: bool, context: Any): return _FakeParsedQuery( original_query=query, query_normalized=query, rewritten_query=query, translations={"en": "black dress"}, ) searcher.query_parser = _TranslatedQueryParser() def _full_source_with_skus(doc_id: str) -> Dict[str, Any]: return { "spu_id": doc_id, "title": {"en": f"product-{doc_id}"}, "brief": {"en": f"brief-{doc_id}"}, "vendor": {"en": f"vendor-{doc_id}"}, "image_url": "https://img/default.jpg", "skus": [ {"sku_id": "sku-red", "option1_value": "Red", "image_src": "https://img/red.jpg"}, {"sku_id": "sku-black", "option1_value": "Black", "image_src": "https://img/black.jpg"}, ], } monkeypatch.setattr(_FakeESClient, "_full_source", staticmethod(_full_source_with_skus)) result = searcher.search( query="黑色 连衣裙", tenant_id="162", from_=0, size=1, context=context, enable_rerank=False, ) assert len(result.results) == 1 assert result.results[0].skus[0].sku_id == "sku-black" assert result.results[0].image_url == "https://img/black.jpg" def test_searcher_promotes_sku_by_embedding_when_query_has_no_direct_option_match(monkeypatch): es_client = _FakeESClient(total_hits=1) searcher = _build_searcher(_build_search_config(rerank_enabled=False), es_client) context = create_request_context(reqid="sku-embed", uid="u-sku-embed") monkeypatch.setattr( "search.searcher.get_tenant_config_loader", lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), ) encoder = _FakeTextEncoder( { "linen summer dress": [0.8, 0.2], "Red": [1.0, 0.0], "Blue": [0.0, 1.0], } ) class _EmbeddingQueryParser: text_encoder = encoder def parse(self, query: str, tenant_id: str, generate_vector: bool, context: Any): return _FakeParsedQuery( original_query=query, query_normalized=query, rewritten_query=query, translations={}, query_vector=np.array([0.0, 1.0], dtype=np.float32), ) searcher.query_parser = _EmbeddingQueryParser() def _full_source_with_skus(doc_id: str) -> Dict[str, Any]: return { "spu_id": doc_id, "title": {"en": f"product-{doc_id}"}, "brief": {"en": f"brief-{doc_id}"}, "vendor": {"en": f"vendor-{doc_id}"}, "image_url": "https://img/default.jpg", "skus": [ {"sku_id": "sku-red", "option1_value": "Red", "image_src": "https://img/red.jpg"}, {"sku_id": "sku-blue", "option1_value": "Blue", "image_src": "https://img/blue.jpg"}, ], } monkeypatch.setattr(_FakeESClient, "_full_source", staticmethod(_full_source_with_skus)) result = searcher.search( query="linen summer dress", tenant_id="162", from_=0, size=1, context=context, enable_rerank=False, ) assert len(result.results) == 1 assert result.results[0].skus[0].sku_id == "sku-blue" assert result.results[0].image_url == "https://img/blue.jpg"