test_es_query_builder.py 7.1 KB
from types import SimpleNamespace
from typing import Any, Dict

import numpy as np

from search.es_query_builder import ESQueryBuilder


def _builder() -> ESQueryBuilder:
    return ESQueryBuilder(
        match_fields=["title.en^3.0", "brief.en^1.0"],
        text_embedding_field="title_embedding",
        default_language="en",
    )


def _lexical_clause(query_root: Dict[str, Any]) -> Dict[str, Any]:
    """Return the first named lexical bool clause from query_root."""
    if "bool" in query_root and query_root["bool"].get("_name"):
        return query_root["bool"]
    for clause in query_root.get("bool", {}).get("should", []):
        clause_bool = clause.get("bool") or {}
        if clause_bool.get("_name"):
            return clause_bool
    raise AssertionError("no lexical bool clause in query_root")


def _lexical_combined_fields(query_root: Dict[str, Any]) -> list:
    return _lexical_clause(query_root)["must"][0]["combined_fields"]["fields"]


def test_knn_prefilter_includes_range_filters():
    qb = _builder()
    q = qb.build_query(
        query_text="bags",
        query_vector=np.array([0.1, 0.2, 0.3]),
        range_filters={"min_price": {"gte": 50, "lt": 100}},
        enable_knn=True,
    )

    assert "knn" in q
    assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}}


def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present():
    qb = _builder()
    facets = [SimpleNamespace(field="category_name", disjunctive=True)]
    q = qb.build_query(
        query_text="bags",
        query_vector=np.array([0.1, 0.2, 0.3]),
        filters={"category_name": ["A", "B"], "vendor": "Nike"},
        range_filters={"min_price": {"gte": 50, "lt": 100}},
        facet_configs=facets,
        enable_knn=True,
    )

    assert "knn" in q
    assert "filter" in q["knn"]
    knn_filter = q["knn"]["filter"]
    assert knn_filter == {
        "bool": {
            "filter": [
                {"term": {"vendor": "Nike"}},
                {"range": {"min_price": {"gte": 50, "lt": 100}}},
            ]
        }
    }
    assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}}


def test_knn_prefilter_not_added_without_filters():
    qb = _builder()
    q = qb.build_query(
        query_text="bags",
        query_vector=np.array([0.1, 0.2, 0.3]),
        enable_knn=True,
    )

    assert "knn" in q
    assert "filter" not in q["knn"]
    assert q["knn"]["_name"] == "knn_query"


def test_text_query_contains_only_base_and_translation_named_queries():
    qb = _builder()
    parsed_query = SimpleNamespace(
        rewritten_query="dress",
        detected_language="en",
        translations={"en": "dress", "zh": "连衣裙"},
    )

    q = qb.build_query(
        query_text="dress",
        parsed_query=parsed_query,
        enable_knn=False,
        index_languages=["en", "zh", "fr"],
    )
    should = q["query"]["bool"]["should"]
    names = [clause["bool"]["_name"] for clause in should]

    assert names == ["base_query", "base_query_trans_zh"]
    base_should = q["query"]["bool"]["should"][0]["bool"]["should"]
    assert [clause["multi_match"]["type"] for clause in base_should] == ["best_fields", "phrase"]


def test_text_query_skips_duplicate_translation_same_as_base():
    qb = _builder()
    parsed_query = SimpleNamespace(
        rewritten_query="dress",
        detected_language="en",
        translations={"en": "dress"},
    )

    q = qb.build_query(
        query_text="dress",
        parsed_query=parsed_query,
        enable_knn=False,
        index_languages=["en", "zh"],
    )

    root = q["query"]
    assert root["bool"]["_name"] == "base_query"
    assert [clause["multi_match"]["type"] for clause in root["bool"]["should"]] == ["best_fields", "phrase"]


def test_mixed_script_merges_en_fields_into_zh_clause():
    qb = ESQueryBuilder(
        match_fields=["title.en^3.0"],
        multilingual_fields=["title", "brief"],
        shared_fields=[],
        text_embedding_field="title_embedding",
        default_language="en",
    )
    parsed_query = SimpleNamespace(
        rewritten_query="法式 dress",
        detected_language="zh",
        translations={},
        contains_chinese=True,
        contains_english=True,
    )
    q = qb.build_query(
        query_text="法式 dress",
        parsed_query=parsed_query,
        enable_knn=False,
        index_languages=["zh", "en"],
    )
    fields = _lexical_combined_fields(q["query"])
    bases = {f.split("^", 1)[0] for f in fields}
    assert "title.zh" in bases and "title.en" in bases
    assert "brief.zh" in bases and "brief.en" in bases
    # Merged supplemental language fields use boost * 0.6 by default.
    assert "title.en^0.6" in fields
    assert "brief.en^0.6" in fields


def test_mixed_script_merges_zh_fields_into_en_clause():
    qb = ESQueryBuilder(
        match_fields=["title.en^3.0"],
        multilingual_fields=["title"],
        shared_fields=[],
        text_embedding_field="title_embedding",
        default_language="en",
    )
    parsed_query = SimpleNamespace(
        rewritten_query="red 连衣裙",
        detected_language="en",
        translations={},
        contains_chinese=True,
        contains_english=True,
    )
    q = qb.build_query(
        query_text="red 连衣裙",
        parsed_query=parsed_query,
        enable_knn=False,
        index_languages=["zh", "en"],
    )
    fields = _lexical_combined_fields(q["query"])
    bases = {f.split("^", 1)[0] for f in fields}
    assert "title.en" in bases and "title.zh" in bases
    assert "title.zh^0.6" in fields


def test_mixed_script_merged_fields_scale_configured_boosts():
    qb = ESQueryBuilder(
        match_fields=["title.en^3.0"],
        multilingual_fields=["title"],
        shared_fields=[],
        field_boosts={"title.zh": 5.0, "title.en": 10.0},
        text_embedding_field="title_embedding",
        default_language="en",
    )
    parsed_query = SimpleNamespace(
        rewritten_query="法式 dress",
        detected_language="zh",
        translations={},
        contains_chinese=True,
        contains_english=True,
    )
    q = qb.build_query(
        query_text="法式 dress",
        parsed_query=parsed_query,
        enable_knn=False,
        index_languages=["zh", "en"],
    )
    fields = _lexical_combined_fields(q["query"])
    assert "title.zh^5.0" in fields
    assert "title.en^6.0" in fields  # 10.0 * 0.6


def test_mixed_script_does_not_merge_en_when_not_in_index_languages():
    qb = ESQueryBuilder(
        match_fields=["title.zh^3.0"],
        multilingual_fields=["title"],
        shared_fields=[],
        text_embedding_field="title_embedding",
        default_language="zh",
    )
    parsed_query = SimpleNamespace(
        rewritten_query="法式 dress",
        detected_language="zh",
        translations={},
        contains_chinese=True,
        contains_english=True,
    )
    q = qb.build_query(
        query_text="法式 dress",
        parsed_query=parsed_query,
        enable_knn=False,
        index_languages=["zh"],
    )
    fields = _lexical_combined_fields(q["query"])
    bases = {f.split("^", 1)[0] for f in fields}
    assert "title.zh" in bases
    assert "title.en" not in bases