test_es_query_builder.py
3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from types import SimpleNamespace
from typing import Any, Dict
import numpy as np
from search.es_query_builder import ESQueryBuilder
def _builder() -> ESQueryBuilder:
return ESQueryBuilder(
match_fields=["title.en^3.0", "brief.en^1.0"],
multilingual_fields=["title", "brief"],
core_multilingual_fields=["title", "brief"],
shared_fields=[],
text_embedding_field="title_embedding",
default_language="en",
)
def _lexical_clause(query_root: Dict[str, Any]) -> Dict[str, Any]:
"""Return the first named lexical bool clause from query_root."""
if "bool" in query_root and query_root["bool"].get("_name"):
return query_root["bool"]
for clause in query_root.get("bool", {}).get("should", []):
clause_bool = clause.get("bool") or {}
if clause_bool.get("_name"):
return clause_bool
raise AssertionError("no lexical bool clause in query_root")
def test_knn_prefilter_includes_range_filters():
qb = _builder()
q = qb.build_query(
query_text="bags",
query_vector=np.array([0.1, 0.2, 0.3]),
range_filters={"min_price": {"gte": 50, "lt": 100}},
enable_knn=True,
)
assert "knn" in q
assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}}
def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present():
qb = _builder()
facets = [SimpleNamespace(field="category_name", disjunctive=True)]
q = qb.build_query(
query_text="bags",
query_vector=np.array([0.1, 0.2, 0.3]),
filters={"category_name": ["A", "B"], "vendor": "Nike"},
range_filters={"min_price": {"gte": 50, "lt": 100}},
facet_configs=facets,
enable_knn=True,
)
assert "knn" in q
assert "filter" in q["knn"]
knn_filter = q["knn"]["filter"]
assert knn_filter == {
"bool": {
"filter": [
{"term": {"vendor": "Nike"}},
{"range": {"min_price": {"gte": 50, "lt": 100}}},
]
}
}
assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}}
def test_knn_prefilter_not_added_without_filters():
qb = _builder()
q = qb.build_query(
query_text="bags",
query_vector=np.array([0.1, 0.2, 0.3]),
enable_knn=True,
)
assert "knn" in q
assert "filter" not in q["knn"]
assert q["knn"]["_name"] == "knn_query"
def test_text_query_contains_only_base_and_translation_named_queries():
qb = _builder()
parsed_query = SimpleNamespace(
rewritten_query="dress",
detected_language="en",
translations={"en": "dress", "zh": "连衣裙"},
)
q = qb.build_query(
query_text="dress",
parsed_query=parsed_query,
enable_knn=False,
)
should = q["query"]["bool"]["should"]
names = [clause["bool"]["_name"] for clause in should]
assert names == ["base_query", "base_query_trans_zh"]
base_should = q["query"]["bool"]["should"][0]["bool"]["should"]
assert [clause["multi_match"]["type"] for clause in base_should] == ["best_fields", "phrase"]
def test_text_query_skips_duplicate_translation_same_as_base():
qb = _builder()
parsed_query = SimpleNamespace(
rewritten_query="dress",
detected_language="en",
translations={"en": "dress"},
)
q = qb.build_query(
query_text="dress",
parsed_query=parsed_query,
enable_knn=False,
)
root = q["query"]
assert root["bool"]["_name"] == "base_query"
assert [clause["multi_match"]["type"] for clause in root["bool"]["should"]] == ["best_fields", "phrase"]