test_search_integration.py 11.9 KB
"""
搜索集成测试

测试搜索流程的完整集成,包括QueryParser、BooleanParser、ESQueryBuilder等组件的协同工作
"""

import pytest
from unittest.mock import Mock, patch, AsyncMock
import json
import numpy as np

from search import Searcher
from query import QueryParser
from search.boolean_parser import BooleanParser, QueryNode
from search.multilang_query_builder import MultiLanguageQueryBuilder
from context import RequestContext, create_request_context


@pytest.mark.integration
@pytest.mark.slow
class TestSearchIntegration:
    """搜索集成测试"""

    def test_end_to_end_search_flow(self, test_searcher):
        """测试端到端搜索流程"""
        context = create_request_context("e2e-001", "e2e-user")

        # 执行搜索
        result = test_searcher.search("红色连衣裙", context=context)

        # 验证结果结构
        assert result.hits is not None
        assert isinstance(result.hits, list)
        assert result.total >= 0
        assert result.took_ms >= 0
        assert result.context == context

        # 验证context中有完整的数据
        summary = context.get_summary()
        assert summary['query_analysis']['original_query'] == "红色连衣裙"
        assert 'performance' in summary
        assert summary['performance']['total_duration_ms'] > 0

        # 验证各阶段都被执行
        assert context.get_stage_duration("query_parsing") >= 0
        assert context.get_stage_duration("query_building") >= 0
        assert context.get_stage_duration("elasticsearch_search") >= 0
        assert context.get_stage_duration("result_processing") >= 0

    def test_complex_boolean_query_integration(self, test_searcher):
        """测试复杂布尔查询的集成"""
        context = create_request_context("boolean-001")

        # 复杂布尔查询
        result = test_searcher.search("手机 AND (华为 OR 苹果) ANDNOT 二手", context=context)

        assert result is not None
        assert context.query_analysis.is_simple_query is False
        assert context.query_analysis.boolean_ast is not None

        # 验证中间结果
        query_node = context.get_intermediate_result('query_node')
        assert query_node is not None
        assert isinstance(query_node, QueryNode)

    def test_multilingual_search_integration(self, test_searcher):
        """测试多语言搜索集成"""
        context = create_request_context("multilang-001")

        with patch('query.query_parser.Translator') as mock_translator_class, \
             patch('query.query_parser.LanguageDetector') as mock_detector_class:

            # 设置mock
            mock_translator = Mock()
            mock_translator_class.return_value = mock_translator
            mock_translator.get_translation_needs.return_value = ["en"]
            mock_translator.translate_multi.return_value = {"en": "red dress"}

            mock_detector = Mock()
            mock_detector_class.return_value = mock_detector
            mock_detector.detect.return_value = "zh"

            result = test_searcher.search("红色连衣裙", enable_translation=True, context=context)

            # 验证翻译结果被使用
            assert context.query_analysis.translations.get("en") == "red dress"
            assert context.query_analysis.detected_language == "zh"

    def test_embedding_search_integration(self, test_searcher):
        """测试向量搜索集成"""
        # 配置embedding字段
        test_searcher.text_embedding_field = "text_embedding"

        context = create_request_context("embedding-001")

        with patch('query.query_parser.BgeEncoder') as mock_encoder_class:
            # 设置mock
            mock_encoder = Mock()
            mock_encoder_class.return_value = mock_encoder
            mock_encoder.encode.return_value = [np.array([0.1, 0.2, 0.3, 0.4])]

            result = test_searcher.search("智能手机", enable_embedding=True, context=context)

            # 验证向量被生成和使用
            assert context.query_analysis.query_vector is not None
            assert len(context.query_analysis.query_vector) == 4

            # 验证ES查询包含KNN
            es_query = context.get_intermediate_result('es_query')
            if es_query and 'knn' in es_query:
                assert 'text_embedding' in es_query['knn']

    def test_spu_collapse_integration(self, test_searcher):
        """测试SPU折叠集成"""
        # 启用SPU折叠
        test_searcher.config.spu_config.enabled = True
        test_searcher.config.spu_config.spu_field = "spu_id"
        test_searcher.config.spu_config.inner_hits_size = 3

        context = create_request_context("spu-001")

        result = test_searcher.search("手机", context=context)

        # 验证SPU折叠被应用
        es_query = context.get_intermediate_result('es_query')
        assert es_query is not None

        # 如果ES查询构建正确,应该包含collapse配置
        # 注意:这取决于ESQueryBuilder的实现

    def test_reranking_integration(self, test_searcher):
        """测试重排序集成"""
        context = create_request_context("rerank-001")

        # 启用重排序
        result = test_searcher.search("笔记本电脑", enable_rerank=True, context=context)

        # 验证重排序阶段被执行
        if result.hits:  # 如果有结果
            # 应该有自定义分数
            assert all('_custom_score' in hit for hit in result.hits)
            assert all('_original_score' in hit for hit in result.hits)

            # 自定义分数应该被计算
            custom_scores = [hit['_custom_score'] for hit in result.hits]
            original_scores = [hit['_original_score'] for hit in result.hits]
            assert len(custom_scores) == len(original_scores)

    def test_error_propagation_integration(self, test_searcher):
        """测试错误传播集成"""
        context = create_request_context("error-001")

        # 模拟ES错误
        test_searcher.es_client.search.side_effect = Exception("ES连接失败")

        with pytest.raises(Exception, match="ES连接失败"):
            test_searcher.search("测试查询", context=context)

        # 验证错误被正确记录
        assert context.has_error()
        assert "ES连接失败" in context.metadata['error_info']['message']

    def test_performance_monitoring_integration(self, test_searcher):
        """测试性能监控集成"""
        context = create_request_context("perf-001")

        # 模拟耗时操作
        with patch('query.query_parser.QueryParser') as mock_parser_class:
            mock_parser = Mock()
            mock_parser_class.return_value = mock_parser
            mock_parser.parse.side_effect = lambda q, **kwargs: Mock(
                original_query=q,
                normalized_query=q,
                rewritten_query=q,
                detected_language="zh",
                domain="default",
                translations={},
                query_vector=None
            )

            # 执行搜索
            result = test_searcher.search("性能测试查询", context=context)

            # 验证性能数据被收集
            summary = context.get_summary()
            assert summary['performance']['total_duration_ms'] > 0
            assert 'stage_timings_ms' in summary['performance']
            assert 'stage_percentages' in summary['performance']

            # 验证主要阶段都被计时
            stages = ['query_parsing', 'query_building', 'elasticsearch_search', 'result_processing']
            for stage in stages:
                assert stage in summary['performance']['stage_timings_ms']

    def test_context_data_persistence_integration(self, test_searcher):
        """测试context数据持久化集成"""
        context = create_request_context("persist-001")

        result = test_searcher.search("数据持久化测试", context=context)

        # 验证所有关键数据都被存储
        assert context.query_analysis.original_query == "数据持久化测试"
        assert context.get_intermediate_result('parsed_query') is not None
        assert context.get_intermediate_result('es_query') is not None
        assert context.get_intermediate_result('es_response') is not None
        assert context.get_intermediate_result('processed_hits') is not None

        # 验证元数据
        assert 'search_params' in context.metadata
        assert 'feature_flags' in context.metadata
        assert context.metadata['search_params']['query'] == "数据持久化测试"

    @pytest.mark.parametrize("query,expected_simple", [
        ("红色连衣裙", True),
        ("手机 AND 电脑", False),
        ("(华为 OR 苹果) ANDNOT 二手", False),
        "laptop RANK gaming", False,
        ("简单查询", True)
    ])
    def test_query_complexity_detection(self, test_searcher, query, expected_simple):
        """测试查询复杂度检测"""
        context = create_request_context(f"complexity-{hash(query)}")

        result = test_searcher.search(query, context=context)

        assert context.query_analysis.is_simple_query == expected_simple

    def test_search_with_all_features_enabled(self, test_searcher):
        """测试启用所有功能的搜索"""
        # 配置所有功能
        test_searcher.text_embedding_field = "text_embedding"
        test_searcher.config.spu_config.enabled = True
        test_searcher.config.spu_config.spu_field = "spu_id"

        context = create_request_context("all-features-001")

        with patch('query.query_parser.BgeEncoder') as mock_encoder_class, \
             patch('query.query_parser.Translator') as mock_translator_class, \
             patch('query.query_parser.LanguageDetector') as mock_detector_class:

            # 设置所有mock
            mock_encoder = Mock()
            mock_encoder_class.return_value = mock_encoder
            mock_encoder.encode.return_value = [np.array([0.1, 0.2])]

            mock_translator = Mock()
            mock_translator_class.return_value = mock_translator
            mock_translator.get_translation_needs.return_value = ["en"]
            mock_translator.translate_multi.return_value = {"en": "test query"}

            mock_detector = Mock()
            mock_detector_class.return_value = mock_detector
            mock_detector.detect.return_value = "zh"

            # 执行完整搜索
            result = test_searcher.search(
                "完整功能测试",
                enable_translation=True,
                enable_embedding=True,
                enable_rerank=True,
                context=context
            )

            # 验证所有功能都被使用
            assert context.query_analysis.detected_language == "zh"
            assert context.query_analysis.translations.get("en") == "test query"
            assert context.query_analysis.query_vector is not None

            # 验证所有阶段都有耗时记录
            summary = context.get_summary()
            expected_stages = [
                'query_parsing', 'query_building',
                'elasticsearch_search', 'result_processing'
            ]
            for stage in expected_stages:
                assert stage in summary['performance']['stage_timings_ms']

    def test_search_result_context_integration(self, test_searcher):
        """测试搜索结果与context的集成"""
        context = create_request_context("result-context-001")

        result = test_searcher.search("结果上下文集成测试", context=context)

        # 验证结果包含context
        assert result.context == context

        # 验证结果to_dict方法包含性能摘要
        result_dict = result.to_dict()
        assert 'performance_summary' in result_dict
        assert result_dict['performance_summary']['request_info']['reqid'] == context.reqid

        # 验证性能摘要内容
        perf_summary = result_dict['performance_summary']
        assert 'query_analysis' in perf_summary
        assert 'performance' in perf_summary
        assert 'results' in perf_summary
        assert 'metadata' in perf_summary