test_context.py 8.14 KB
"""
RequestContext单元测试
"""

import pytest
import time
from context import RequestContext, RequestContextStage, create_request_context


@pytest.mark.unit
class TestRequestContext:
    """RequestContext测试用例"""

    def test_create_context(self):
        """测试创建context"""
        context = create_request_context("req-001", "user-123")

        assert context.reqid == "req-001"
        assert context.uid == "user-123"
        assert not context.has_error()

    def test_auto_generated_reqid(self):
        """测试自动生成reqid"""
        context = RequestContext()

        assert context.reqid is not None
        assert len(context.reqid) == 8
        assert context.uid == "anonymous"

    def test_stage_timing(self):
        """测试阶段计时"""
        context = create_request_context()

        # 开始计时
        context.start_stage(RequestContextStage.QUERY_PARSING)
        time.sleep(0.05)  # 50ms
        duration = context.end_stage(RequestContextStage.QUERY_PARSING)

        assert duration >= 40  # 至少40ms(允许一些误差)
        assert duration < 100  # 不超过100ms
        assert context.get_stage_duration(RequestContextStage.QUERY_PARSING) == duration

    def test_store_query_analysis(self):
        """测试存储查询分析结果"""
        context = create_request_context()

        context.store_query_analysis(
            original_query="红色连衣裙",
            normalized_query="红色 连衣裙",
            rewritten_query="红色 女 连衣裙",
            detected_language="zh",
            translations={"en": "red dress"},
            domain="default",
            is_simple_query=True
        )

        assert context.query_analysis.original_query == "红色连衣裙"
        assert context.query_analysis.detected_language == "zh"
        assert context.query_analysis.translations["en"] == "red dress"
        assert context.query_analysis.is_simple_query is True

    def test_store_intermediate_results(self):
        """测试存储中间结果"""
        context = create_request_context()

        # 存储各种类型的中间结果
        context.store_intermediate_result('parsed_query', {'query': 'test'})
        context.store_intermediate_result('es_query', {'bool': {'must': []}})
        context.store_intermediate_result('hits', [{'_id': '1', '_score': 1.0}])

        assert context.get_intermediate_result('parsed_query') == {'query': 'test'}
        assert context.get_intermediate_result('es_query') == {'bool': {'must': []}}
        assert context.get_intermediate_result('hits') == [{'_id': '1', '_score': 1.0}]

        # 测试不存在的key
        assert context.get_intermediate_result('nonexistent') is None
        assert context.get_intermediate_result('nonexistent', 'default') == 'default'

    def test_error_handling(self):
        """测试错误处理"""
        context = create_request_context()

        assert not context.has_error()

        # 设置错误
        try:
            raise ValueError("测试错误")
        except Exception as e:
            context.set_error(e)

        assert context.has_error()
        error_info = context.metadata['error_info']
        assert error_info['type'] == 'ValueError'
        assert error_info['message'] == '测试错误'

    def test_warnings(self):
        """测试警告处理"""
        context = create_request_context()

        assert len(context.metadata['warnings']) == 0

        # 添加警告
        context.add_warning("第一个警告")
        context.add_warning("第二个警告")

        assert len(context.metadata['warnings']) == 2
        assert "第一个警告" in context.metadata['warnings']
        assert "第二个警告" in context.metadata['warnings']

    def test_stage_percentages(self):
        """测试阶段耗时占比计算"""
        context = create_request_context()
        context.performance_metrics.total_duration = 100.0

        # 设置各阶段耗时
        context.performance_metrics.stage_timings = {
            'query_parsing': 25.0,
            'elasticsearch_search': 50.0,
            'result_processing': 25.0
        }

        percentages = context.calculate_stage_percentages()

        assert percentages['query_parsing'] == 25.0
        assert percentages['elasticsearch_search'] == 50.0
        assert percentages['result_processing'] == 25.0

    def test_get_summary(self):
        """测试获取摘要"""
        context = create_request_context("test-req", "test-user")

        # 设置一些数据
        context.store_query_analysis(
            original_query="测试查询",
            detected_language="zh",
            domain="default"
        )
        context.store_intermediate_result('test_key', 'test_value')
        context.performance_metrics.total_duration = 150.0
        context.performance_metrics.stage_timings = {
            'query_parsing': 30.0,
            'elasticsearch_search': 80.0
        }

        summary = context.get_summary()

        # 验证基本结构
        assert 'request_info' in summary
        assert 'query_analysis' in summary
        assert 'performance' in summary
        assert 'results' in summary
        assert 'metadata' in summary

        # 验证具体内容
        assert summary['request_info']['reqid'] == 'test-req'
        assert summary['request_info']['uid'] == 'test-user'
        assert summary['query_analysis']['original_query'] == '测试查询'
        assert summary['query_analysis']['detected_language'] == 'zh'
        assert summary['performance']['total_duration_ms'] == 150.0
        assert 'query_parsing' in summary['performance']['stage_timings_ms']

    def test_context_manager(self):
        """测试上下文管理器功能"""
        with create_request_context("cm-test", "cm-user") as context:
            assert context.reqid == "cm-test"
            assert context.uid == "cm-user"

            # 在上下文中执行一些操作
            context.start_stage(RequestContextStage.QUERY_PARSING)
            time.sleep(0.01)
            context.end_stage(RequestContextStage.QUERY_PARSING)

            # 上下文应该仍然活跃
            assert context.get_stage_duration(RequestContextStage.QUERY_PARSING) > 0

        # 退出上下文后,应该自动记录了总时间
        assert context.performance_metrics.total_duration > 0


@pytest.mark.unit
class TestContextFactory:
    """Context工厂函数测试"""

    def test_create_request_context_with_params(self):
        """测试带参数创建context"""
        context = create_request_context("custom-req", "custom-user")

        assert context.reqid == "custom-req"
        assert context.uid == "custom-user"

    def test_create_request_context_without_params(self):
        """测试不带参数创建context"""
        context = create_request_context()

        assert context.reqid is not None
        assert len(context.reqid) == 8
        assert context.uid == "anonymous"

    def test_create_request_context_with_partial_params(self):
        """测试部分参数创建context"""
        context = create_request_context(reqid="partial-req")

        assert context.reqid == "partial-req"
        assert context.uid == "anonymous"

        context2 = create_request_context(uid="partial-user")
        assert context2.reqid is not None
        assert context2.uid == "partial-user"


@pytest.mark.unit
class TestContextStages:
    """Context阶段枚举测试"""

    def test_stage_values(self):
        """测试阶段枚举值"""
        assert RequestContextStage.TOTAL.value == "total_search"
        assert RequestContextStage.QUERY_PARSING.value == "query_parsing"
        assert RequestContextStage.BOOLEAN_PARSING.value == "boolean_parsing"
        assert RequestContextStage.QUERY_BUILDING.value == "query_building"
        assert RequestContextStage.ELASTICSEARCH_SEARCH.value == "elasticsearch_search"
        assert RequestContextStage.RESULT_PROCESSING.value == "result_processing"
        assert RequestContextStage.RERANKING.value == "reranking"

    def test_stage_uniqueness(self):
        """测试阶段值唯一性"""
        values = [stage.value for stage in RequestContextStage]
        assert len(values) == len(set(values)), "阶段值应该是唯一的"