""" Searcher单元测试 """ import pytest from unittest.mock import Mock, patch, MagicMock import numpy as np from search import Searcher from query import ParsedQuery from context import RequestContext, create_request_context @pytest.mark.unit class TestSearcher: """Searcher测试用例""" def test_searcher_initialization(self, sample_customer_config, mock_es_client): """测试Searcher初始化""" searcher = Searcher(sample_customer_config, mock_es_client) assert searcher.config == sample_customer_config assert searcher.es_client == mock_es_client assert searcher.query_parser is not None assert searcher.boolean_parser is not None assert searcher.ranking_engine is not None def test_search_without_context(self, test_searcher): """测试不带context的搜索(向后兼容)""" result = test_searcher.search("红色连衣裙", size=5) assert result.hits is not None assert result.total >= 0 assert result.context is not None # 应该自动创建context assert result.took_ms >= 0 def test_search_with_context(self, test_searcher): """测试带context的搜索""" context = create_request_context("test-req", "test-user") result = test_searcher.search("红色连衣裙", context=context) assert result.hits is not None assert result.context == context assert context.reqid == "test-req" assert context.uid == "test-user" def test_search_with_parameters(self, test_searcher): """测试带各种参数的搜索""" context = create_request_context() result = test_searcher.search( query="红色连衣裙", size=15, from_=5, filters={"category_id": 1}, enable_translation=False, enable_embedding=False, enable_rerank=False, min_score=1.0, context=context ) assert result is not None assert context.metadata['search_params']['size'] == 15 assert context.metadata['search_params']['from'] == 5 assert context.metadata['search_params']['filters'] == {"category_id": 1} assert context.metadata['search_params']['min_score'] == 1.0 # 验证feature flags assert context.metadata['feature_flags']['enable_translation'] is False assert context.metadata['feature_flags']['enable_embedding'] is False assert context.metadata['feature_flags']['enable_rerank'] is False @patch('search.searcher.QueryParser') def test_search_query_parsing(self, mock_query_parser_class, test_searcher): """测试查询解析流程""" # 设置mock mock_parser = Mock() mock_query_parser_class.return_value = mock_parser parsed_query = ParsedQuery( original_query="红色连衣裙", normalized_query="红色 连衣裙", rewritten_query="红色 女 连衣裙", detected_language="zh", domain="default" ) mock_parser.parse.return_value = parsed_query context = create_request_context() test_searcher.search("红色连衣裙", context=context) # 验证query parser被调用 mock_parser.parse.assert_called_once_with("红色连衣裙", generate_vector=True, context=context) def test_search_error_handling(self, test_searcher): """测试搜索错误处理""" # 设置ES客户端抛出异常 test_searcher.es_client.search.side_effect = Exception("ES连接失败") context = create_request_context() with pytest.raises(Exception, match="ES连接失败"): test_searcher.search("红色连衣裙", context=context) # 验证错误被记录到context assert context.has_error() assert "ES连接失败" in context.metadata['error_info']['message'] def test_search_result_processing(self, test_searcher): """测试搜索结果处理""" context = create_request_context() result = test_searcher.search("红色连衣裙", enable_rerank=True, context=context) # 验证结果结构 assert hasattr(result, 'hits') assert hasattr(result, 'total') assert hasattr(result, 'max_score') assert hasattr(result, 'took_ms') assert hasattr(result, 'aggregations') assert hasattr(result, 'query_info') assert hasattr(result, 'context') # 验证context中有中间结果 assert context.get_intermediate_result('es_response') is not None assert context.get_intermediate_result('raw_hits') is not None assert context.get_intermediate_result('processed_hits') is not None def test_boolean_query_handling(self, test_searcher): """测试布尔查询处理""" context = create_request_context() # 测试复杂布尔查询 result = test_searcher.search("laptop AND (gaming OR professional)", context=context) assert result is not None # 对于复杂查询,应该调用boolean parser assert not context.query_analysis.is_simple_query def test_simple_query_handling(self, test_searcher): """测试简单查询处理""" context = create_request_context() # 测试简单查询 result = test_searcher.search("红色连衣裙", context=context) assert result is not None # 简单查询应该标记为simple assert context.query_analysis.is_simple_query @patch('search.searcher.RankingEngine') def test_reranking(self, mock_ranking_engine_class, test_searcher): """测试重排序功能""" # 设置mock mock_ranking = Mock() mock_ranking_engine_class.return_value = mock_ranking mock_ranking.calculate_score.return_value = 2.0 context = create_request_context() result = test_searcher.search("红色连衣裙", enable_rerank=True, context=context) # 验证重排序被调用 hits = result.hits if hits: # 如果有结果 # 应该有自定义分数 assert all('_custom_score' in hit for hit in hits) assert all('_original_score' in hit for hit in hits) def test_spu_collapse(self, test_searcher): """测试SPU折叠功能""" # 配置SPU test_searcher.config.spu_config.enabled = True test_searcher.config.spu_config.spu_field = "spu_id" test_searcher.config.spu_config.inner_hits_size = 3 context = create_request_context() result = test_searcher.search("红色连衣裙", context=context) assert result is not None # 验证SPU折叠配置被应用 assert context.get_intermediate_result('es_query') is not None def test_embedding_search(self, test_searcher): """测试向量搜索功能""" # 配置embedding字段 test_searcher.text_embedding_field = "text_embedding" context = create_request_context() result = test_searcher.search("红色连衣裙", enable_embedding=True, context=context) assert result is not None # embedding搜索应该被启用 def test_search_by_image(self, test_searcher): """测试图片搜索功能""" # 配置图片embedding字段 test_searcher.image_embedding_field = "image_embedding" # Mock图片编码器 with patch('search.searcher.CLIPImageEncoder') as mock_encoder_class: mock_encoder = Mock() mock_encoder_class.return_value = mock_encoder mock_encoder.encode_image_from_url.return_value = np.array([0.1, 0.2, 0.3]) result = test_searcher.search_by_image("http://example.com/image.jpg") assert result is not None assert result.query_info['search_type'] == 'image_similarity' assert result.query_info['image_url'] == "http://example.com/image.jpg" def test_performance_monitoring(self, test_searcher): """测试性能监控""" context = create_request_context() result = test_searcher.search("红色连衣裙", context=context) # 验证各阶段都被计时 assert context.get_stage_duration(RequestContextStage.QUERY_PARSING) >= 0 assert context.get_stage_duration(RequestContextStage.QUERY_BUILDING) >= 0 assert context.get_stage_duration(RequestContextStage.ELASTICSEARCH_SEARCH) >= 0 assert context.get_stage_duration(RequestContextStage.RESULT_PROCESSING) >= 0 # 验证总耗时 assert context.performance_metrics.total_duration > 0 def test_context_storage(self, test_searcher): """测试context存储功能""" context = create_request_context() result = test_searcher.search("红色连衣裙", context=context) # 验证查询分析结果被存储 assert context.query_analysis.original_query == "红色连衣裙" assert context.query_analysis.domain is not None # 验证中间结果被存储 assert context.get_intermediate_result('parsed_query') is not None assert context.get_intermediate_result('es_query') is not None assert context.get_intermediate_result('es_response') is not None assert context.get_intermediate_result('processed_hits') is not None