Blame view

tests/test_es_query_builder.py 7.1 KB
7fbca0d7   tangwang   启动脚本优化
1
  from types import SimpleNamespace
a3d3fb11   tangwang   加phrase提权
2
  from typing import Any, Dict
7fbca0d7   tangwang   启动脚本优化
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  
  import numpy as np
  
  from search.es_query_builder import ESQueryBuilder
  
  
  def _builder() -> ESQueryBuilder:
      return ESQueryBuilder(
          match_fields=["title.en^3.0", "brief.en^1.0"],
          text_embedding_field="title_embedding",
          default_language="en",
      )
  
  
e756b18e   tangwang   重构了文本召回构建器,现在每个 b...
17
18
19
20
  def _lexical_clause(query_root: Dict[str, Any]) -> Dict[str, Any]:
      """Return the first named lexical bool clause from query_root."""
      if "bool" in query_root and query_root["bool"].get("_name"):
          return query_root["bool"]
a3d3fb11   tangwang   加phrase提权
21
      for clause in query_root.get("bool", {}).get("should", []):
e756b18e   tangwang   重构了文本召回构建器,现在每个 b...
22
23
24
25
26
27
28
29
          clause_bool = clause.get("bool") or {}
          if clause_bool.get("_name"):
              return clause_bool
      raise AssertionError("no lexical bool clause in query_root")
  
  
  def _lexical_combined_fields(query_root: Dict[str, Any]) -> list:
      return _lexical_clause(query_root)["must"][0]["combined_fields"]["fields"]
a3d3fb11   tangwang   加phrase提权
30
31
  
  
7fbca0d7   tangwang   启动脚本优化
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  def test_knn_prefilter_includes_range_filters():
      qb = _builder()
      q = qb.build_query(
          query_text="bags",
          query_vector=np.array([0.1, 0.2, 0.3]),
          range_filters={"min_price": {"gte": 50, "lt": 100}},
          enable_knn=True,
      )
  
      assert "knn" in q
      assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}}
  
  
  def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present():
      qb = _builder()
      facets = [SimpleNamespace(field="category_name", disjunctive=True)]
      q = qb.build_query(
          query_text="bags",
          query_vector=np.array([0.1, 0.2, 0.3]),
          filters={"category_name": ["A", "B"], "vendor": "Nike"},
          range_filters={"min_price": {"gte": 50, "lt": 100}},
          facet_configs=facets,
          enable_knn=True,
      )
  
      assert "knn" in q
      assert "filter" in q["knn"]
      knn_filter = q["knn"]["filter"]
      assert knn_filter == {
          "bool": {
              "filter": [
                  {"term": {"vendor": "Nike"}},
                  {"range": {"min_price": {"gte": 50, "lt": 100}}},
              ]
          }
      }
      assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}}
  
  
  def test_knn_prefilter_not_added_without_filters():
      qb = _builder()
      q = qb.build_query(
          query_text="bags",
          query_vector=np.array([0.1, 0.2, 0.3]),
          enable_knn=True,
      )
  
      assert "knn" in q
      assert "filter" not in q["knn"]
a8261ece   tangwang   检索效果优化
81
      assert q["knn"]["_name"] == "knn_query"
c90f80ed   tangwang   相关性优化
82
83
  
  
ef5baa86   tangwang   混杂语言处理
84
  def test_text_query_contains_only_base_and_translation_named_queries():
c90f80ed   tangwang   相关性优化
85
86
      qb = _builder()
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
87
          rewritten_query="dress",
c90f80ed   tangwang   相关性优化
88
          detected_language="en",
ef5baa86   tangwang   混杂语言处理
89
          translations={"en": "dress", "zh": "连衣裙"},
c90f80ed   tangwang   相关性优化
90
91
      )
  
ef5baa86   tangwang   混杂语言处理
92
93
94
95
96
97
      q = qb.build_query(
          query_text="dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["en", "zh", "fr"],
      )
c90f80ed   tangwang   相关性优化
98
      should = q["query"]["bool"]["should"]
e756b18e   tangwang   重构了文本召回构建器,现在每个 b...
99
      names = [clause["bool"]["_name"] for clause in should]
c90f80ed   tangwang   相关性优化
100
  
e756b18e   tangwang   重构了文本召回构建器,现在每个 b...
101
102
103
      assert names == ["base_query", "base_query_trans_zh"]
      base_should = q["query"]["bool"]["should"][0]["bool"]["should"]
      assert [clause["multi_match"]["type"] for clause in base_should] == ["best_fields", "phrase"]
ef5baa86   tangwang   混杂语言处理
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  
  
  def test_text_query_skips_duplicate_translation_same_as_base():
      qb = _builder()
      parsed_query = SimpleNamespace(
          rewritten_query="dress",
          detected_language="en",
          translations={"en": "dress"},
      )
  
      q = qb.build_query(
          query_text="dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["en", "zh"],
      )
  
a3d3fb11   tangwang   加phrase提权
121
      root = q["query"]
e756b18e   tangwang   重构了文本召回构建器,现在每个 b...
122
123
      assert root["bool"]["_name"] == "base_query"
      assert [clause["multi_match"]["type"] for clause in root["bool"]["should"]] == ["best_fields", "phrase"]
6823fe3e   tangwang   feat(search): 混合语...
124
125
126
127
128
129
130
131
132
133
134
  
  
  def test_mixed_script_merges_en_fields_into_zh_clause():
      qb = ESQueryBuilder(
          match_fields=["title.en^3.0"],
          multilingual_fields=["title", "brief"],
          shared_fields=[],
          text_embedding_field="title_embedding",
          default_language="en",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
135
          rewritten_query="法式 dress",
6823fe3e   tangwang   feat(search): 混合语...
136
          detected_language="zh",
ef5baa86   tangwang   混杂语言处理
137
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
138
139
140
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
141
142
143
144
145
146
      q = qb.build_query(
          query_text="法式 dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh", "en"],
      )
e756b18e   tangwang   重构了文本召回构建器,现在每个 b...
147
      fields = _lexical_combined_fields(q["query"])
6823fe3e   tangwang   feat(search): 混合语...
148
149
150
      bases = {f.split("^", 1)[0] for f in fields}
      assert "title.zh" in bases and "title.en" in bases
      assert "brief.zh" in bases and "brief.en" in bases
ef5baa86   tangwang   混杂语言处理
151
152
153
      # Merged supplemental language fields use boost * 0.6 by default.
      assert "title.en^0.6" in fields
      assert "brief.en^0.6" in fields
6823fe3e   tangwang   feat(search): 混合语...
154
155
156
157
158
159
160
161
162
163
164
  
  
  def test_mixed_script_merges_zh_fields_into_en_clause():
      qb = ESQueryBuilder(
          match_fields=["title.en^3.0"],
          multilingual_fields=["title"],
          shared_fields=[],
          text_embedding_field="title_embedding",
          default_language="en",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
165
          rewritten_query="red 连衣裙",
6823fe3e   tangwang   feat(search): 混合语...
166
          detected_language="en",
ef5baa86   tangwang   混杂语言处理
167
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
168
169
170
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
171
172
173
174
175
176
      q = qb.build_query(
          query_text="red 连衣裙",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh", "en"],
      )
e756b18e   tangwang   重构了文本召回构建器,现在每个 b...
177
      fields = _lexical_combined_fields(q["query"])
6823fe3e   tangwang   feat(search): 混合语...
178
179
      bases = {f.split("^", 1)[0] for f in fields}
      assert "title.en" in bases and "title.zh" in bases
ef5baa86   tangwang   混杂语言处理
180
      assert "title.zh^0.6" in fields
6823fe3e   tangwang   feat(search): 混合语...
181
182
183
184
185
186
187
188
189
190
191
192
  
  
  def test_mixed_script_merged_fields_scale_configured_boosts():
      qb = ESQueryBuilder(
          match_fields=["title.en^3.0"],
          multilingual_fields=["title"],
          shared_fields=[],
          field_boosts={"title.zh": 5.0, "title.en": 10.0},
          text_embedding_field="title_embedding",
          default_language="en",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
193
          rewritten_query="法式 dress",
6823fe3e   tangwang   feat(search): 混合语...
194
          detected_language="zh",
ef5baa86   tangwang   混杂语言处理
195
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
196
197
198
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
199
200
201
202
203
204
      q = qb.build_query(
          query_text="法式 dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh", "en"],
      )
e756b18e   tangwang   重构了文本召回构建器,现在每个 b...
205
      fields = _lexical_combined_fields(q["query"])
6823fe3e   tangwang   feat(search): 混合语...
206
      assert "title.zh^5.0" in fields
ef5baa86   tangwang   混杂语言处理
207
      assert "title.en^6.0" in fields  # 10.0 * 0.6
6823fe3e   tangwang   feat(search): 混合语...
208
209
210
211
212
213
214
215
216
217
218
  
  
  def test_mixed_script_does_not_merge_en_when_not_in_index_languages():
      qb = ESQueryBuilder(
          match_fields=["title.zh^3.0"],
          multilingual_fields=["title"],
          shared_fields=[],
          text_embedding_field="title_embedding",
          default_language="zh",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
219
          rewritten_query="法式 dress",
6823fe3e   tangwang   feat(search): 混合语...
220
          detected_language="zh",
ef5baa86   tangwang   混杂语言处理
221
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
222
223
224
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
225
226
227
228
229
230
      q = qb.build_query(
          query_text="法式 dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh"],
      )
e756b18e   tangwang   重构了文本召回构建器,现在每个 b...
231
      fields = _lexical_combined_fields(q["query"])
6823fe3e   tangwang   feat(search): 混合语...
232
233
234
      bases = {f.split("^", 1)[0] for f in fields}
      assert "title.zh" in bases
      assert "title.en" not in bases