test_es_query_builder.py 1.77 KB
from types import SimpleNamespace

import numpy as np

from search.es_query_builder import ESQueryBuilder


def _builder() -> ESQueryBuilder:
    return ESQueryBuilder(
        match_fields=["title.en^3.0", "brief.en^1.0"],
        text_embedding_field="title_embedding",
        default_language="en",
    )


def test_knn_prefilter_includes_range_filters():
    qb = _builder()
    q = qb.build_query(
        query_text="bags",
        query_vector=np.array([0.1, 0.2, 0.3]),
        range_filters={"min_price": {"gte": 50, "lt": 100}},
        enable_knn=True,
    )

    assert "knn" in q
    assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}}


def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present():
    qb = _builder()
    facets = [SimpleNamespace(field="category_name", disjunctive=True)]
    q = qb.build_query(
        query_text="bags",
        query_vector=np.array([0.1, 0.2, 0.3]),
        filters={"category_name": ["A", "B"], "vendor": "Nike"},
        range_filters={"min_price": {"gte": 50, "lt": 100}},
        facet_configs=facets,
        enable_knn=True,
    )

    assert "knn" in q
    assert "filter" in q["knn"]
    knn_filter = q["knn"]["filter"]
    assert knn_filter == {
        "bool": {
            "filter": [
                {"term": {"vendor": "Nike"}},
                {"range": {"min_price": {"gte": 50, "lt": 100}}},
            ]
        }
    }
    assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}}


def test_knn_prefilter_not_added_without_filters():
    qb = _builder()
    q = qb.build_query(
        query_text="bags",
        query_vector=np.array([0.1, 0.2, 0.3]),
        enable_knn=True,
    )

    assert "knn" in q
    assert "filter" not in q["knn"]
    assert q["knn"]["name"] == "knn_query"