test_search_integration.py
11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
"""
搜索集成测试
测试搜索流程的完整集成,包括QueryParser、BooleanParser、ESQueryBuilder等组件的协同工作
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
import json
import numpy as np
from search import Searcher
from query import QueryParser
from search.boolean_parser import BooleanParser, QueryNode
from search.multilang_query_builder import MultiLanguageQueryBuilder
from context import RequestContext, create_request_context
@pytest.mark.integration
@pytest.mark.slow
class TestSearchIntegration:
"""搜索集成测试"""
def test_end_to_end_search_flow(self, test_searcher):
"""测试端到端搜索流程"""
context = create_request_context("e2e-001", "e2e-user")
# 执行搜索
result = test_searcher.search("红色连衣裙", context=context)
# 验证结果结构
assert result.hits is not None
assert isinstance(result.hits, list)
assert result.total >= 0
assert result.took_ms >= 0
assert result.context == context
# 验证context中有完整的数据
summary = context.get_summary()
assert summary['query_analysis']['original_query'] == "红色连衣裙"
assert 'performance' in summary
assert summary['performance']['total_duration_ms'] > 0
# 验证各阶段都被执行
assert context.get_stage_duration("query_parsing") >= 0
assert context.get_stage_duration("query_building") >= 0
assert context.get_stage_duration("elasticsearch_search") >= 0
assert context.get_stage_duration("result_processing") >= 0
def test_complex_boolean_query_integration(self, test_searcher):
"""测试复杂布尔查询的集成"""
context = create_request_context("boolean-001")
# 复杂布尔查询
result = test_searcher.search("手机 AND (华为 OR 苹果) ANDNOT 二手", context=context)
assert result is not None
assert context.query_analysis.is_simple_query is False
assert context.query_analysis.boolean_ast is not None
# 验证中间结果
query_node = context.get_intermediate_result('query_node')
assert query_node is not None
assert isinstance(query_node, QueryNode)
def test_multilingual_search_integration(self, test_searcher):
"""测试多语言搜索集成"""
context = create_request_context("multilang-001")
with patch('query.query_parser.Translator') as mock_translator_class, \
patch('query.query_parser.LanguageDetector') as mock_detector_class:
# 设置mock
mock_translator = Mock()
mock_translator_class.return_value = mock_translator
mock_translator.get_translation_needs.return_value = ["en"]
mock_translator.translate_multi.return_value = {"en": "red dress"}
mock_detector = Mock()
mock_detector_class.return_value = mock_detector
mock_detector.detect.return_value = "zh"
result = test_searcher.search("红色连衣裙", enable_translation=True, context=context)
# 验证翻译结果被使用
assert context.query_analysis.translations.get("en") == "red dress"
assert context.query_analysis.detected_language == "zh"
def test_embedding_search_integration(self, test_searcher):
"""测试向量搜索集成"""
# 配置embedding字段
test_searcher.text_embedding_field = "text_embedding"
context = create_request_context("embedding-001")
with patch('query.query_parser.BgeEncoder') as mock_encoder_class:
# 设置mock
mock_encoder = Mock()
mock_encoder_class.return_value = mock_encoder
mock_encoder.encode.return_value = [np.array([0.1, 0.2, 0.3, 0.4])]
result = test_searcher.search("智能手机", enable_embedding=True, context=context)
# 验证向量被生成和使用
assert context.query_analysis.query_vector is not None
assert len(context.query_analysis.query_vector) == 4
# 验证ES查询包含KNN
es_query = context.get_intermediate_result('es_query')
if es_query and 'knn' in es_query:
assert 'text_embedding' in es_query['knn']
def test_spu_collapse_integration(self, test_searcher):
"""测试SPU折叠集成"""
# 启用SPU折叠
test_searcher.config.spu_config.enabled = True
test_searcher.config.spu_config.spu_field = "spu_id"
test_searcher.config.spu_config.inner_hits_size = 3
context = create_request_context("spu-001")
result = test_searcher.search("手机", context=context)
# 验证SPU折叠被应用
es_query = context.get_intermediate_result('es_query')
assert es_query is not None
# 如果ES查询构建正确,应该包含collapse配置
# 注意:这取决于ESQueryBuilder的实现
def test_reranking_integration(self, test_searcher):
"""测试重排序集成"""
context = create_request_context("rerank-001")
# 启用重排序
result = test_searcher.search("笔记本电脑", enable_rerank=True, context=context)
# 验证重排序阶段被执行
if result.hits: # 如果有结果
# 应该有自定义分数
assert all('_custom_score' in hit for hit in result.hits)
assert all('_original_score' in hit for hit in result.hits)
# 自定义分数应该被计算
custom_scores = [hit['_custom_score'] for hit in result.hits]
original_scores = [hit['_original_score'] for hit in result.hits]
assert len(custom_scores) == len(original_scores)
def test_error_propagation_integration(self, test_searcher):
"""测试错误传播集成"""
context = create_request_context("error-001")
# 模拟ES错误
test_searcher.es_client.search.side_effect = Exception("ES连接失败")
with pytest.raises(Exception, match="ES连接失败"):
test_searcher.search("测试查询", context=context)
# 验证错误被正确记录
assert context.has_error()
assert "ES连接失败" in context.metadata['error_info']['message']
def test_performance_monitoring_integration(self, test_searcher):
"""测试性能监控集成"""
context = create_request_context("perf-001")
# 模拟耗时操作
with patch('query.query_parser.QueryParser') as mock_parser_class:
mock_parser = Mock()
mock_parser_class.return_value = mock_parser
mock_parser.parse.side_effect = lambda q, **kwargs: Mock(
original_query=q,
normalized_query=q,
rewritten_query=q,
detected_language="zh",
domain="default",
translations={},
query_vector=None
)
# 执行搜索
result = test_searcher.search("性能测试查询", context=context)
# 验证性能数据被收集
summary = context.get_summary()
assert summary['performance']['total_duration_ms'] > 0
assert 'stage_timings_ms' in summary['performance']
assert 'stage_percentages' in summary['performance']
# 验证主要阶段都被计时
stages = ['query_parsing', 'query_building', 'elasticsearch_search', 'result_processing']
for stage in stages:
assert stage in summary['performance']['stage_timings_ms']
def test_context_data_persistence_integration(self, test_searcher):
"""测试context数据持久化集成"""
context = create_request_context("persist-001")
result = test_searcher.search("数据持久化测试", context=context)
# 验证所有关键数据都被存储
assert context.query_analysis.original_query == "数据持久化测试"
assert context.get_intermediate_result('parsed_query') is not None
assert context.get_intermediate_result('es_query') is not None
assert context.get_intermediate_result('es_response') is not None
assert context.get_intermediate_result('processed_hits') is not None
# 验证元数据
assert 'search_params' in context.metadata
assert 'feature_flags' in context.metadata
assert context.metadata['search_params']['query'] == "数据持久化测试"
@pytest.mark.parametrize("query,expected_simple", [
("红色连衣裙", True),
("手机 AND 电脑", False),
("(华为 OR 苹果) ANDNOT 二手", False),
"laptop RANK gaming", False,
("简单查询", True)
])
def test_query_complexity_detection(self, test_searcher, query, expected_simple):
"""测试查询复杂度检测"""
context = create_request_context(f"complexity-{hash(query)}")
result = test_searcher.search(query, context=context)
assert context.query_analysis.is_simple_query == expected_simple
def test_search_with_all_features_enabled(self, test_searcher):
"""测试启用所有功能的搜索"""
# 配置所有功能
test_searcher.text_embedding_field = "text_embedding"
test_searcher.config.spu_config.enabled = True
test_searcher.config.spu_config.spu_field = "spu_id"
context = create_request_context("all-features-001")
with patch('query.query_parser.BgeEncoder') as mock_encoder_class, \
patch('query.query_parser.Translator') as mock_translator_class, \
patch('query.query_parser.LanguageDetector') as mock_detector_class:
# 设置所有mock
mock_encoder = Mock()
mock_encoder_class.return_value = mock_encoder
mock_encoder.encode.return_value = [np.array([0.1, 0.2])]
mock_translator = Mock()
mock_translator_class.return_value = mock_translator
mock_translator.get_translation_needs.return_value = ["en"]
mock_translator.translate_multi.return_value = {"en": "test query"}
mock_detector = Mock()
mock_detector_class.return_value = mock_detector
mock_detector.detect.return_value = "zh"
# 执行完整搜索
result = test_searcher.search(
"完整功能测试",
enable_translation=True,
enable_embedding=True,
enable_rerank=True,
context=context
)
# 验证所有功能都被使用
assert context.query_analysis.detected_language == "zh"
assert context.query_analysis.translations.get("en") == "test query"
assert context.query_analysis.query_vector is not None
# 验证所有阶段都有耗时记录
summary = context.get_summary()
expected_stages = [
'query_parsing', 'query_building',
'elasticsearch_search', 'result_processing'
]
for stage in expected_stages:
assert stage in summary['performance']['stage_timings_ms']
def test_search_result_context_integration(self, test_searcher):
"""测试搜索结果与context的集成"""
context = create_request_context("result-context-001")
result = test_searcher.search("结果上下文集成测试", context=context)
# 验证结果包含context
assert result.context == context
# 验证结果to_dict方法包含性能摘要
result_dict = result.to_dict()
assert 'performance_summary' in result_dict
assert result_dict['performance_summary']['request_info']['reqid'] == context.reqid
# 验证性能摘要内容
perf_summary = result_dict['performance_summary']
assert 'query_analysis' in perf_summary
assert 'performance' in perf_summary
assert 'results' in perf_summary
assert 'metadata' in perf_summary