from types import SimpleNamespace import numpy as np from search.es_query_builder import ESQueryBuilder def _builder() -> ESQueryBuilder: return ESQueryBuilder( match_fields=["title.en^3.0", "brief.en^1.0"], text_embedding_field="title_embedding", default_language="en", ) def test_knn_prefilter_includes_range_filters(): qb = _builder() q = qb.build_query( query_text="bags", query_vector=np.array([0.1, 0.2, 0.3]), range_filters={"min_price": {"gte": 50, "lt": 100}}, enable_knn=True, ) assert "knn" in q assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}} def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present(): qb = _builder() facets = [SimpleNamespace(field="category_name", disjunctive=True)] q = qb.build_query( query_text="bags", query_vector=np.array([0.1, 0.2, 0.3]), filters={"category_name": ["A", "B"], "vendor": "Nike"}, range_filters={"min_price": {"gte": 50, "lt": 100}}, facet_configs=facets, enable_knn=True, ) assert "knn" in q assert "filter" in q["knn"] knn_filter = q["knn"]["filter"] assert knn_filter == { "bool": { "filter": [ {"term": {"vendor": "Nike"}}, {"range": {"min_price": {"gte": 50, "lt": 100}}}, ] } } assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}} def test_knn_prefilter_not_added_without_filters(): qb = _builder() q = qb.build_query( query_text="bags", query_vector=np.array([0.1, 0.2, 0.3]), enable_knn=True, ) assert "knn" in q assert "filter" not in q["knn"]