test_query_parser.py 11.4 KB
"""
QueryParser单元测试
"""

import pytest
from unittest.mock import Mock, patch, MagicMock
import numpy as np

from query import QueryParser, ParsedQuery
from context import RequestContext, create_request_context


@pytest.mark.unit
class TestQueryParser:
    """QueryParser测试用例"""

    def test_parser_initialization(self, sample_customer_config):
        """测试QueryParser初始化"""
        parser = QueryParser(sample_customer_config)

        assert parser.config == sample_customer_config
        assert parser.query_config is not None
        assert parser.normalizer is not None
        assert parser.rewriter is not None
        assert parser.language_detector is not None
        assert parser.translator is not None

    @patch('query.query_parser.QueryNormalizer')
    @patch('query.query_parser.LanguageDetector')
    def test_parse_without_context(self, mock_detector_class, mock_normalizer_class, test_query_parser):
        """测试不带context的解析"""
        # 设置mock
        mock_normalizer = Mock()
        mock_normalizer_class.return_value = mock_normalizer
        mock_normalizer.normalize.return_value = "红色 连衣裙"
        mock_normalizer.extract_domain_query.return_value = ("default", "红色 连衣裙")

        mock_detector = Mock()
        mock_detector_class.return_value = mock_detector
        mock_detector.detect.return_value = "zh"

        result = test_query_parser.parse("红色连衣裙")

        assert isinstance(result, ParsedQuery)
        assert result.original_query == "红色连衣裙"
        assert result.normalized_query == "红色 连衣裙"
        assert result.rewritten_query == "红色 连衣裙"  # 没有重写
        assert result.detected_language == "zh"

    def test_parse_with_context(self, test_query_parser):
        """测试带context的解析"""
        context = create_request_context("parse-001", "parse-user")

        # Mock各种组件
        with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \
             patch.object(test_query_parser, 'language_detector') as mock_detector, \
             patch.object(test_query_parser, 'translator') as mock_translator, \
             patch.object(test_query_parser, 'text_encoder') as mock_encoder:

            # 设置mock返回值
            mock_normalizer.normalize.return_value = "红色 连衣裙"
            mock_normalizer.extract_domain_query.return_value = ("default", "红色 连衣裙")
            mock_detector.detect.return_value = "zh"
            mock_translator.translate_multi.return_value = {"en": "red dress"}
            mock_encoder.encode.return_value = [np.array([0.1, 0.2, 0.3])]

            result = test_query_parser.parse("红色连衣裙", generate_vector=True, context=context)

            # 验证结果
            assert isinstance(result, ParsedQuery)
            assert result.original_query == "红色连衣裙"
            assert result.detected_language == "zh"
            assert result.translations["en"] == "red dress"
            assert result.query_vector is not None

            # 验证context被更新
            assert context.query_analysis.original_query == "红色连衣裙"
            assert context.query_analysis.normalized_query == "红色 连衣裙"
            assert context.query_analysis.detected_language == "zh"
            assert context.query_analysis.translations["en"] == "red dress"
            assert context.query_analysis.domain == "default"

            # 验证计时
            assert context.get_stage_duration("query_parsing") > 0

    @patch('query.query_parser.QueryRewriter')
    def test_query_rewriting(self, mock_rewriter_class, test_query_parser):
        """测试查询重写"""
        # 设置mock
        mock_rewriter = Mock()
        mock_rewriter_class.return_value = mock_rewriter
        mock_rewriter.rewrite.return_value = "红色 女 连衣裙"

        context = create_request_context()

        # 启用查询重写
        test_query_parser.query_config.enable_query_rewrite = True

        result = test_query_parser.parse("红色连衣裙", context=context)

        assert result.rewritten_query == "红色 女 连衣裙"
        assert context.query_analysis.rewritten_query == "红色 女 连衣裙"

    def test_language_detection(self, test_query_parser):
        """测试语言检测"""
        context = create_request_context()

        with patch.object(test_query_parser, 'language_detector') as mock_detector, \
             patch.object(test_query_parser, 'normalizer') as mock_normalizer:

            mock_normalizer.normalize.return_value = "red dress"
            mock_normalizer.extract_domain_query.return_value = ("default", "red dress")
            mock_detector.detect.return_value = "en"

            result = test_query_parser.parse("red dress", context=context)

            assert result.detected_language == "en"
            assert context.query_analysis.detected_language == "en"

    @patch('query.query_parser.Translator')
    def test_query_translation(self, mock_translator_class, test_query_parser):
        """测试查询翻译"""
        # 设置mock
        mock_translator = Mock()
        mock_translator_class.return_value = mock_translator
        mock_translator.get_translation_needs.return_value = ["en"]
        mock_translator.translate_multi.return_value = {"en": "red dress"}

        context = create_request_context()

        # 启用翻译
        test_query_parser.query_config.enable_translation = True
        test_query_parser.query_config.supported_languages = ["zh", "en"]

        with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \
             patch.object(test_query_parser, 'language_detector') as mock_detector:

            mock_normalizer.normalize.return_value = "红色 连衣裙"
            mock_normalizer.extract_domain_query.return_value = ("default", "红色 连衣裙")
            mock_detector.detect.return_value = "zh"

            result = test_query_parser.parse("红色连衣裙", context=context)

            assert result.translations["en"] == "red dress"
            assert context.query_analysis.translations["en"] == "red dress"

    @patch('query.query_parser.BgeEncoder')
    def test_text_embedding(self, mock_encoder_class, test_query_parser):
        """测试文本向量化"""
        # 设置mock
        mock_encoder = Mock()
        mock_encoder_class.return_value = mock_encoder
        mock_encoder.encode.return_value = [np.array([0.1, 0.2, 0.3])]

        context = create_request_context()

        # 启用向量化
        test_query_parser.query_config.enable_text_embedding = True

        with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \
             patch.object(test_query_parser, 'language_detector') as mock_detector:

            mock_normalizer.normalize.return_value = "红色 连衣裙"
            mock_normalizer.extract_domain_query.return_value = ("default", "红色 连衣裙")
            mock_detector.detect.return_value = "zh"

            result = test_query_parser.parse("红色连衣裙", generate_vector=True, context=context)

            assert result.query_vector is not None
            assert isinstance(result.query_vector, np.ndarray)
            assert context.query_analysis.query_vector is not None

    def test_domain_extraction(self, test_query_parser):
        """测试域名提取"""
        context = create_request_context()

        with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \
             patch.object(test_query_parser, 'language_detector') as mock_detector:

            # 测试带域名的查询
            mock_normalizer.normalize.return_value = "brand:nike 鞋子"
            mock_normalizer.extract_domain_query.return_value = ("brand", "nike 鞋子")
            mock_detector.detect.return_value = "zh"

            result = test_query_parser.parse("brand:nike 鞋子", context=context)

            assert result.domain == "brand"
            assert context.query_analysis.domain == "brand"

    def test_parse_with_disabled_features(self, test_query_parser):
        """测试禁用功能的解析"""
        context = create_request_context()

        # 禁用所有功能
        test_query_parser.query_config.enable_query_rewrite = False
        test_query_parser.query_config.enable_translation = False
        test_query_parser.query_config.enable_text_embedding = False

        with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \
             patch.object(test_query_parser, 'language_detector') as mock_detector:

            mock_normalizer.normalize.return_value = "红色 连衣裙"
            mock_normalizer.extract_domain_query.return_value = ("default", "红色 连衣裙")
            mock_detector.detect.return_value = "zh"

            result = test_query_parser.parse("红色连衣裙", generate_vector=False, context=context)

            assert result.original_query == "红色连衣裙"
            assert result.rewritten_query == "红色 连衣裙"  # 没有重写
            assert result.detected_language == "zh"
            assert len(result.translations) == 0  # 没有翻译
            assert result.query_vector is None  # 没有向量

    def test_get_search_queries(self, test_query_parser):
        """测试获取搜索查询列表"""
        parsed_query = ParsedQuery(
            original_query="红色连衣裙",
            normalized_query="红色 连衣裙",
            rewritten_query="红色 连衣裙",
            detected_language="zh",
            translations={"en": "red dress", "fr": "robe rouge"}
        )

        queries = test_query_parser.get_search_queries(parsed_query)

        assert len(queries) == 3
        assert "红色 连衣裙" in queries
        assert "red dress" in queries
        assert "robe rouge" in queries

    def test_empty_query_handling(self, test_query_parser):
        """测试空查询处理"""
        result = test_query_parser.parse("")

        assert result.original_query == ""
        assert result.normalized_query == ""

    def test_whitespace_query_handling(self, test_query_parser):
        """测试空白字符查询处理"""
        result = test_query_parser.parse("   ")

        assert result.original_query == "   "

    def test_error_handling_in_parsing(self, test_query_parser):
        """测试解析过程中的错误处理"""
        context = create_request_context()

        # Mock normalizer抛出异常
        with patch.object(test_query_parser, 'normalizer') as mock_normalizer:
            mock_normalizer.normalize.side_effect = Exception("Normalization failed")

            with pytest.raises(Exception, match="Normalization failed"):
                test_query_parser.parse("红色连衣裙", context=context)

    def test_performance_timing(self, test_query_parser):
        """测试性能计时"""
        context = create_request_context()

        with patch.object(test_query_parser, 'normalizer') as mock_normalizer, \
             patch.object(test_query_parser, 'language_detector') as mock_detector:

            mock_normalizer.normalize.return_value = "test"
            mock_normalizer.extract_domain_query.return_value = ("default", "test")
            mock_detector.detect.return_value = "zh"

            result = test_query_parser.parse("test", context=context)

            # 验证计时被记录
            assert context.get_stage_duration("query_parsing") > 0
            assert context.get_intermediate_result('parsed_query') == result