test_es_query_builder.py 7.98 KB
from types import SimpleNamespace
from typing import Any, Dict

import numpy as np

from search.es_query_builder import ESQueryBuilder


def _builder() -> ESQueryBuilder:
    return ESQueryBuilder(
        match_fields=["title.en^3.0", "brief.en^1.0"],
        multilingual_fields=["title", "brief"],
        core_multilingual_fields=["title", "brief"],
        shared_fields=[],
        text_embedding_field="title_embedding",
        image_embedding_field="image_embedding.vector",
        default_language="en",
    )


def _recall_root(es_body: Dict[str, Any]) -> Dict[str, Any]:
    query_root = es_body["query"]
    if "bool" in query_root and query_root["bool"].get("must"):
        query_root = query_root["bool"]["must"][0]
    if "function_score" in query_root:
        query_root = query_root["function_score"]["query"]
    return query_root


def _recall_should_clauses(es_body: Dict[str, Any]) -> list[Dict[str, Any]]:
    root = _recall_root(es_body)
    should = root.get("bool", {}).get("should")
    if should:
        return should
    return [root]


def _recall_clause_name(clause: Dict[str, Any]) -> str | None:
    if "bool" in clause:
        return clause["bool"].get("_name")
    if "knn" in clause:
        return clause["knn"].get("_name")
    if "nested" in clause:
        return clause["nested"].get("_name")
    return None


def test_knn_clause_moves_under_query_should_and_uses_outer_filters():
    qb = _builder()
    q = qb.build_query(
        query_text="bags",
        query_vector=np.array([0.1, 0.2, 0.3]),
        range_filters={"min_price": {"gte": 50, "lt": 100}},
        enable_knn=True,
    )

    assert "knn" not in q
    should = _recall_should_clauses(q)
    assert any(clause.get("knn", {}).get("_name") == "knn_query" for clause in should)
    assert q["query"]["bool"]["filter"] == [{"range": {"min_price": {"gte": 50, "lt": 100}}}]


def test_knn_clause_uses_outer_query_filter_when_disjunctive_filters_present():
    qb = _builder()
    facets = [SimpleNamespace(field="category_name", disjunctive=True)]
    q = qb.build_query(
        query_text="bags",
        query_vector=np.array([0.1, 0.2, 0.3]),
        filters={"category_name": ["A", "B"], "vendor": "Nike"},
        range_filters={"min_price": {"gte": 50, "lt": 100}},
        facet_configs=facets,
        enable_knn=True,
    )

    assert "knn" not in q
    assert q["query"]["bool"]["filter"] == [
        {"term": {"vendor": "Nike"}},
        {"range": {"min_price": {"gte": 50, "lt": 100}}},
    ]
    assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}}


def test_knn_clause_has_name_and_no_embedded_filter():
    qb = _builder()
    q = qb.build_query(
        query_text="bags",
        query_vector=np.array([0.1, 0.2, 0.3]),
        enable_knn=True,
    )

    should = _recall_should_clauses(q)
    knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query")
    assert "filter" not in knn_clause
    assert knn_clause["_name"] == "knn_query"


def test_text_query_contains_only_base_and_translation_named_queries():
    qb = _builder()
    parsed_query = SimpleNamespace(
        rewritten_query="dress",
        detected_language="en",
        translations={"en": "dress", "zh": "连衣裙"},
    )

    q = qb.build_query(
        query_text="dress",
        parsed_query=parsed_query,
        enable_knn=False,
    )
    should = _recall_should_clauses(q)
    names = [clause["bool"]["_name"] for clause in should]

    assert names == ["base_query", "base_query_trans_zh"]
    base_should = should[0]["bool"]["should"]
    mm_types = [c["multi_match"]["type"] for c in base_should if "multi_match" in c]
    assert mm_types == ["best_fields", "phrase"]


def test_text_query_skips_duplicate_translation_same_as_base():
    qb = _builder()
    parsed_query = SimpleNamespace(
        rewritten_query="dress",
        detected_language="en",
        translations={"en": "dress"},
    )

    q = qb.build_query(
        query_text="dress",
        parsed_query=parsed_query,
        enable_knn=False,
    )

    query_root = q["query"]
    if "function_score" in query_root:
        query_root = query_root["function_score"]["query"]
    base_bool = query_root["bool"]
    assert base_bool["_name"] == "base_query"
    mm_types = [c["multi_match"]["type"] for c in base_bool["should"] if "multi_match" in c]
    assert mm_types == ["best_fields", "phrase"]


def test_product_title_exclusion_filter_is_applied_once_on_outer_query():
    qb = _builder()
    parsed_query = SimpleNamespace(
        rewritten_query="fitted dress",
        detected_language="en",
        translations={"zh": "修身 连衣裙"},
        product_title_exclusion_profile=SimpleNamespace(
            is_active=True,
            all_zh_title_exclusions=lambda: ["宽松"],
            all_en_title_exclusions=lambda: ["loose", "relaxed"],
        ),
    )

    q = qb.build_query(
        query_text="fitted dress",
        query_vector=np.array([0.1, 0.2, 0.3]),
        parsed_query=parsed_query,
        enable_knn=True,
    )

    expected_filter = {
        "bool": {
            "must_not": [
                {
                    "bool": {
                        "should": [
                            {"match_phrase": {"title.zh": {"query": "宽松"}}},
                            {"match_phrase": {"title.en": {"query": "loose"}}},
                            {"match_phrase": {"title.en": {"query": "relaxed"}}},
                        ],
                        "minimum_should_match": 1,
                    }
                }
            ]
        }
    }

    assert expected_filter in q["query"]["bool"]["filter"]
    should = _recall_should_clauses(q)
    knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query")
    assert "filter" not in knn_clause


def test_image_knn_clause_is_added_alongside_base_translation_and_text_knn():
    qb = _builder()
    parsed_query = SimpleNamespace(
        rewritten_query="street tee",
        detected_language="en",
        translations={"zh": "街头短袖"},
    )

    q = qb.build_query(
        query_text="street tee",
        query_vector=np.array([0.1, 0.2, 0.3]),
        image_query_vector=np.array([0.4, 0.5, 0.6]),
        parsed_query=parsed_query,
        enable_knn=True,
    )

    should = _recall_should_clauses(q)
    names = [
        _recall_clause_name(clause)
        for clause in should
    ]
    assert names == ["base_query", "base_query_trans_zh", "knn_query", "image_knn_query"]
    image_knn = next(clause["nested"] for clause in should if clause.get("nested", {}).get("_name") == "image_knn_query")
    assert image_knn["path"] == "image_embedding"
    assert image_knn["score_mode"] == "max"
    assert image_knn["query"]["knn"]["field"] == "image_embedding.vector"


def test_text_knn_plan_is_reused_for_ann_and_exact_rescore():
    qb = _builder()
    parsed_query = SimpleNamespace(query_tokens=["a", "b", "c", "d", "e"])

    ann_clause = qb.build_text_knn_clause(
        np.array([0.1, 0.2, 0.3]),
        parsed_query=parsed_query,
    )
    exact_clause = qb.build_exact_text_knn_rescore_clause(
        np.array([0.1, 0.2, 0.3]),
        parsed_query=parsed_query,
    )

    assert ann_clause is not None
    assert exact_clause is not None
    assert ann_clause["knn"]["k"] == qb.knn_text_k_long
    assert ann_clause["knn"]["num_candidates"] == qb.knn_text_num_candidates_long
    assert ann_clause["knn"]["boost"] == qb.knn_text_boost * 1.4
    assert exact_clause["script_score"]["script"]["params"]["boost"] == qb.knn_text_boost * 1.4


def test_image_knn_plan_is_reused_for_ann_and_exact_rescore():
    qb = _builder()

    ann_clause = qb.build_image_knn_clause(np.array([0.4, 0.5, 0.6]))
    exact_clause = qb.build_exact_image_knn_rescore_clause(np.array([0.4, 0.5, 0.6]))

    assert ann_clause is not None
    assert exact_clause is not None
    assert ann_clause["nested"]["query"]["knn"]["boost"] == qb.knn_image_boost
    assert exact_clause["nested"]["query"]["script_score"]["script"]["params"]["boost"] == qb.knn_image_boost