""" QueryParser单元测试 """ import pytest from unittest.mock import Mock, patch, MagicMock import numpy as np from query import QueryParser, ParsedQuery from context import RequestContext, create_request_context @pytest.mark.unit class TestQueryParser: """QueryParser测试用例""" def test_parser_initialization(self, sample_customer_config): """测试QueryParser初始化""" parser = QueryParser(sample_customer_config) assert parser.config == sample_customer_config assert parser.query_config is not None assert parser.normalizer is not None assert parser.rewriter is not None assert parser.language_detector is not None assert parser.translator is not None @patch('query.query_parser.QueryNormalizer') @patch('query.query_parser.LanguageDetector') def test_parse_without_context(self, mock_detector_class, mock_normalizer_class, test_query_parser): """测试不带context的解析""" # 设置mock mock_normalizer = Mock() mock_normalizer_class.return_value = mock_normalizer mock_normalizer.normalize.return_value = "红色 连衣裙" mock_normalizer.extract_domain_query.return_value = ("default", "红色 连衣裙") mock_detector = Mock() mock_detector_class.return_value = mock_detector mock_detector.detect.return_value = "zh" result = test_query_parser.parse("红色连衣裙") assert isinstance(result, ParsedQuery) assert result.original_query == "红色连衣裙" assert result.normalized_query == "红色 连衣裙" assert result.rewritten_query == "红色 连衣裙" # 没有重写 assert result.detected_language == "zh" def test_parse_with_context(self, test_query_parser): """测试带context的解析""" context = create_request_context("parse-001", "parse-user") # Mock各种组件 with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \ patch.object(test_query_parser, 'language_detector') as mock_detector, \ patch.object(test_query_parser, 'translator') as mock_translator, \ patch.object(test_query_parser, 'text_encoder') as mock_encoder: # 设置mock返回值 mock_normalizer.normalize.return_value = "红色 连衣裙" mock_normalizer.extract_domain_query.return_value = ("default", "红色 连衣裙") mock_detector.detect.return_value = "zh" mock_translator.translate_multi.return_value = {"en": "red dress"} mock_encoder.encode.return_value = [np.array([0.1, 0.2, 0.3])] result = test_query_parser.parse("红色连衣裙", generate_vector=True, context=context) # 验证结果 assert isinstance(result, ParsedQuery) assert result.original_query == "红色连衣裙" assert result.detected_language == "zh" assert result.translations["en"] == "red dress" assert result.query_vector is not None # 验证context被更新 assert context.query_analysis.original_query == "红色连衣裙" assert context.query_analysis.normalized_query == "红色 连衣裙" assert context.query_analysis.detected_language == "zh" assert context.query_analysis.translations["en"] == "red dress" assert context.query_analysis.domain == "default" # 验证计时 assert context.get_stage_duration("query_parsing") > 0 @patch('query.query_parser.QueryRewriter') def test_query_rewriting(self, mock_rewriter_class, test_query_parser): """测试查询重写""" # 设置mock mock_rewriter = Mock() mock_rewriter_class.return_value = mock_rewriter mock_rewriter.rewrite.return_value = "红色 女 连衣裙" context = create_request_context() # 启用查询重写 test_query_parser.query_config.enable_query_rewrite = True result = test_query_parser.parse("红色连衣裙", context=context) assert result.rewritten_query == "红色 女 连衣裙" assert context.query_analysis.rewritten_query == "红色 女 连衣裙" def test_language_detection(self, test_query_parser): """测试语言检测""" context = create_request_context() with patch.object(test_query_parser, 'language_detector') as mock_detector, \ patch.object(test_query_parser, 'normalizer') as mock_normalizer: mock_normalizer.normalize.return_value = "red dress" mock_normalizer.extract_domain_query.return_value = ("default", "red dress") mock_detector.detect.return_value = "en" result = test_query_parser.parse("red dress", context=context) assert result.detected_language == "en" assert context.query_analysis.detected_language == "en" @patch('query.query_parser.Translator') def test_query_translation(self, mock_translator_class, test_query_parser): """测试查询翻译""" # 设置mock mock_translator = Mock() mock_translator_class.return_value = mock_translator mock_translator.get_translation_needs.return_value = ["en"] mock_translator.translate_multi.return_value = {"en": "red dress"} context = create_request_context() # 启用翻译 test_query_parser.query_config.enable_translation = True test_query_parser.query_config.supported_languages = ["zh", "en"] with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \ patch.object(test_query_parser, 'language_detector') as mock_detector: mock_normalizer.normalize.return_value = "红色 连衣裙" mock_normalizer.extract_domain_query.return_value = ("default", "红色 连衣裙") mock_detector.detect.return_value = "zh" result = test_query_parser.parse("红色连衣裙", context=context) assert result.translations["en"] == "red dress" assert context.query_analysis.translations["en"] == "red dress" @patch('query.query_parser.BgeEncoder') def test_text_embedding(self, mock_encoder_class, test_query_parser): """测试文本向量化""" # 设置mock mock_encoder = Mock() mock_encoder_class.return_value = mock_encoder mock_encoder.encode.return_value = [np.array([0.1, 0.2, 0.3])] context = create_request_context() # 启用向量化 test_query_parser.query_config.enable_text_embedding = True with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \ patch.object(test_query_parser, 'language_detector') as mock_detector: mock_normalizer.normalize.return_value = "红色 连衣裙" mock_normalizer.extract_domain_query.return_value = ("default", "红色 连衣裙") mock_detector.detect.return_value = "zh" result = test_query_parser.parse("红色连衣裙", generate_vector=True, context=context) assert result.query_vector is not None assert isinstance(result.query_vector, np.ndarray) assert context.query_analysis.query_vector is not None def test_domain_extraction(self, test_query_parser): """测试域名提取""" context = create_request_context() with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \ patch.object(test_query_parser, 'language_detector') as mock_detector: # 测试带域名的查询 mock_normalizer.normalize.return_value = "brand:nike 鞋子" mock_normalizer.extract_domain_query.return_value = ("brand", "nike 鞋子") mock_detector.detect.return_value = "zh" result = test_query_parser.parse("brand:nike 鞋子", context=context) assert result.domain == "brand" assert context.query_analysis.domain == "brand" def test_parse_with_disabled_features(self, test_query_parser): """测试禁用功能的解析""" context = create_request_context() # 禁用所有功能 test_query_parser.query_config.enable_query_rewrite = False test_query_parser.query_config.enable_translation = False test_query_parser.query_config.enable_text_embedding = False with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \ patch.object(test_query_parser, 'language_detector') as mock_detector: mock_normalizer.normalize.return_value = "红色 连衣裙" mock_normalizer.extract_domain_query.return_value = ("default", "红色 连衣裙") mock_detector.detect.return_value = "zh" result = test_query_parser.parse("红色连衣裙", generate_vector=False, context=context) assert result.original_query == "红色连衣裙" assert result.rewritten_query == "红色 连衣裙" # 没有重写 assert result.detected_language == "zh" assert len(result.translations) == 0 # 没有翻译 assert result.query_vector is None # 没有向量 def test_get_search_queries(self, test_query_parser): """测试获取搜索查询列表""" parsed_query = ParsedQuery( original_query="红色连衣裙", normalized_query="红色 连衣裙", rewritten_query="红色 连衣裙", detected_language="zh", translations={"en": "red dress", "fr": "robe rouge"} ) queries = test_query_parser.get_search_queries(parsed_query) assert len(queries) == 3 assert "红色 连衣裙" in queries assert "red dress" in queries assert "robe rouge" in queries def test_empty_query_handling(self, test_query_parser): """测试空查询处理""" result = test_query_parser.parse("") assert result.original_query == "" assert result.normalized_query == "" def test_whitespace_query_handling(self, test_query_parser): """测试空白字符查询处理""" result = test_query_parser.parse(" ") assert result.original_query == " " def test_error_handling_in_parsing(self, test_query_parser): """测试解析过程中的错误处理""" context = create_request_context() # Mock normalizer抛出异常 with patch.object(test_query_parser, 'normalizer') as mock_normalizer: mock_normalizer.normalize.side_effect = Exception("Normalization failed") with pytest.raises(Exception, match="Normalization failed"): test_query_parser.parse("红色连衣裙", context=context) def test_performance_timing(self, test_query_parser): """测试性能计时""" context = create_request_context() with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \ patch.object(test_query_parser, 'language_detector') as mock_detector: mock_normalizer.normalize.return_value = "test" mock_normalizer.extract_domain_query.return_value = ("default", "test") mock_detector.detect.return_value = "zh" result = test_query_parser.parse("test", context=context) # 验证计时被记录 assert context.get_stage_duration("query_parsing") > 0 assert context.get_intermediate_result('parsed_query') == result