Blame view

tests/test_es_query_builder.py 7.04 KB
7fbca0d7   tangwang   启动脚本优化
1
  from types import SimpleNamespace
a3d3fb11   tangwang   加phrase提权
2
  from typing import Any, Dict
7fbca0d7   tangwang   启动脚本优化
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  
  import numpy as np
  
  from search.es_query_builder import ESQueryBuilder
  
  
  def _builder() -> ESQueryBuilder:
      return ESQueryBuilder(
          match_fields=["title.en^3.0", "brief.en^1.0"],
          text_embedding_field="title_embedding",
          default_language="en",
      )
  
  
a3d3fb11   tangwang   加phrase提权
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  def _lexical_multi_match_fields(query_root: Dict[str, Any]) -> list:
      """Fields from the non-phrase multi_match (bool.should or single clause)."""
      if "multi_match" in query_root:
          mm = query_root["multi_match"]
          if mm.get("type") == "phrase":
              raise AssertionError("root multi_match is phrase-only")
          return mm["fields"]
      for clause in query_root.get("bool", {}).get("should", []):
          mm = clause.get("multi_match") or {}
          if mm.get("type") == "phrase":
              continue
          return mm["fields"]
      raise AssertionError("no lexical multi_match in query_root")
  
  
7fbca0d7   tangwang   启动脚本优化
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  def test_knn_prefilter_includes_range_filters():
      qb = _builder()
      q = qb.build_query(
          query_text="bags",
          query_vector=np.array([0.1, 0.2, 0.3]),
          range_filters={"min_price": {"gte": 50, "lt": 100}},
          enable_knn=True,
      )
  
      assert "knn" in q
      assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}}
  
  
  def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present():
      qb = _builder()
      facets = [SimpleNamespace(field="category_name", disjunctive=True)]
      q = qb.build_query(
          query_text="bags",
          query_vector=np.array([0.1, 0.2, 0.3]),
          filters={"category_name": ["A", "B"], "vendor": "Nike"},
          range_filters={"min_price": {"gte": 50, "lt": 100}},
          facet_configs=facets,
          enable_knn=True,
      )
  
      assert "knn" in q
      assert "filter" in q["knn"]
      knn_filter = q["knn"]["filter"]
      assert knn_filter == {
          "bool": {
              "filter": [
                  {"term": {"vendor": "Nike"}},
                  {"range": {"min_price": {"gte": 50, "lt": 100}}},
              ]
          }
      }
      assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}}
  
  
  def test_knn_prefilter_not_added_without_filters():
      qb = _builder()
      q = qb.build_query(
          query_text="bags",
          query_vector=np.array([0.1, 0.2, 0.3]),
          enable_knn=True,
      )
  
      assert "knn" in q
      assert "filter" not in q["knn"]
a8261ece   tangwang   检索效果优化
81
      assert q["knn"]["_name"] == "knn_query"
c90f80ed   tangwang   相关性优化
82
83
  
  
ef5baa86   tangwang   混杂语言处理
84
  def test_text_query_contains_only_base_and_translation_named_queries():
c90f80ed   tangwang   相关性优化
85
86
      qb = _builder()
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
87
          rewritten_query="dress",
c90f80ed   tangwang   相关性优化
88
          detected_language="en",
ef5baa86   tangwang   混杂语言处理
89
          translations={"en": "dress", "zh": "连衣裙"},
c90f80ed   tangwang   相关性优化
90
91
      )
  
ef5baa86   tangwang   混杂语言处理
92
93
94
95
96
97
      q = qb.build_query(
          query_text="dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["en", "zh", "fr"],
      )
c90f80ed   tangwang   相关性优化
98
99
100
      should = q["query"]["bool"]["should"]
      names = [clause["multi_match"]["_name"] for clause in should]
  
a3d3fb11   tangwang   加phrase提权
101
102
103
104
105
106
      assert names == [
          "base_query",
          "base_query_phrase",
          "base_query_trans_zh",
          "base_query_trans_zh_phrase",
      ]
ef5baa86   tangwang   混杂语言处理
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  
  
  def test_text_query_skips_duplicate_translation_same_as_base():
      qb = _builder()
      parsed_query = SimpleNamespace(
          rewritten_query="dress",
          detected_language="en",
          translations={"en": "dress"},
      )
  
      q = qb.build_query(
          query_text="dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["en", "zh"],
      )
  
a3d3fb11   tangwang   加phrase提权
124
125
126
      root = q["query"]
      assert root["bool"]["should"][0]["multi_match"]["_name"] == "base_query"
      assert root["bool"]["should"][1]["multi_match"]["_name"] == "base_query_phrase"
6823fe3e   tangwang   feat(search): 混合语...
127
128
129
130
131
132
133
134
135
136
137
  
  
  def test_mixed_script_merges_en_fields_into_zh_clause():
      qb = ESQueryBuilder(
          match_fields=["title.en^3.0"],
          multilingual_fields=["title", "brief"],
          shared_fields=[],
          text_embedding_field="title_embedding",
          default_language="en",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
138
          rewritten_query="法式 dress",
6823fe3e   tangwang   feat(search): 混合语...
139
          detected_language="zh",
ef5baa86   tangwang   混杂语言处理
140
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
141
142
143
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
144
145
146
147
148
149
      q = qb.build_query(
          query_text="法式 dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh", "en"],
      )
a3d3fb11   tangwang   加phrase提权
150
      fields = _lexical_multi_match_fields(q["query"])
6823fe3e   tangwang   feat(search): 混合语...
151
152
153
      bases = {f.split("^", 1)[0] for f in fields}
      assert "title.zh" in bases and "title.en" in bases
      assert "brief.zh" in bases and "brief.en" in bases
ef5baa86   tangwang   混杂语言处理
154
155
156
      # Merged supplemental language fields use boost * 0.6 by default.
      assert "title.en^0.6" in fields
      assert "brief.en^0.6" in fields
6823fe3e   tangwang   feat(search): 混合语...
157
158
159
160
161
162
163
164
165
166
167
  
  
  def test_mixed_script_merges_zh_fields_into_en_clause():
      qb = ESQueryBuilder(
          match_fields=["title.en^3.0"],
          multilingual_fields=["title"],
          shared_fields=[],
          text_embedding_field="title_embedding",
          default_language="en",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
168
          rewritten_query="red 连衣裙",
6823fe3e   tangwang   feat(search): 混合语...
169
          detected_language="en",
ef5baa86   tangwang   混杂语言处理
170
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
171
172
173
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
174
175
176
177
178
179
      q = qb.build_query(
          query_text="red 连衣裙",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh", "en"],
      )
a3d3fb11   tangwang   加phrase提权
180
      fields = _lexical_multi_match_fields(q["query"])
6823fe3e   tangwang   feat(search): 混合语...
181
182
      bases = {f.split("^", 1)[0] for f in fields}
      assert "title.en" in bases and "title.zh" in bases
ef5baa86   tangwang   混杂语言处理
183
      assert "title.zh^0.6" in fields
6823fe3e   tangwang   feat(search): 混合语...
184
185
186
187
188
189
190
191
192
193
194
195
  
  
  def test_mixed_script_merged_fields_scale_configured_boosts():
      qb = ESQueryBuilder(
          match_fields=["title.en^3.0"],
          multilingual_fields=["title"],
          shared_fields=[],
          field_boosts={"title.zh": 5.0, "title.en": 10.0},
          text_embedding_field="title_embedding",
          default_language="en",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
196
          rewritten_query="法式 dress",
6823fe3e   tangwang   feat(search): 混合语...
197
          detected_language="zh",
ef5baa86   tangwang   混杂语言处理
198
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
199
200
201
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
202
203
204
205
206
207
      q = qb.build_query(
          query_text="法式 dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh", "en"],
      )
a3d3fb11   tangwang   加phrase提权
208
      fields = _lexical_multi_match_fields(q["query"])
6823fe3e   tangwang   feat(search): 混合语...
209
      assert "title.zh^5.0" in fields
ef5baa86   tangwang   混杂语言处理
210
      assert "title.en^6.0" in fields  # 10.0 * 0.6
6823fe3e   tangwang   feat(search): 混合语...
211
212
213
214
215
216
217
218
219
220
221
  
  
  def test_mixed_script_does_not_merge_en_when_not_in_index_languages():
      qb = ESQueryBuilder(
          match_fields=["title.zh^3.0"],
          multilingual_fields=["title"],
          shared_fields=[],
          text_embedding_field="title_embedding",
          default_language="zh",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
222
          rewritten_query="法式 dress",
6823fe3e   tangwang   feat(search): 混合语...
223
          detected_language="zh",
ef5baa86   tangwang   混杂语言处理
224
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
225
226
227
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
228
229
230
231
232
233
      q = qb.build_query(
          query_text="法式 dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh"],
      )
a3d3fb11   tangwang   加phrase提权
234
      fields = _lexical_multi_match_fields(q["query"])
6823fe3e   tangwang   feat(search): 混合语...
235
236
237
      bases = {f.split("^", 1)[0] for f in fields}
      assert "title.zh" in bases
      assert "title.en" not in bases