conftest.py 7.78 KB
"""
pytest配置文件

提供测试夹具和共享配置
"""

import os
import sys
import pytest
import tempfile
from typing import Dict, Any, Generator
from unittest.mock import Mock, MagicMock

# 添加项目根目录到Python路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)

from config import SearchConfig, QueryConfig, IndexConfig, FieldConfig, SPUConfig, RankingConfig, FunctionScoreConfig, RerankConfig
from config.field_types import FieldType, AnalyzerType
from utils.es_client import ESClient
from search import Searcher
from query import QueryParser
from context import RequestContext, create_request_context


@pytest.fixture
def sample_field_config() -> FieldConfig:
    """样例字段配置"""
    return FieldConfig(
        name="name",
        type="TEXT",
        analyzer="ansj",
        searchable=True,
        filterable=False
    )


@pytest.fixture
def sample_index_config() -> IndexConfig:
    """样例索引配置"""
    return IndexConfig(
        name="default",
        label="默认索引",
        fields=["name", "brand_name", "tags"],
        analyzer=AnalyzerType.CHINESE_ECOMMERCE,
        language_field_mapping={
            "zh": ["name", "brand_name"],
            "en": ["name_en", "brand_name_en"]
        }
    )


@pytest.fixture
def sample_search_config(sample_index_config) -> SearchConfig:
    """样例搜索配置"""
    query_config = QueryConfig(
        enable_query_rewrite=True,
        enable_translation=True,
        enable_text_embedding=True,
        supported_languages=["zh", "en"]
    )

    spu_config = SPUConfig(
        enabled=True,
        spu_field="spu_id",
        inner_hits_size=3
    )

    ranking_config = RankingConfig(
        expression="static_bm25() + text_embedding_relevance() * 0.2",
        description="Test ranking"
    )

    function_score_config = FunctionScoreConfig()
    rerank_config = RerankConfig()

    return SearchConfig(
        es_index_name="test_products",
        fields=[
            FieldConfig(name="tenant_id", field_type=FieldType.KEYWORD, required=True),
            FieldConfig(name="name", field_type=FieldType.TEXT, analyzer=AnalyzerType.CHINESE_ECOMMERCE),
            FieldConfig(name="brand_name", field_type=FieldType.TEXT, analyzer=AnalyzerType.CHINESE_ECOMMERCE),
            FieldConfig(name="tags", field_type=FieldType.TEXT, analyzer=AnalyzerType.CHINESE_ECOMMERCE),
            FieldConfig(name="price", field_type=FieldType.DOUBLE),
            FieldConfig(name="category_id", field_type=FieldType.INT),
        ],
        indexes=[sample_index_config],
        query_config=query_config,
        ranking=ranking_config,
        function_score=function_score_config,
        rerank=rerank_config,
        spu_config=spu_config
    )


@pytest.fixture
def mock_es_client() -> Mock:
    """模拟ES客户端"""
    mock_client = Mock(spec=ESClient)

    # 模拟搜索响应
    mock_response = {
        "hits": {
            "total": {"value": 10},
            "max_score": 2.5,
            "hits": [
                {
                    "_id": "1",
                    "_score": 2.5,
                    "_source": {
                        "name": "红色连衣裙",
                        "brand_name": "测试品牌",
                        "price": 299.0,
                        "category_id": 1
                    }
                },
                {
                    "_id": "2",
                    "_score": 2.2,
                    "_source": {
                        "name": "蓝色连衣裙",
                        "brand_name": "测试品牌",
                        "price": 399.0,
                        "category_id": 1
                    }
                }
            ]
        },
        "took": 15
    }

    mock_client.search.return_value = mock_response
    return mock_client


@pytest.fixture
def test_searcher(sample_search_config, mock_es_client) -> Searcher:
    """测试用Searcher实例"""
    return Searcher(
        es_client=mock_es_client,
        config=sample_search_config
    )


@pytest.fixture
def test_query_parser(sample_search_config) -> QueryParser:
    """测试用QueryParser实例"""
    return QueryParser(sample_search_config)


@pytest.fixture
def test_request_context() -> RequestContext:
    """测试用RequestContext实例"""
    return create_request_context("test-req-001", "test-user")


@pytest.fixture
def sample_search_results() -> Dict[str, Any]:
    """样例搜索结果"""
    return {
        "query": "红色连衣裙",
        "expected_total": 2,
        "expected_products": [
            {"name": "红色连衣裙", "price": 299.0},
            {"name": "蓝色连衣裙", "price": 399.0}
        ]
    }


@pytest.fixture
def temp_config_file() -> Generator[str, None, None]:
    """临时配置文件"""
    import tempfile
    import yaml

    config_data = {
        "es_index_name": "test_products",
        "query_config": {
            "enable_query_rewrite": True,
            "enable_translation": True,
            "enable_text_embedding": True,
            "supported_languages": ["zh", "en"]
        },
        "fields": [
            {"name": "tenant_id", "type": "KEYWORD", "required": True},
            {"name": "name", "type": "TEXT", "analyzer": "ansj"},
            {"name": "brand_name", "type": "TEXT", "analyzer": "ansj"}
        ],
        "indexes": [
            {
                "name": "default",
                "label": "默认索引",
                "fields": ["name", "brand_name"],
                "analyzer": "ansj",
                "language_field_mapping": {
                    "zh": ["name", "brand_name"],
                    "en": ["name_en", "brand_name_en"]
                }
            }
        ],
        "spu_config": {
            "enabled": True,
            "spu_field": "spu_id",
            "inner_hits_size": 3
        },
        "ranking": {
            "expression": "static_bm25() + text_embedding_relevance() * 0.2",
            "description": "Test ranking"
        },
        "function_score": {
            "score_mode": "sum",
            "boost_mode": "multiply",
            "functions": []
        },
        "rerank": {
            "enabled": False,
            "expression": "",
            "description": ""
        }
    }

    with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
        yaml.dump(config_data, f)
        temp_file = f.name

    yield temp_file

    # 清理
    os.unlink(temp_file)


@pytest.fixture
def mock_env_variables(monkeypatch):
    """设置环境变量"""
    monkeypatch.setenv("ES_HOST", "http://localhost:9200")
    monkeypatch.setenv("ES_USERNAME", "elastic")
    monkeypatch.setenv("ES_PASSWORD", "changeme")


# 标记配置
pytest_plugins = []

# 标记定义
def pytest_configure(config):
    """配置pytest标记"""
    config.addinivalue_line(
        "markers", "unit: 单元测试"
    )
    config.addinivalue_line(
        "markers", "integration: 集成测试"
    )
    config.addinivalue_line(
        "markers", "api: API测试"
    )
    config.addinivalue_line(
        "markers", "e2e: 端到端测试"
    )
    config.addinivalue_line(
        "markers", "performance: 性能测试"
    )
    config.addinivalue_line(
        "markers", "slow: 慢速测试"
    )


# 测试数据
@pytest.fixture
def test_queries():
    """测试查询集合"""
    return [
        "红色连衣裙",
        "wireless bluetooth headphones",
        "手机 手机壳",
        "laptop AND (gaming OR professional)",
        "运动鞋 -价格:0-500"
    ]


@pytest.fixture
def expected_response_structure():
    """期望的API响应结构"""
    return {
        "hits": list,
        "total": int,
        "max_score": float,
        "took_ms": int,
        "aggregations": dict,
        "query_info": dict,
        "performance_summary": dict
    }