test_search_rerank_window.py 10.6 KB
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List

import yaml

from config import (
    ConfigLoader,
    FunctionScoreConfig,
    IndexConfig,
    QueryConfig,
    RerankConfig,
    SPUConfig,
    SearchConfig,
)
from context import create_request_context
from search.searcher import Searcher


@dataclass
class _FakeParsedQuery:
    original_query: str
    query_normalized: str
    rewritten_query: str
    detected_language: str = "en"
    translations: Dict[str, str] = None
    query_vector: Any = None
    domain: str = "default"

    def to_dict(self) -> Dict[str, Any]:
        return {
            "original_query": self.original_query,
            "query_normalized": self.query_normalized,
            "rewritten_query": self.rewritten_query,
            "detected_language": self.detected_language,
            "translations": self.translations or {},
            "domain": self.domain,
        }


class _FakeQueryParser:
    def parse(self, query: str, tenant_id: str, generate_vector: bool, context: Any):
        return _FakeParsedQuery(
            original_query=query,
            query_normalized=query,
            rewritten_query=query,
            translations={},
        )


class _FakeQueryBuilder:
    def build_query(self, **kwargs):
        return {
            "query": {"match_all": {}},
            "size": kwargs["size"],
            "from": kwargs["from_"],
        }

    def build_facets(self, facets: Any):
        return {}

    def add_sorting(self, es_query: Dict[str, Any], sort_by: str, sort_order: str):
        return es_query


class _FakeESClient:
    def __init__(self, total_hits: int = 5000):
        self.calls: List[Dict[str, Any]] = []
        self.total_hits = total_hits

    @staticmethod
    def _apply_source_filter(src: Dict[str, Any], source_spec: Any) -> Dict[str, Any]:
        if source_spec is None:
            return dict(src)
        if source_spec is False:
            return {}
        if isinstance(source_spec, dict):
            includes = source_spec.get("includes") or []
        elif isinstance(source_spec, list):
            includes = source_spec
        else:
            includes = []
        if not includes:
            return dict(src)
        return {k: v for k, v in src.items() if k in set(includes)}

    @staticmethod
    def _full_source(doc_id: str) -> Dict[str, Any]:
        return {
            "spu_id": doc_id,
            "title": {"en": f"product-{doc_id}"},
            "brief": {"en": f"brief-{doc_id}"},
            "vendor": {"en": f"vendor-{doc_id}"},
            "skus": [],
        }

    def search(
        self,
        index_name: str,
        body: Dict[str, Any],
        size: int,
        from_: int,
        include_named_queries_score: bool = False,
    ):
        self.calls.append(
            {
                "index_name": index_name,
                "body": body,
                "size": size,
                "from_": from_,
                "include_named_queries_score": include_named_queries_score,
            }
        )
        ids_query = (((body or {}).get("query") or {}).get("ids") or {}).get("values")
        source_spec = (body or {}).get("_source")

        if isinstance(ids_query, list):
            # Return reversed order intentionally; caller should restore original ranking order.
            ids = [str(i) for i in ids_query][::-1]
            hits = []
            for doc_id in ids:
                src = self._apply_source_filter(self._full_source(doc_id), source_spec)
                hit = {"_id": doc_id, "_score": 1.0}
                if source_spec is not False:
                    hit["_source"] = src
                hits.append(hit)
        else:
            end = min(from_ + size, self.total_hits)
            hits = []
            for i in range(from_, end):
                doc_id = str(i)
                src = self._apply_source_filter(self._full_source(doc_id), source_spec)
                hit = {"_id": doc_id, "_score": float(self.total_hits - i)}
                if source_spec is not False:
                    hit["_source"] = src
                hits.append(hit)

        return {
            "took": 8,
            "hits": {
                "total": {"value": self.total_hits},
                "max_score": hits[0]["_score"] if hits else 0.0,
                "hits": hits,
            },
        }


def _build_search_config(*, rerank_enabled: bool = True, rerank_window: int = 384):
    return SearchConfig(
        field_boosts={"title.en": 3.0},
        indexes=[IndexConfig(name="default", label="default", fields=["title.en"])],
        query_config=QueryConfig(enable_text_embedding=False, enable_query_rewrite=False),
        function_score=FunctionScoreConfig(),
        rerank=RerankConfig(enabled=rerank_enabled, rerank_window=rerank_window),
        spu_config=SPUConfig(enabled=False),
        es_index_name="test_products",
        tenant_config={},
        es_settings={},
        services={},
    )


def _build_searcher(config: SearchConfig, es_client: _FakeESClient) -> Searcher:
    searcher = Searcher(
        es_client=es_client,
        config=config,
        query_parser=_FakeQueryParser(),
    )
    searcher.query_builder = _FakeQueryBuilder()
    return searcher


def test_config_loader_rerank_enabled_defaults_true(tmp_path: Path):
    config_data = {
        "es_index_name": "test_products",
        "field_boosts": {"title.en": 3.0},
        "indexes": [{"name": "default", "label": "default", "fields": ["title.en"]}],
        "query_config": {"supported_languages": ["en"], "default_language": "en"},
        "spu_config": {"enabled": False},
        "function_score": {"score_mode": "sum", "boost_mode": "multiply", "functions": []},
        "rerank": {"rerank_window": 384},
    }
    config_path = tmp_path / "config.yaml"
    config_path.write_text(yaml.safe_dump(config_data), encoding="utf-8")

    loader = ConfigLoader(config_path)
    loaded = loader.load_config(validate=False)

    assert loaded.rerank.enabled is True


def test_searcher_reranks_top_window_by_default(monkeypatch):
    es_client = _FakeESClient()
    searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client)
    context = create_request_context(reqid="t1", uid="u1")

    monkeypatch.setattr(
        "search.searcher.get_tenant_config_loader",
        lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}),
    )

    called: Dict[str, Any] = {"count": 0, "docs": 0}

    def _fake_run_rerank(**kwargs):
        called["count"] += 1
        called["docs"] = len(kwargs["es_response"]["hits"]["hits"])
        return kwargs["es_response"], None, []

    monkeypatch.setattr("search.rerank_client.run_rerank", _fake_run_rerank)

    result = searcher.search(
        query="toy",
        tenant_id="162",
        from_=20,
        size=10,
        context=context,
        enable_rerank=None,
    )

    assert called["count"] == 1
    # 应当对配置的 rerank_window 条文档做重排预取
    window = searcher.config.rerank.rerank_window
    assert called["docs"] == window
    assert es_client.calls[0]["from_"] == 0
    assert es_client.calls[0]["size"] == window
    assert es_client.calls[0]["include_named_queries_score"] is True
    assert es_client.calls[0]["body"]["_source"] == {"includes": ["title"]}
    assert len(es_client.calls) == 2
    assert es_client.calls[1]["size"] == 10
    assert es_client.calls[1]["from_"] == 0
    assert es_client.calls[1]["body"]["query"]["ids"]["values"] == [str(i) for i in range(20, 30)]
    assert len(result.results) == 10
    assert result.results[0].spu_id == "20"
    assert result.results[0].brief == "brief-20"


def test_searcher_rerank_prefetch_source_follows_doc_template(monkeypatch):
    es_client = _FakeESClient()
    searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client)
    context = create_request_context(reqid="t1b", uid="u1b")

    monkeypatch.setattr(
        "search.searcher.get_tenant_config_loader",
        lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}),
    )
    monkeypatch.setattr("search.rerank_client.run_rerank", lambda **kwargs: (kwargs["es_response"], None, []))

    searcher.search(
        query="toy",
        tenant_id="162",
        from_=0,
        size=5,
        context=context,
        enable_rerank=None,
        rerank_doc_template="{title} {vendor} {brief}",
    )

    assert es_client.calls[0]["body"]["_source"] == {"includes": ["brief", "title", "vendor"]}


def test_searcher_skips_rerank_when_request_explicitly_false(monkeypatch):
    es_client = _FakeESClient()
    searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client)
    context = create_request_context(reqid="t2", uid="u2")

    monkeypatch.setattr(
        "search.searcher.get_tenant_config_loader",
        lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}),
    )

    called: Dict[str, int] = {"count": 0}

    def _fake_run_rerank(**kwargs):
        called["count"] += 1
        return kwargs["es_response"], None, []

    monkeypatch.setattr("search.rerank_client.run_rerank", _fake_run_rerank)

    searcher.search(
        query="toy",
        tenant_id="162",
        from_=20,
        size=10,
        context=context,
        enable_rerank=False,
    )

    assert called["count"] == 0
    assert es_client.calls[0]["from_"] == 20
    assert es_client.calls[0]["size"] == 10
    assert es_client.calls[0]["include_named_queries_score"] is False
    assert len(es_client.calls) == 1


def test_searcher_skips_rerank_when_page_exceeds_window(monkeypatch):
    es_client = _FakeESClient()
    searcher = _build_searcher(_build_search_config(rerank_enabled=True, rerank_window=384), es_client)
    context = create_request_context(reqid="t3", uid="u3")

    monkeypatch.setattr(
        "search.searcher.get_tenant_config_loader",
        lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}),
    )

    called: Dict[str, int] = {"count": 0}

    def _fake_run_rerank(**kwargs):
        called["count"] += 1
        return kwargs["es_response"], None, []

    monkeypatch.setattr("search.rerank_client.run_rerank", _fake_run_rerank)

    searcher.search(
        query="toy",
        tenant_id="162",
        from_=995,
        size=10,
        context=context,
        enable_rerank=None,
    )

    assert called["count"] == 0
    assert es_client.calls[0]["from_"] == 995
    assert es_client.calls[0]["size"] == 10
    assert es_client.calls[0]["include_named_queries_score"] is False
    assert len(es_client.calls) == 1