Blame view

tests/unit/test_searcher.py 9.21 KB
16c42787   tangwang   feat: implement r...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
  """
  Searcher单元测试
  """
  
  import pytest
  from unittest.mock import Mock, patch, MagicMock
  import numpy as np
  
  from search import Searcher
  from query import ParsedQuery
  from context import RequestContext, create_request_context
  
  
  @pytest.mark.unit
  class TestSearcher:
      """Searcher测试用例"""
  
      def test_searcher_initialization(self, sample_customer_config, mock_es_client):
          """测试Searcher初始化"""
          searcher = Searcher(sample_customer_config, mock_es_client)
  
          assert searcher.config == sample_customer_config
          assert searcher.es_client == mock_es_client
          assert searcher.query_parser is not None
          assert searcher.boolean_parser is not None
          assert searcher.ranking_engine is not None
  
      def test_search_without_context(self, test_searcher):
          """测试不带context的搜索(向后兼容)"""
          result = test_searcher.search("红色连衣裙", size=5)
  
          assert result.hits is not None
          assert result.total >= 0
          assert result.context is not None  # 应该自动创建context
          assert result.took_ms >= 0
  
      def test_search_with_context(self, test_searcher):
          """测试带context的搜索"""
          context = create_request_context("test-req", "test-user")
  
          result = test_searcher.search("红色连衣裙", context=context)
  
          assert result.hits is not None
          assert result.context == context
          assert context.reqid == "test-req"
          assert context.uid == "test-user"
  
      def test_search_with_parameters(self, test_searcher):
          """测试带各种参数的搜索"""
          context = create_request_context()
  
          result = test_searcher.search(
              query="红色连衣裙",
              size=15,
              from_=5,
              filters={"category_id": 1},
              enable_translation=False,
              enable_embedding=False,
              enable_rerank=False,
              min_score=1.0,
              context=context
          )
  
          assert result is not None
          assert context.metadata['search_params']['size'] == 15
          assert context.metadata['search_params']['from'] == 5
          assert context.metadata['search_params']['filters'] == {"category_id": 1}
          assert context.metadata['search_params']['min_score'] == 1.0
  
          # 验证feature flags
          assert context.metadata['feature_flags']['enable_translation'] is False
          assert context.metadata['feature_flags']['enable_embedding'] is False
          assert context.metadata['feature_flags']['enable_rerank'] is False
  
      @patch('search.searcher.QueryParser')
      def test_search_query_parsing(self, mock_query_parser_class, test_searcher):
          """测试查询解析流程"""
          # 设置mock
          mock_parser = Mock()
          mock_query_parser_class.return_value = mock_parser
  
          parsed_query = ParsedQuery(
              original_query="红色连衣裙",
              normalized_query="红色 连衣裙",
              rewritten_query="红色 女 连衣裙",
              detected_language="zh",
              domain="default"
          )
          mock_parser.parse.return_value = parsed_query
  
          context = create_request_context()
          test_searcher.search("红色连衣裙", context=context)
  
          # 验证query parser被调用
          mock_parser.parse.assert_called_once_with("红色连衣裙", generate_vector=True, context=context)
  
      def test_search_error_handling(self, test_searcher):
          """测试搜索错误处理"""
          # 设置ES客户端抛出异常
          test_searcher.es_client.search.side_effect = Exception("ES连接失败")
  
          context = create_request_context()
  
          with pytest.raises(Exception, match="ES连接失败"):
              test_searcher.search("红色连衣裙", context=context)
  
          # 验证错误被记录到context
          assert context.has_error()
          assert "ES连接失败" in context.metadata['error_info']['message']
  
      def test_search_result_processing(self, test_searcher):
          """测试搜索结果处理"""
          context = create_request_context()
  
          result = test_searcher.search("红色连衣裙", enable_rerank=True, context=context)
  
          # 验证结果结构
          assert hasattr(result, 'hits')
          assert hasattr(result, 'total')
          assert hasattr(result, 'max_score')
          assert hasattr(result, 'took_ms')
          assert hasattr(result, 'aggregations')
          assert hasattr(result, 'query_info')
          assert hasattr(result, 'context')
  
          # 验证context中有中间结果
          assert context.get_intermediate_result('es_response') is not None
          assert context.get_intermediate_result('raw_hits') is not None
          assert context.get_intermediate_result('processed_hits') is not None
  
      def test_boolean_query_handling(self, test_searcher):
          """测试布尔查询处理"""
          context = create_request_context()
  
          # 测试复杂布尔查询
          result = test_searcher.search("laptop AND (gaming OR professional)", context=context)
  
          assert result is not None
          # 对于复杂查询,应该调用boolean parser
          assert not context.query_analysis.is_simple_query
  
      def test_simple_query_handling(self, test_searcher):
          """测试简单查询处理"""
          context = create_request_context()
  
          # 测试简单查询
          result = test_searcher.search("红色连衣裙", context=context)
  
          assert result is not None
          # 简单查询应该标记为simple
          assert context.query_analysis.is_simple_query
  
      @patch('search.searcher.RankingEngine')
      def test_reranking(self, mock_ranking_engine_class, test_searcher):
          """测试重排序功能"""
          # 设置mock
          mock_ranking = Mock()
          mock_ranking_engine_class.return_value = mock_ranking
          mock_ranking.calculate_score.return_value = 2.0
  
          context = create_request_context()
          result = test_searcher.search("红色连衣裙", enable_rerank=True, context=context)
  
          # 验证重排序被调用
          hits = result.hits
          if hits:  # 如果有结果
              # 应该有自定义分数
              assert all('_custom_score' in hit for hit in hits)
              assert all('_original_score' in hit for hit in hits)
  
      def test_spu_collapse(self, test_searcher):
          """测试SPU折叠功能"""
          # 配置SPU
          test_searcher.config.spu_config.enabled = True
          test_searcher.config.spu_config.spu_field = "spu_id"
          test_searcher.config.spu_config.inner_hits_size = 3
  
          context = create_request_context()
          result = test_searcher.search("红色连衣裙", context=context)
  
          assert result is not None
          # 验证SPU折叠配置被应用
          assert context.get_intermediate_result('es_query') is not None
  
      def test_embedding_search(self, test_searcher):
          """测试向量搜索功能"""
          # 配置embedding字段
          test_searcher.text_embedding_field = "text_embedding"
  
          context = create_request_context()
          result = test_searcher.search("红色连衣裙", enable_embedding=True, context=context)
  
          assert result is not None
          # embedding搜索应该被启用
  
      def test_search_by_image(self, test_searcher):
          """测试图片搜索功能"""
          # 配置图片embedding字段
          test_searcher.image_embedding_field = "image_embedding"
  
          # Mock图片编码器
          with patch('search.searcher.CLIPImageEncoder') as mock_encoder_class:
              mock_encoder = Mock()
              mock_encoder_class.return_value = mock_encoder
              mock_encoder.encode_image_from_url.return_value = np.array([0.1, 0.2, 0.3])
  
              result = test_searcher.search_by_image("http://example.com/image.jpg")
  
              assert result is not None
              assert result.query_info['search_type'] == 'image_similarity'
              assert result.query_info['image_url'] == "http://example.com/image.jpg"
  
      def test_performance_monitoring(self, test_searcher):
          """测试性能监控"""
          context = create_request_context()
  
          result = test_searcher.search("红色连衣裙", context=context)
  
          # 验证各阶段都被计时
          assert context.get_stage_duration(RequestContextStage.QUERY_PARSING) >= 0
          assert context.get_stage_duration(RequestContextStage.QUERY_BUILDING) >= 0
          assert context.get_stage_duration(RequestContextStage.ELASTICSEARCH_SEARCH) >= 0
          assert context.get_stage_duration(RequestContextStage.RESULT_PROCESSING) >= 0
  
          # 验证总耗时
          assert context.performance_metrics.total_duration > 0
  
      def test_context_storage(self, test_searcher):
          """测试context存储功能"""
          context = create_request_context()
  
          result = test_searcher.search("红色连衣裙", context=context)
  
          # 验证查询分析结果被存储
          assert context.query_analysis.original_query == "红色连衣裙"
          assert context.query_analysis.domain is not None
  
          # 验证中间结果被存储
          assert context.get_intermediate_result('parsed_query') is not None
          assert context.get_intermediate_result('es_query') is not None
          assert context.get_intermediate_result('es_response') is not None
          assert context.get_intermediate_result('processed_hits') is not None