test_searcher.py 9.21 KB
"""
Searcher单元测试
"""

import pytest
from unittest.mock import Mock, patch, MagicMock
import numpy as np

from search import Searcher
from query import ParsedQuery
from context import RequestContext, create_request_context


@pytest.mark.unit
class TestSearcher:
    """Searcher测试用例"""

    def test_searcher_initialization(self, sample_customer_config, mock_es_client):
        """测试Searcher初始化"""
        searcher = Searcher(sample_customer_config, mock_es_client)

        assert searcher.config == sample_customer_config
        assert searcher.es_client == mock_es_client
        assert searcher.query_parser is not None
        assert searcher.boolean_parser is not None
        assert searcher.ranking_engine is not None

    def test_search_without_context(self, test_searcher):
        """测试不带context的搜索(向后兼容)"""
        result = test_searcher.search("红色连衣裙", size=5)

        assert result.hits is not None
        assert result.total >= 0
        assert result.context is not None  # 应该自动创建context
        assert result.took_ms >= 0

    def test_search_with_context(self, test_searcher):
        """测试带context的搜索"""
        context = create_request_context("test-req", "test-user")

        result = test_searcher.search("红色连衣裙", context=context)

        assert result.hits is not None
        assert result.context == context
        assert context.reqid == "test-req"
        assert context.uid == "test-user"

    def test_search_with_parameters(self, test_searcher):
        """测试带各种参数的搜索"""
        context = create_request_context()

        result = test_searcher.search(
            query="红色连衣裙",
            size=15,
            from_=5,
            filters={"category_id": 1},
            enable_translation=False,
            enable_embedding=False,
            enable_rerank=False,
            min_score=1.0,
            context=context
        )

        assert result is not None
        assert context.metadata['search_params']['size'] == 15
        assert context.metadata['search_params']['from'] == 5
        assert context.metadata['search_params']['filters'] == {"category_id": 1}
        assert context.metadata['search_params']['min_score'] == 1.0

        # 验证feature flags
        assert context.metadata['feature_flags']['enable_translation'] is False
        assert context.metadata['feature_flags']['enable_embedding'] is False
        assert context.metadata['feature_flags']['enable_rerank'] is False

    @patch('search.searcher.QueryParser')
    def test_search_query_parsing(self, mock_query_parser_class, test_searcher):
        """测试查询解析流程"""
        # 设置mock
        mock_parser = Mock()
        mock_query_parser_class.return_value = mock_parser

        parsed_query = ParsedQuery(
            original_query="红色连衣裙",
            normalized_query="红色 连衣裙",
            rewritten_query="红色 女 连衣裙",
            detected_language="zh",
            domain="default"
        )
        mock_parser.parse.return_value = parsed_query

        context = create_request_context()
        test_searcher.search("红色连衣裙", context=context)

        # 验证query parser被调用
        mock_parser.parse.assert_called_once_with("红色连衣裙", generate_vector=True, context=context)

    def test_search_error_handling(self, test_searcher):
        """测试搜索错误处理"""
        # 设置ES客户端抛出异常
        test_searcher.es_client.search.side_effect = Exception("ES连接失败")

        context = create_request_context()

        with pytest.raises(Exception, match="ES连接失败"):
            test_searcher.search("红色连衣裙", context=context)

        # 验证错误被记录到context
        assert context.has_error()
        assert "ES连接失败" in context.metadata['error_info']['message']

    def test_search_result_processing(self, test_searcher):
        """测试搜索结果处理"""
        context = create_request_context()

        result = test_searcher.search("红色连衣裙", enable_rerank=True, context=context)

        # 验证结果结构
        assert hasattr(result, 'hits')
        assert hasattr(result, 'total')
        assert hasattr(result, 'max_score')
        assert hasattr(result, 'took_ms')
        assert hasattr(result, 'aggregations')
        assert hasattr(result, 'query_info')
        assert hasattr(result, 'context')

        # 验证context中有中间结果
        assert context.get_intermediate_result('es_response') is not None
        assert context.get_intermediate_result('raw_hits') is not None
        assert context.get_intermediate_result('processed_hits') is not None

    def test_boolean_query_handling(self, test_searcher):
        """测试布尔查询处理"""
        context = create_request_context()

        # 测试复杂布尔查询
        result = test_searcher.search("laptop AND (gaming OR professional)", context=context)

        assert result is not None
        # 对于复杂查询,应该调用boolean parser
        assert not context.query_analysis.is_simple_query

    def test_simple_query_handling(self, test_searcher):
        """测试简单查询处理"""
        context = create_request_context()

        # 测试简单查询
        result = test_searcher.search("红色连衣裙", context=context)

        assert result is not None
        # 简单查询应该标记为simple
        assert context.query_analysis.is_simple_query

    @patch('search.searcher.RankingEngine')
    def test_reranking(self, mock_ranking_engine_class, test_searcher):
        """测试重排序功能"""
        # 设置mock
        mock_ranking = Mock()
        mock_ranking_engine_class.return_value = mock_ranking
        mock_ranking.calculate_score.return_value = 2.0

        context = create_request_context()
        result = test_searcher.search("红色连衣裙", enable_rerank=True, context=context)

        # 验证重排序被调用
        hits = result.hits
        if hits:  # 如果有结果
            # 应该有自定义分数
            assert all('_custom_score' in hit for hit in hits)
            assert all('_original_score' in hit for hit in hits)

    def test_spu_collapse(self, test_searcher):
        """测试SPU折叠功能"""
        # 配置SPU
        test_searcher.config.spu_config.enabled = True
        test_searcher.config.spu_config.spu_field = "spu_id"
        test_searcher.config.spu_config.inner_hits_size = 3

        context = create_request_context()
        result = test_searcher.search("红色连衣裙", context=context)

        assert result is not None
        # 验证SPU折叠配置被应用
        assert context.get_intermediate_result('es_query') is not None

    def test_embedding_search(self, test_searcher):
        """测试向量搜索功能"""
        # 配置embedding字段
        test_searcher.text_embedding_field = "text_embedding"

        context = create_request_context()
        result = test_searcher.search("红色连衣裙", enable_embedding=True, context=context)

        assert result is not None
        # embedding搜索应该被启用

    def test_search_by_image(self, test_searcher):
        """测试图片搜索功能"""
        # 配置图片embedding字段
        test_searcher.image_embedding_field = "image_embedding"

        # Mock图片编码器
        with patch('search.searcher.CLIPImageEncoder') as mock_encoder_class:
            mock_encoder = Mock()
            mock_encoder_class.return_value = mock_encoder
            mock_encoder.encode_image_from_url.return_value = np.array([0.1, 0.2, 0.3])

            result = test_searcher.search_by_image("http://example.com/image.jpg")

            assert result is not None
            assert result.query_info['search_type'] == 'image_similarity'
            assert result.query_info['image_url'] == "http://example.com/image.jpg"

    def test_performance_monitoring(self, test_searcher):
        """测试性能监控"""
        context = create_request_context()

        result = test_searcher.search("红色连衣裙", context=context)

        # 验证各阶段都被计时
        assert context.get_stage_duration(RequestContextStage.QUERY_PARSING) >= 0
        assert context.get_stage_duration(RequestContextStage.QUERY_BUILDING) >= 0
        assert context.get_stage_duration(RequestContextStage.ELASTICSEARCH_SEARCH) >= 0
        assert context.get_stage_duration(RequestContextStage.RESULT_PROCESSING) >= 0

        # 验证总耗时
        assert context.performance_metrics.total_duration > 0

    def test_context_storage(self, test_searcher):
        """测试context存储功能"""
        context = create_request_context()

        result = test_searcher.search("红色连衣裙", context=context)

        # 验证查询分析结果被存储
        assert context.query_analysis.original_query == "红色连衣裙"
        assert context.query_analysis.domain is not None

        # 验证中间结果被存储
        assert context.get_intermediate_result('parsed_query') is not None
        assert context.get_intermediate_result('es_query') is not None
        assert context.get_intermediate_result('es_response') is not None
        assert context.get_intermediate_result('processed_hits') is not None