from types import SimpleNamespace from typing import Any, Dict import numpy as np from search.es_query_builder import ESQueryBuilder def _builder() -> ESQueryBuilder: return ESQueryBuilder( match_fields=["title.en^3.0", "brief.en^1.0"], multilingual_fields=["title", "brief"], core_multilingual_fields=["title", "brief"], shared_fields=[], text_embedding_field="title_embedding", image_embedding_field="image_embedding.vector", default_language="en", ) def _recall_root(es_body: Dict[str, Any]) -> Dict[str, Any]: query_root = es_body["query"] if "bool" in query_root and query_root["bool"].get("must"): query_root = query_root["bool"]["must"][0] if "function_score" in query_root: query_root = query_root["function_score"]["query"] return query_root def _recall_should_clauses(es_body: Dict[str, Any]) -> list[Dict[str, Any]]: root = _recall_root(es_body) should = root.get("bool", {}).get("should") if should: return should return [root] def _recall_clause_name(clause: Dict[str, Any]) -> str | None: if "bool" in clause: return clause["bool"].get("_name") if "knn" in clause: return clause["knn"].get("_name") if "nested" in clause: return clause["nested"].get("_name") return None def test_knn_clause_moves_under_query_should_and_uses_outer_filters(): qb = _builder() q = qb.build_query( query_text="bags", query_vector=np.array([0.1, 0.2, 0.3]), range_filters={"min_price": {"gte": 50, "lt": 100}}, enable_knn=True, ) assert "knn" not in q should = _recall_should_clauses(q) assert any(clause.get("knn", {}).get("_name") == "knn_query" for clause in should) assert q["query"]["bool"]["filter"] == [{"range": {"min_price": {"gte": 50, "lt": 100}}}] def test_knn_clause_uses_outer_query_filter_when_disjunctive_filters_present(): qb = _builder() facets = [SimpleNamespace(field="category_name", disjunctive=True)] q = qb.build_query( query_text="bags", query_vector=np.array([0.1, 0.2, 0.3]), filters={"category_name": ["A", "B"], "vendor": "Nike"}, range_filters={"min_price": {"gte": 50, "lt": 100}}, facet_configs=facets, enable_knn=True, ) assert "knn" not in q assert q["query"]["bool"]["filter"] == [ {"term": {"vendor": "Nike"}}, {"range": {"min_price": {"gte": 50, "lt": 100}}}, ] assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}} def test_knn_clause_has_name_and_no_embedded_filter(): qb = _builder() q = qb.build_query( query_text="bags", query_vector=np.array([0.1, 0.2, 0.3]), enable_knn=True, ) should = _recall_should_clauses(q) knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query") assert "filter" not in knn_clause assert knn_clause["_name"] == "knn_query" def test_text_query_contains_only_base_and_translation_named_queries(): qb = _builder() parsed_query = SimpleNamespace( rewritten_query="dress", detected_language="en", translations={"en": "dress", "zh": "连衣裙"}, ) q = qb.build_query( query_text="dress", parsed_query=parsed_query, enable_knn=False, ) should = _recall_should_clauses(q) names = [clause["bool"]["_name"] for clause in should] assert names == ["base_query", "base_query_trans_zh"] base_should = should[0]["bool"]["should"] mm_types = [c["multi_match"]["type"] for c in base_should if "multi_match" in c] assert mm_types == ["best_fields", "phrase"] def test_text_query_skips_duplicate_translation_same_as_base(): qb = _builder() parsed_query = SimpleNamespace( rewritten_query="dress", detected_language="en", translations={"en": "dress"}, ) q = qb.build_query( query_text="dress", parsed_query=parsed_query, enable_knn=False, ) query_root = q["query"] if "function_score" in query_root: query_root = query_root["function_score"]["query"] base_bool = query_root["bool"] assert base_bool["_name"] == "base_query" mm_types = [c["multi_match"]["type"] for c in base_bool["should"] if "multi_match" in c] assert mm_types == ["best_fields", "phrase"] def test_product_title_exclusion_filter_is_applied_once_on_outer_query(): qb = _builder() parsed_query = SimpleNamespace( rewritten_query="fitted dress", detected_language="en", translations={"zh": "修身 连衣裙"}, product_title_exclusion_profile=SimpleNamespace( is_active=True, all_zh_title_exclusions=lambda: ["宽松"], all_en_title_exclusions=lambda: ["loose", "relaxed"], ), ) q = qb.build_query( query_text="fitted dress", query_vector=np.array([0.1, 0.2, 0.3]), parsed_query=parsed_query, enable_knn=True, ) expected_filter = { "bool": { "must_not": [ { "bool": { "should": [ {"match_phrase": {"title.zh": {"query": "宽松"}}}, {"match_phrase": {"title.en": {"query": "loose"}}}, {"match_phrase": {"title.en": {"query": "relaxed"}}}, ], "minimum_should_match": 1, } } ] } } assert expected_filter in q["query"]["bool"]["filter"] should = _recall_should_clauses(q) knn_clause = next(clause["knn"] for clause in should if clause.get("knn", {}).get("_name") == "knn_query") assert "filter" not in knn_clause def test_image_knn_clause_is_added_alongside_base_translation_and_text_knn(): qb = _builder() parsed_query = SimpleNamespace( rewritten_query="street tee", detected_language="en", translations={"zh": "街头短袖"}, ) q = qb.build_query( query_text="street tee", query_vector=np.array([0.1, 0.2, 0.3]), image_query_vector=np.array([0.4, 0.5, 0.6]), parsed_query=parsed_query, enable_knn=True, ) should = _recall_should_clauses(q) names = [ _recall_clause_name(clause) for clause in should ] assert names == ["base_query", "base_query_trans_zh", "knn_query", "image_knn_query"] image_knn = next(clause["nested"] for clause in should if clause.get("nested", {}).get("_name") == "image_knn_query") assert image_knn["path"] == "image_embedding" assert image_knn["score_mode"] == "max" assert image_knn["query"]["knn"]["field"] == "image_embedding.vector" def test_text_knn_plan_is_reused_for_ann_and_exact_rescore(): qb = _builder() parsed_query = SimpleNamespace(query_tokens=["a", "b", "c", "d", "e"]) ann_clause = qb.build_text_knn_clause( np.array([0.1, 0.2, 0.3]), parsed_query=parsed_query, ) exact_clause = qb.build_exact_text_knn_rescore_clause( np.array([0.1, 0.2, 0.3]), parsed_query=parsed_query, ) assert ann_clause is not None assert exact_clause is not None assert ann_clause["knn"]["k"] == qb.knn_text_k_long assert ann_clause["knn"]["num_candidates"] == qb.knn_text_num_candidates_long assert ann_clause["knn"]["boost"] == qb.knn_text_boost * 1.4 assert exact_clause["script_score"]["script"]["params"]["boost"] == qb.knn_text_boost * 1.4 def test_image_knn_plan_is_reused_for_ann_and_exact_rescore(): qb = _builder() ann_clause = qb.build_image_knn_clause(np.array([0.4, 0.5, 0.6])) exact_clause = qb.build_exact_image_knn_rescore_clause(np.array([0.4, 0.5, 0.6])) assert ann_clause is not None assert exact_clause is not None assert ann_clause["nested"]["query"]["knn"]["boost"] == qb.knn_image_boost assert exact_clause["nested"]["query"]["script_score"]["script"]["params"]["boost"] == qb.knn_image_boost