Blame view

tests/test_es_query_builder.py 6.18 KB
7fbca0d7   tangwang   启动脚本优化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  from types import SimpleNamespace
  
  import numpy as np
  
  from search.es_query_builder import ESQueryBuilder
  
  
  def _builder() -> ESQueryBuilder:
      return ESQueryBuilder(
          match_fields=["title.en^3.0", "brief.en^1.0"],
          text_embedding_field="title_embedding",
          default_language="en",
      )
  
  
  def test_knn_prefilter_includes_range_filters():
      qb = _builder()
      q = qb.build_query(
          query_text="bags",
          query_vector=np.array([0.1, 0.2, 0.3]),
          range_filters={"min_price": {"gte": 50, "lt": 100}},
          enable_knn=True,
      )
  
      assert "knn" in q
      assert q["knn"]["filter"] == {"range": {"min_price": {"gte": 50, "lt": 100}}}
  
  
  def test_knn_prefilter_uses_only_conjunctive_filters_when_disjunctive_present():
      qb = _builder()
      facets = [SimpleNamespace(field="category_name", disjunctive=True)]
      q = qb.build_query(
          query_text="bags",
          query_vector=np.array([0.1, 0.2, 0.3]),
          filters={"category_name": ["A", "B"], "vendor": "Nike"},
          range_filters={"min_price": {"gte": 50, "lt": 100}},
          facet_configs=facets,
          enable_knn=True,
      )
  
      assert "knn" in q
      assert "filter" in q["knn"]
      knn_filter = q["knn"]["filter"]
      assert knn_filter == {
          "bool": {
              "filter": [
                  {"term": {"vendor": "Nike"}},
                  {"range": {"min_price": {"gte": 50, "lt": 100}}},
              ]
          }
      }
      assert q["post_filter"] == {"terms": {"category_name": ["A", "B"]}}
  
  
  def test_knn_prefilter_not_added_without_filters():
      qb = _builder()
      q = qb.build_query(
          query_text="bags",
          query_vector=np.array([0.1, 0.2, 0.3]),
          enable_knn=True,
      )
  
      assert "knn" in q
      assert "filter" not in q["knn"]
a8261ece   tangwang   检索效果优化
65
      assert q["knn"]["_name"] == "knn_query"
c90f80ed   tangwang   相关性优化
66
67
  
  
ef5baa86   tangwang   混杂语言处理
68
  def test_text_query_contains_only_base_and_translation_named_queries():
c90f80ed   tangwang   相关性优化
69
70
      qb = _builder()
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
71
          rewritten_query="dress",
c90f80ed   tangwang   相关性优化
72
          detected_language="en",
ef5baa86   tangwang   混杂语言处理
73
          translations={"en": "dress", "zh": "连衣裙"},
c90f80ed   tangwang   相关性优化
74
75
      )
  
ef5baa86   tangwang   混杂语言处理
76
77
78
79
80
81
      q = qb.build_query(
          query_text="dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["en", "zh", "fr"],
      )
c90f80ed   tangwang   相关性优化
82
83
84
      should = q["query"]["bool"]["should"]
      names = [clause["multi_match"]["_name"] for clause in should]
  
ef5baa86   tangwang   混杂语言处理
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
      assert names == ["base_query", "base_query_trans_zh"]
  
  
  def test_text_query_skips_duplicate_translation_same_as_base():
      qb = _builder()
      parsed_query = SimpleNamespace(
          rewritten_query="dress",
          detected_language="en",
          translations={"en": "dress"},
      )
  
      q = qb.build_query(
          query_text="dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["en", "zh"],
      )
  
      assert q["query"]["multi_match"]["_name"] == "base_query"
6823fe3e   tangwang   feat(search): 混合语...
104
105
106
107
108
109
110
111
112
113
114
  
  
  def test_mixed_script_merges_en_fields_into_zh_clause():
      qb = ESQueryBuilder(
          match_fields=["title.en^3.0"],
          multilingual_fields=["title", "brief"],
          shared_fields=[],
          text_embedding_field="title_embedding",
          default_language="en",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
115
          rewritten_query="法式 dress",
6823fe3e   tangwang   feat(search): 混合语...
116
          detected_language="zh",
ef5baa86   tangwang   混杂语言处理
117
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
118
119
120
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
121
122
123
124
125
126
      q = qb.build_query(
          query_text="法式 dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh", "en"],
      )
6823fe3e   tangwang   feat(search): 混合语...
127
128
129
130
      fields = q["query"]["multi_match"]["fields"]
      bases = {f.split("^", 1)[0] for f in fields}
      assert "title.zh" in bases and "title.en" in bases
      assert "brief.zh" in bases and "brief.en" in bases
ef5baa86   tangwang   混杂语言处理
131
132
133
      # Merged supplemental language fields use boost * 0.6 by default.
      assert "title.en^0.6" in fields
      assert "brief.en^0.6" in fields
6823fe3e   tangwang   feat(search): 混合语...
134
135
136
137
138
139
140
141
142
143
144
  
  
  def test_mixed_script_merges_zh_fields_into_en_clause():
      qb = ESQueryBuilder(
          match_fields=["title.en^3.0"],
          multilingual_fields=["title"],
          shared_fields=[],
          text_embedding_field="title_embedding",
          default_language="en",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
145
          rewritten_query="red 连衣裙",
6823fe3e   tangwang   feat(search): 混合语...
146
          detected_language="en",
ef5baa86   tangwang   混杂语言处理
147
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
148
149
150
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
151
152
153
154
155
156
      q = qb.build_query(
          query_text="red 连衣裙",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh", "en"],
      )
6823fe3e   tangwang   feat(search): 混合语...
157
158
159
      fields = q["query"]["multi_match"]["fields"]
      bases = {f.split("^", 1)[0] for f in fields}
      assert "title.en" in bases and "title.zh" in bases
ef5baa86   tangwang   混杂语言处理
160
      assert "title.zh^0.6" in fields
6823fe3e   tangwang   feat(search): 混合语...
161
162
163
164
165
166
167
168
169
170
171
172
  
  
  def test_mixed_script_merged_fields_scale_configured_boosts():
      qb = ESQueryBuilder(
          match_fields=["title.en^3.0"],
          multilingual_fields=["title"],
          shared_fields=[],
          field_boosts={"title.zh": 5.0, "title.en": 10.0},
          text_embedding_field="title_embedding",
          default_language="en",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
173
          rewritten_query="法式 dress",
6823fe3e   tangwang   feat(search): 混合语...
174
          detected_language="zh",
ef5baa86   tangwang   混杂语言处理
175
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
176
177
178
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
179
180
181
182
183
184
      q = qb.build_query(
          query_text="法式 dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh", "en"],
      )
6823fe3e   tangwang   feat(search): 混合语...
185
186
      fields = q["query"]["multi_match"]["fields"]
      assert "title.zh^5.0" in fields
ef5baa86   tangwang   混杂语言处理
187
      assert "title.en^6.0" in fields  # 10.0 * 0.6
6823fe3e   tangwang   feat(search): 混合语...
188
189
190
191
192
193
194
195
196
197
198
  
  
  def test_mixed_script_does_not_merge_en_when_not_in_index_languages():
      qb = ESQueryBuilder(
          match_fields=["title.zh^3.0"],
          multilingual_fields=["title"],
          shared_fields=[],
          text_embedding_field="title_embedding",
          default_language="zh",
      )
      parsed_query = SimpleNamespace(
ef5baa86   tangwang   混杂语言处理
199
          rewritten_query="法式 dress",
6823fe3e   tangwang   feat(search): 混合语...
200
          detected_language="zh",
ef5baa86   tangwang   混杂语言处理
201
          translations={},
6823fe3e   tangwang   feat(search): 混合语...
202
203
204
          contains_chinese=True,
          contains_english=True,
      )
ef5baa86   tangwang   混杂语言处理
205
206
207
208
209
210
      q = qb.build_query(
          query_text="法式 dress",
          parsed_query=parsed_query,
          enable_knn=False,
          index_languages=["zh"],
      )
6823fe3e   tangwang   feat(search): 混合语...
211
212
213
214
      fields = q["query"]["multi_match"]["fields"]
      bases = {f.split("^", 1)[0] for f in fields}
      assert "title.zh" in bases
      assert "title.en" not in bases