test_context.py
8.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""
RequestContext单元测试
"""
import pytest
import time
from context import RequestContext, RequestContextStage, create_request_context
@pytest.mark.unit
class TestRequestContext:
"""RequestContext测试用例"""
def test_create_context(self):
"""测试创建context"""
context = create_request_context("req-001", "user-123")
assert context.reqid == "req-001"
assert context.uid == "user-123"
assert not context.has_error()
def test_auto_generated_reqid(self):
"""测试自动生成reqid"""
context = RequestContext()
assert context.reqid is not None
assert len(context.reqid) == 8
assert context.uid == "anonymous"
def test_stage_timing(self):
"""测试阶段计时"""
context = create_request_context()
# 开始计时
context.start_stage(RequestContextStage.QUERY_PARSING)
time.sleep(0.05) # 50ms
duration = context.end_stage(RequestContextStage.QUERY_PARSING)
assert duration >= 40 # 至少40ms(允许一些误差)
assert duration < 100 # 不超过100ms
assert context.get_stage_duration(RequestContextStage.QUERY_PARSING) == duration
def test_store_query_analysis(self):
"""测试存储查询分析结果"""
context = create_request_context()
context.store_query_analysis(
original_query="红色连衣裙",
normalized_query="红色 连衣裙",
rewritten_query="红色 女 连衣裙",
detected_language="zh",
translations={"en": "red dress"},
domain="default",
is_simple_query=True
)
assert context.query_analysis.original_query == "红色连衣裙"
assert context.query_analysis.detected_language == "zh"
assert context.query_analysis.translations["en"] == "red dress"
assert context.query_analysis.is_simple_query is True
def test_store_intermediate_results(self):
"""测试存储中间结果"""
context = create_request_context()
# 存储各种类型的中间结果
context.store_intermediate_result('parsed_query', {'query': 'test'})
context.store_intermediate_result('es_query', {'bool': {'must': []}})
context.store_intermediate_result('hits', [{'_id': '1', '_score': 1.0}])
assert context.get_intermediate_result('parsed_query') == {'query': 'test'}
assert context.get_intermediate_result('es_query') == {'bool': {'must': []}}
assert context.get_intermediate_result('hits') == [{'_id': '1', '_score': 1.0}]
# 测试不存在的key
assert context.get_intermediate_result('nonexistent') is None
assert context.get_intermediate_result('nonexistent', 'default') == 'default'
def test_error_handling(self):
"""测试错误处理"""
context = create_request_context()
assert not context.has_error()
# 设置错误
try:
raise ValueError("测试错误")
except Exception as e:
context.set_error(e)
assert context.has_error()
error_info = context.metadata['error_info']
assert error_info['type'] == 'ValueError'
assert error_info['message'] == '测试错误'
def test_warnings(self):
"""测试警告处理"""
context = create_request_context()
assert len(context.metadata['warnings']) == 0
# 添加警告
context.add_warning("第一个警告")
context.add_warning("第二个警告")
assert len(context.metadata['warnings']) == 2
assert "第一个警告" in context.metadata['warnings']
assert "第二个警告" in context.metadata['warnings']
def test_stage_percentages(self):
"""测试阶段耗时占比计算"""
context = create_request_context()
context.performance_metrics.total_duration = 100.0
# 设置各阶段耗时
context.performance_metrics.stage_timings = {
'query_parsing': 25.0,
'elasticsearch_search': 50.0,
'result_processing': 25.0
}
percentages = context.calculate_stage_percentages()
assert percentages['query_parsing'] == 25.0
assert percentages['elasticsearch_search'] == 50.0
assert percentages['result_processing'] == 25.0
def test_get_summary(self):
"""测试获取摘要"""
context = create_request_context("test-req", "test-user")
# 设置一些数据
context.store_query_analysis(
original_query="测试查询",
detected_language="zh",
domain="default"
)
context.store_intermediate_result('test_key', 'test_value')
context.performance_metrics.total_duration = 150.0
context.performance_metrics.stage_timings = {
'query_parsing': 30.0,
'elasticsearch_search': 80.0
}
summary = context.get_summary()
# 验证基本结构
assert 'request_info' in summary
assert 'query_analysis' in summary
assert 'performance' in summary
assert 'results' in summary
assert 'metadata' in summary
# 验证具体内容
assert summary['request_info']['reqid'] == 'test-req'
assert summary['request_info']['uid'] == 'test-user'
assert summary['query_analysis']['original_query'] == '测试查询'
assert summary['query_analysis']['detected_language'] == 'zh'
assert summary['performance']['total_duration_ms'] == 150.0
assert 'query_parsing' in summary['performance']['stage_timings_ms']
def test_context_manager(self):
"""测试上下文管理器功能"""
with create_request_context("cm-test", "cm-user") as context:
assert context.reqid == "cm-test"
assert context.uid == "cm-user"
# 在上下文中执行一些操作
context.start_stage(RequestContextStage.QUERY_PARSING)
time.sleep(0.01)
context.end_stage(RequestContextStage.QUERY_PARSING)
# 上下文应该仍然活跃
assert context.get_stage_duration(RequestContextStage.QUERY_PARSING) > 0
# 退出上下文后,应该自动记录了总时间
assert context.performance_metrics.total_duration > 0
@pytest.mark.unit
class TestContextFactory:
"""Context工厂函数测试"""
def test_create_request_context_with_params(self):
"""测试带参数创建context"""
context = create_request_context("custom-req", "custom-user")
assert context.reqid == "custom-req"
assert context.uid == "custom-user"
def test_create_request_context_without_params(self):
"""测试不带参数创建context"""
context = create_request_context()
assert context.reqid is not None
assert len(context.reqid) == 8
assert context.uid == "anonymous"
def test_create_request_context_with_partial_params(self):
"""测试部分参数创建context"""
context = create_request_context(reqid="partial-req")
assert context.reqid == "partial-req"
assert context.uid == "anonymous"
context2 = create_request_context(uid="partial-user")
assert context2.reqid is not None
assert context2.uid == "partial-user"
@pytest.mark.unit
class TestContextStages:
"""Context阶段枚举测试"""
def test_stage_values(self):
"""测试阶段枚举值"""
assert RequestContextStage.TOTAL.value == "total_search"
assert RequestContextStage.QUERY_PARSING.value == "query_parsing"
assert RequestContextStage.BOOLEAN_PARSING.value == "boolean_parsing"
assert RequestContextStage.QUERY_BUILDING.value == "query_building"
assert RequestContextStage.ELASTICSEARCH_SEARCH.value == "elasticsearch_search"
assert RequestContextStage.RESULT_PROCESSING.value == "result_processing"
assert RequestContextStage.RERANKING.value == "reranking"
def test_stage_uniqueness(self):
"""测试阶段值唯一性"""
values = [stage.value for stage in RequestContextStage]
assert len(values) == len(set(values)), "阶段值应该是唯一的"