""" RequestContext单元测试 """ import pytest import time from context import RequestContext, RequestContextStage, create_request_context @pytest.mark.unit class TestRequestContext: """RequestContext测试用例""" def test_create_context(self): """测试创建context""" context = create_request_context("req-001", "user-123") assert context.reqid == "req-001" assert context.uid == "user-123" assert not context.has_error() def test_auto_generated_reqid(self): """测试自动生成reqid""" context = RequestContext() assert context.reqid is not None assert len(context.reqid) == 8 assert context.uid == "anonymous" def test_stage_timing(self): """测试阶段计时""" context = create_request_context() # 开始计时 context.start_stage(RequestContextStage.QUERY_PARSING) time.sleep(0.05) # 50ms duration = context.end_stage(RequestContextStage.QUERY_PARSING) assert duration >= 40 # 至少40ms(允许一些误差) assert duration < 100 # 不超过100ms assert context.get_stage_duration(RequestContextStage.QUERY_PARSING) == duration def test_store_query_analysis(self): """测试存储查询分析结果""" context = create_request_context() context.store_query_analysis( original_query="红色连衣裙", normalized_query="红色 连衣裙", rewritten_query="红色 女 连衣裙", detected_language="zh", translations={"en": "red dress"}, domain="default", is_simple_query=True ) assert context.query_analysis.original_query == "红色连衣裙" assert context.query_analysis.detected_language == "zh" assert context.query_analysis.translations["en"] == "red dress" assert context.query_analysis.is_simple_query is True def test_store_intermediate_results(self): """测试存储中间结果""" context = create_request_context() # 存储各种类型的中间结果 context.store_intermediate_result('parsed_query', {'query': 'test'}) context.store_intermediate_result('es_query', {'bool': {'must': []}}) context.store_intermediate_result('hits', [{'_id': '1', '_score': 1.0}]) assert context.get_intermediate_result('parsed_query') == {'query': 'test'} assert context.get_intermediate_result('es_query') == {'bool': {'must': []}} assert context.get_intermediate_result('hits') == [{'_id': '1', '_score': 1.0}] # 测试不存在的key assert context.get_intermediate_result('nonexistent') is None assert context.get_intermediate_result('nonexistent', 'default') == 'default' def test_error_handling(self): """测试错误处理""" context = create_request_context() assert not context.has_error() # 设置错误 try: raise ValueError("测试错误") except Exception as e: context.set_error(e) assert context.has_error() error_info = context.metadata['error_info'] assert error_info['type'] == 'ValueError' assert error_info['message'] == '测试错误' def test_warnings(self): """测试警告处理""" context = create_request_context() assert len(context.metadata['warnings']) == 0 # 添加警告 context.add_warning("第一个警告") context.add_warning("第二个警告") assert len(context.metadata['warnings']) == 2 assert "第一个警告" in context.metadata['warnings'] assert "第二个警告" in context.metadata['warnings'] def test_stage_percentages(self): """测试阶段耗时占比计算""" context = create_request_context() context.performance_metrics.total_duration = 100.0 # 设置各阶段耗时 context.performance_metrics.stage_timings = { 'query_parsing': 25.0, 'elasticsearch_search': 50.0, 'result_processing': 25.0 } percentages = context.calculate_stage_percentages() assert percentages['query_parsing'] == 25.0 assert percentages['elasticsearch_search'] == 50.0 assert percentages['result_processing'] == 25.0 def test_get_summary(self): """测试获取摘要""" context = create_request_context("test-req", "test-user") # 设置一些数据 context.store_query_analysis( original_query="测试查询", detected_language="zh", domain="default" ) context.store_intermediate_result('test_key', 'test_value') context.performance_metrics.total_duration = 150.0 context.performance_metrics.stage_timings = { 'query_parsing': 30.0, 'elasticsearch_search': 80.0 } summary = context.get_summary() # 验证基本结构 assert 'request_info' in summary assert 'query_analysis' in summary assert 'performance' in summary assert 'results' in summary assert 'metadata' in summary # 验证具体内容 assert summary['request_info']['reqid'] == 'test-req' assert summary['request_info']['uid'] == 'test-user' assert summary['query_analysis']['original_query'] == '测试查询' assert summary['query_analysis']['detected_language'] == 'zh' assert summary['performance']['total_duration_ms'] == 150.0 assert 'query_parsing' in summary['performance']['stage_timings_ms'] def test_context_manager(self): """测试上下文管理器功能""" with create_request_context("cm-test", "cm-user") as context: assert context.reqid == "cm-test" assert context.uid == "cm-user" # 在上下文中执行一些操作 context.start_stage(RequestContextStage.QUERY_PARSING) time.sleep(0.01) context.end_stage(RequestContextStage.QUERY_PARSING) # 上下文应该仍然活跃 assert context.get_stage_duration(RequestContextStage.QUERY_PARSING) > 0 # 退出上下文后,应该自动记录了总时间 assert context.performance_metrics.total_duration > 0 @pytest.mark.unit class TestContextFactory: """Context工厂函数测试""" def test_create_request_context_with_params(self): """测试带参数创建context""" context = create_request_context("custom-req", "custom-user") assert context.reqid == "custom-req" assert context.uid == "custom-user" def test_create_request_context_without_params(self): """测试不带参数创建context""" context = create_request_context() assert context.reqid is not None assert len(context.reqid) == 8 assert context.uid == "anonymous" def test_create_request_context_with_partial_params(self): """测试部分参数创建context""" context = create_request_context(reqid="partial-req") assert context.reqid == "partial-req" assert context.uid == "anonymous" context2 = create_request_context(uid="partial-user") assert context2.reqid is not None assert context2.uid == "partial-user" @pytest.mark.unit class TestContextStages: """Context阶段枚举测试""" def test_stage_values(self): """测试阶段枚举值""" assert RequestContextStage.TOTAL.value == "total_search" assert RequestContextStage.QUERY_PARSING.value == "query_parsing" assert RequestContextStage.BOOLEAN_PARSING.value == "boolean_parsing" assert RequestContextStage.QUERY_BUILDING.value == "query_building" assert RequestContextStage.ELASTICSEARCH_SEARCH.value == "elasticsearch_search" assert RequestContextStage.RESULT_PROCESSING.value == "result_processing" assert RequestContextStage.RERANKING.value == "reranking" def test_stage_uniqueness(self): """测试阶段值唯一性""" values = [stage.value for stage in RequestContextStage] assert len(values) == len(set(values)), "阶段值应该是唯一的"