test_searcher.py
9.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
"""
Searcher单元测试
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
import numpy as np
from search import Searcher
from query import ParsedQuery
from context import RequestContext, create_request_context
@pytest.mark.unit
class TestSearcher:
"""Searcher测试用例"""
def test_searcher_initialization(self, sample_customer_config, mock_es_client):
"""测试Searcher初始化"""
searcher = Searcher(sample_customer_config, mock_es_client)
assert searcher.config == sample_customer_config
assert searcher.es_client == mock_es_client
assert searcher.query_parser is not None
assert searcher.boolean_parser is not None
assert searcher.ranking_engine is not None
def test_search_without_context(self, test_searcher):
"""测试不带context的搜索(向后兼容)"""
result = test_searcher.search("红色连衣裙", size=5)
assert result.hits is not None
assert result.total >= 0
assert result.context is not None # 应该自动创建context
assert result.took_ms >= 0
def test_search_with_context(self, test_searcher):
"""测试带context的搜索"""
context = create_request_context("test-req", "test-user")
result = test_searcher.search("红色连衣裙", context=context)
assert result.hits is not None
assert result.context == context
assert context.reqid == "test-req"
assert context.uid == "test-user"
def test_search_with_parameters(self, test_searcher):
"""测试带各种参数的搜索"""
context = create_request_context()
result = test_searcher.search(
query="红色连衣裙",
size=15,
from_=5,
filters={"category_id": 1},
enable_translation=False,
enable_embedding=False,
enable_rerank=False,
min_score=1.0,
context=context
)
assert result is not None
assert context.metadata['search_params']['size'] == 15
assert context.metadata['search_params']['from'] == 5
assert context.metadata['search_params']['filters'] == {"category_id": 1}
assert context.metadata['search_params']['min_score'] == 1.0
# 验证feature flags
assert context.metadata['feature_flags']['enable_translation'] is False
assert context.metadata['feature_flags']['enable_embedding'] is False
assert context.metadata['feature_flags']['enable_rerank'] is False
@patch('search.searcher.QueryParser')
def test_search_query_parsing(self, mock_query_parser_class, test_searcher):
"""测试查询解析流程"""
# 设置mock
mock_parser = Mock()
mock_query_parser_class.return_value = mock_parser
parsed_query = ParsedQuery(
original_query="红色连衣裙",
normalized_query="红色 连衣裙",
rewritten_query="红色 女 连衣裙",
detected_language="zh",
domain="default"
)
mock_parser.parse.return_value = parsed_query
context = create_request_context()
test_searcher.search("红色连衣裙", context=context)
# 验证query parser被调用
mock_parser.parse.assert_called_once_with("红色连衣裙", generate_vector=True, context=context)
def test_search_error_handling(self, test_searcher):
"""测试搜索错误处理"""
# 设置ES客户端抛出异常
test_searcher.es_client.search.side_effect = Exception("ES连接失败")
context = create_request_context()
with pytest.raises(Exception, match="ES连接失败"):
test_searcher.search("红色连衣裙", context=context)
# 验证错误被记录到context
assert context.has_error()
assert "ES连接失败" in context.metadata['error_info']['message']
def test_search_result_processing(self, test_searcher):
"""测试搜索结果处理"""
context = create_request_context()
result = test_searcher.search("红色连衣裙", enable_rerank=True, context=context)
# 验证结果结构
assert hasattr(result, 'hits')
assert hasattr(result, 'total')
assert hasattr(result, 'max_score')
assert hasattr(result, 'took_ms')
assert hasattr(result, 'aggregations')
assert hasattr(result, 'query_info')
assert hasattr(result, 'context')
# 验证context中有中间结果
assert context.get_intermediate_result('es_response') is not None
assert context.get_intermediate_result('raw_hits') is not None
assert context.get_intermediate_result('processed_hits') is not None
def test_boolean_query_handling(self, test_searcher):
"""测试布尔查询处理"""
context = create_request_context()
# 测试复杂布尔查询
result = test_searcher.search("laptop AND (gaming OR professional)", context=context)
assert result is not None
# 对于复杂查询,应该调用boolean parser
assert not context.query_analysis.is_simple_query
def test_simple_query_handling(self, test_searcher):
"""测试简单查询处理"""
context = create_request_context()
# 测试简单查询
result = test_searcher.search("红色连衣裙", context=context)
assert result is not None
# 简单查询应该标记为simple
assert context.query_analysis.is_simple_query
@patch('search.searcher.RankingEngine')
def test_reranking(self, mock_ranking_engine_class, test_searcher):
"""测试重排序功能"""
# 设置mock
mock_ranking = Mock()
mock_ranking_engine_class.return_value = mock_ranking
mock_ranking.calculate_score.return_value = 2.0
context = create_request_context()
result = test_searcher.search("红色连衣裙", enable_rerank=True, context=context)
# 验证重排序被调用
hits = result.hits
if hits: # 如果有结果
# 应该有自定义分数
assert all('_custom_score' in hit for hit in hits)
assert all('_original_score' in hit for hit in hits)
def test_spu_collapse(self, test_searcher):
"""测试SPU折叠功能"""
# 配置SPU
test_searcher.config.spu_config.enabled = True
test_searcher.config.spu_config.spu_field = "spu_id"
test_searcher.config.spu_config.inner_hits_size = 3
context = create_request_context()
result = test_searcher.search("红色连衣裙", context=context)
assert result is not None
# 验证SPU折叠配置被应用
assert context.get_intermediate_result('es_query') is not None
def test_embedding_search(self, test_searcher):
"""测试向量搜索功能"""
# 配置embedding字段
test_searcher.text_embedding_field = "text_embedding"
context = create_request_context()
result = test_searcher.search("红色连衣裙", enable_embedding=True, context=context)
assert result is not None
# embedding搜索应该被启用
def test_search_by_image(self, test_searcher):
"""测试图片搜索功能"""
# 配置图片embedding字段
test_searcher.image_embedding_field = "image_embedding"
# Mock图片编码器
with patch('search.searcher.CLIPImageEncoder') as mock_encoder_class:
mock_encoder = Mock()
mock_encoder_class.return_value = mock_encoder
mock_encoder.encode_image_from_url.return_value = np.array([0.1, 0.2, 0.3])
result = test_searcher.search_by_image("http://example.com/image.jpg")
assert result is not None
assert result.query_info['search_type'] == 'image_similarity'
assert result.query_info['image_url'] == "http://example.com/image.jpg"
def test_performance_monitoring(self, test_searcher):
"""测试性能监控"""
context = create_request_context()
result = test_searcher.search("红色连衣裙", context=context)
# 验证各阶段都被计时
assert context.get_stage_duration(RequestContextStage.QUERY_PARSING) >= 0
assert context.get_stage_duration(RequestContextStage.QUERY_BUILDING) >= 0
assert context.get_stage_duration(RequestContextStage.ELASTICSEARCH_SEARCH) >= 0
assert context.get_stage_duration(RequestContextStage.RESULT_PROCESSING) >= 0
# 验证总耗时
assert context.performance_metrics.total_duration > 0
def test_context_storage(self, test_searcher):
"""测试context存储功能"""
context = create_request_context()
result = test_searcher.search("红色连衣裙", context=context)
# 验证查询分析结果被存储
assert context.query_analysis.original_query == "红色连衣裙"
assert context.query_analysis.domain is not None
# 验证中间结果被存储
assert context.get_intermediate_result('parsed_query') is not None
assert context.get_intermediate_result('es_query') is not None
assert context.get_intermediate_result('es_response') is not None
assert context.get_intermediate_result('processed_hits') is not None