from types import SimpleNamespace import numpy as np from search.es_query_builder import ESQueryBuilder def _builder() -> ESQueryBuilder: return ESQueryBuilder( match_fields=["title.en^3.0", "brief.en^1.0"], text_embedding_field="title_embedding", default_language="en", ) def test_knn_prefilter_includes_range_filters(): qb = _builder() q = qb.build_query( query_text="bags", query_vector=np.array([0.1, 0.2, 0.3]), range_filters={"min_price": {"gte": 50, "lt": 100}}, enable_knn=True, ) assert "knn" in q assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}} def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present(): qb = _builder() facets = [SimpleNamespace(field="category_name", disjunctive=True)] q = qb.build_query( query_text="bags", query_vector=np.array([0.1, 0.2, 0.3]), filters={"category_name": ["A", "B"], "vendor": "Nike"}, range_filters={"min_price": {"gte": 50, "lt": 100}}, facet_configs=facets, enable_knn=True, ) assert "knn" in q assert "filter" in q["knn"] knn_filter = q["knn"]["filter"] assert knn_filter == { "bool": { "filter": [ {"term": {"vendor": "Nike"}}, {"range": {"min_price": {"gte": 50, "lt": 100}}}, ] } } assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}} def test_knn_prefilter_not_added_without_filters(): qb = _builder() q = qb.build_query( query_text="bags", query_vector=np.array([0.1, 0.2, 0.3]), enable_knn=True, ) assert "knn" in q assert "filter" not in q["knn"] assert q["knn"]["_name"] == "knn_query" def test_text_query_contains_only_base_and_translation_named_queries(): qb = _builder() parsed_query = SimpleNamespace( rewritten_query="dress", detected_language="en", translations={"en": "dress", "zh": "连衣裙"}, ) q = qb.build_query( query_text="dress", parsed_query=parsed_query, enable_knn=False, index_languages=["en", "zh", "fr"], ) should = q["query"]["bool"]["should"] names = [clause["multi_match"]["_name"] for clause in should] assert names == ["base_query", "base_query_trans_zh"] def test_text_query_skips_duplicate_translation_same_as_base(): qb = _builder() parsed_query = SimpleNamespace( rewritten_query="dress", detected_language="en", translations={"en": "dress"}, ) q = qb.build_query( query_text="dress", parsed_query=parsed_query, enable_knn=False, index_languages=["en", "zh"], ) assert q["query"]["multi_match"]["_name"] == "base_query" def test_mixed_script_merges_en_fields_into_zh_clause(): qb = ESQueryBuilder( match_fields=["title.en^3.0"], multilingual_fields=["title", "brief"], shared_fields=[], text_embedding_field="title_embedding", default_language="en", ) parsed_query = SimpleNamespace( rewritten_query="法式 dress", detected_language="zh", translations={}, contains_chinese=True, contains_english=True, ) q = qb.build_query( query_text="法式 dress", parsed_query=parsed_query, enable_knn=False, index_languages=["zh", "en"], ) fields = q["query"]["multi_match"]["fields"] bases = {f.split("^", 1)[0] for f in fields} assert "title.zh" in bases and "title.en" in bases assert "brief.zh" in bases and "brief.en" in bases # Merged supplemental language fields use boost * 0.6 by default. assert "title.en^0.6" in fields assert "brief.en^0.6" in fields def test_mixed_script_merges_zh_fields_into_en_clause(): qb = ESQueryBuilder( match_fields=["title.en^3.0"], multilingual_fields=["title"], shared_fields=[], text_embedding_field="title_embedding", default_language="en", ) parsed_query = SimpleNamespace( rewritten_query="red 连衣裙", detected_language="en", translations={}, contains_chinese=True, contains_english=True, ) q = qb.build_query( query_text="red 连衣裙", parsed_query=parsed_query, enable_knn=False, index_languages=["zh", "en"], ) fields = q["query"]["multi_match"]["fields"] bases = {f.split("^", 1)[0] for f in fields} assert "title.en" in bases and "title.zh" in bases assert "title.zh^0.6" in fields def test_mixed_script_merged_fields_scale_configured_boosts(): qb = ESQueryBuilder( match_fields=["title.en^3.0"], multilingual_fields=["title"], shared_fields=[], field_boosts={"title.zh": 5.0, "title.en": 10.0}, text_embedding_field="title_embedding", default_language="en", ) parsed_query = SimpleNamespace( rewritten_query="法式 dress", detected_language="zh", translations={}, contains_chinese=True, contains_english=True, ) q = qb.build_query( query_text="法式 dress", parsed_query=parsed_query, enable_knn=False, index_languages=["zh", "en"], ) fields = q["query"]["multi_match"]["fields"] assert "title.zh^5.0" in fields assert "title.en^6.0" in fields # 10.0 * 0.6 def test_mixed_script_does_not_merge_en_when_not_in_index_languages(): qb = ESQueryBuilder( match_fields=["title.zh^3.0"], multilingual_fields=["title"], shared_fields=[], text_embedding_field="title_embedding", default_language="zh", ) parsed_query = SimpleNamespace( rewritten_query="法式 dress", detected_language="zh", translations={}, contains_chinese=True, contains_english=True, ) q = qb.build_query( query_text="法式 dress", parsed_query=parsed_query, enable_knn=False, index_languages=["zh"], ) fields = q["query"]["multi_match"]["fields"] bases = {f.split("^", 1)[0] for f in fields} assert "title.zh" in bases assert "title.en" not in bases