test_api_integration.py 11.2 KB
"""
API集成测试

测试API接口的完整集成,包括请求处理、响应格式、错误处理等
"""

import pytest
import json
import asyncio
from unittest.mock import patch, Mock, AsyncMock
from fastapi.testclient import TestClient

# 导入API应用
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))

from api.app import app


@pytest.mark.integration
@pytest.mark.api
class TestAPIIntegration:
    """API集成测试"""

    @pytest.fixture
    def client(self):
        """创建测试客户端"""
        return TestClient(app)

    def test_search_api_basic(self, client):
        """测试基础搜索API"""
        response = client.get("/search", params={"q": "红色连衣裙"})

        assert response.status_code == 200
        data = response.json()

        # 验证响应结构
        assert "hits" in data
        assert "total" in data
        assert "max_score" in data
        assert "took_ms" in data
        assert "query_info" in data
        assert "performance_summary" in data

        # 验证hits是列表
        assert isinstance(data["hits"], list)
        assert isinstance(data["total"], int)
        assert isinstance(data["max_score"], (int, float))
        assert isinstance(data["took_ms"], int)

    def test_search_api_with_parameters(self, client):
        """测试带参数的搜索API"""
        params = {
            "q": "智能手机",
            "size": 15,
            "from": 5,
            "enable_translation": False,
            "enable_embedding": False,
            "enable_rerank": True,
            "min_score": 1.0
        }

        response = client.get("/search", params=params)

        assert response.status_code == 200
        data = response.json()

        # 验证参数被正确传递
        performance = data.get("performance_summary", {})
        metadata = performance.get("metadata", {})
        search_params = metadata.get("search_params", {})

        assert search_params.get("size") == 15
        assert search_params.get("from") == 5
        assert search_params.get("min_score") == 1.0

        feature_flags = metadata.get("feature_flags", {})
        assert feature_flags.get("enable_translation") is False
        assert feature_flags.get("enable_embedding") is False
        assert feature_flags.get("enable_rerank") is True

    def test_search_api_complex_query(self, client):
        """测试复杂查询API"""
        response = client.get("/search", params={"q": "手机 AND (华为 OR 苹果) ANDNOT 二手"})

        assert response.status_code == 200
        data = response.json()

        # 验证复杂查询被处理
        query_info = data.get("query_info", {})
        performance = data.get("performance_summary", {})
        query_analysis = performance.get("query_analysis", {})

        # 对于复杂查询,is_simple_query应该是False
        assert query_analysis.get("is_simple_query") is False

    def test_search_api_missing_query(self, client):
        """测试缺少查询参数的API"""
        response = client.get("/search")

        assert response.status_code == 422  # Validation error
        data = response.json()

        # 验证错误信息
        assert "detail" in data

    def test_search_api_empty_query(self, client):
        """测试空查询API"""
        response = client.get("/search", params={"q": ""})

        assert response.status_code == 200
        data = response.json()

        # 空查询应该返回有效结果
        assert "hits" in data
        assert isinstance(data["hits"], list)

    def test_search_api_with_filters(self, client):
        """测试带过滤器的搜索API"""
        response = client.get("/search", params={
            "q": "连衣裙",
            "filters": json.dumps({"category_id": 1, "brand": "测试品牌"})
        })

        assert response.status_code == 200
        data = response.json()

        # 验证过滤器被应用
        performance = data.get("performance_summary", {})
        metadata = performance.get("metadata", {})
        search_params = metadata.get("search_params", {})

        filters = search_params.get("filters", {})
        assert filters.get("category_id") == 1
        assert filters.get("brand") == "测试品牌"

    def test_search_api_performance_summary(self, client):
        """测试API性能摘要"""
        response = client.get("/search", params={"q": "性能测试查询"})

        assert response.status_code == 200
        data = response.json()

        performance = data.get("performance_summary", {})

        # 验证性能摘要结构
        assert "request_info" in performance
        assert "query_analysis" in performance
        assert "performance" in performance
        assert "results" in performance
        assert "metadata" in performance

        # 验证request_info
        request_info = performance["request_info"]
        assert "reqid" in request_info
        assert "uid" in request_info
        assert len(request_info["reqid"]) == 8  # 8字符的reqid

        # 验证performance
        perf_data = performance["performance"]
        assert "total_duration_ms" in perf_data
        assert "stage_timings_ms" in perf_data
        assert "stage_percentages" in perf_data
        assert isinstance(perf_data["total_duration_ms"], (int, float))
        assert perf_data["total_duration_ms"] >= 0

    def test_search_api_error_handling(self, client):
        """测试API错误处理"""
        # 模拟内部错误
        with patch('api.app._searcher') as mock_searcher:
            mock_searcher.search.side_effect = Exception("内部服务错误")

            response = client.get("/search", params={"q": "错误测试"})

            assert response.status_code == 500
            data = response.json()

            # 验证错误响应格式
            assert "error" in data
            assert "request_id" in data
            assert len(data["request_id"]) == 8

    def test_health_check_api(self, client):
        """测试健康检查API"""
        response = client.get("/health")

        assert response.status_code == 200
        data = response.json()

        # 验证健康检查响应
        assert "status" in data
        assert "timestamp" in data
        assert "service" in data
        assert "version" in data

        assert data["status"] in ["healthy", "unhealthy"]
        assert data["service"] == "search-engine-api"

    def test_metrics_api(self, client):
        """测试指标API"""
        response = client.get("/metrics")

        # 根据实现,可能是JSON格式或Prometheus格式
        assert response.status_code in [200, 404]  # 404如果未实现

    def test_concurrent_search_api(self, client):
        """测试并发搜索API"""
        async def test_concurrent():
            tasks = []
            for i in range(10):
                task = asyncio.create_task(
                    asyncio.to_thread(
                        client.get,
                        "/search",
                        params={"q": f"并发测试查询-{i}"}
                    )
                )
                tasks.append(task)

            responses = await asyncio.gather(*tasks)

            # 验证所有响应都成功
            for response in responses:
                assert response.status_code == 200
                data = response.json()
                assert "hits" in data
                assert "performance_summary" in data

        # 运行并发测试
        asyncio.run(test_concurrent())

    def test_search_api_response_time(self, client):
        """测试API响应时间"""
        import time

        start_time = time.time()
        response = client.get("/search", params={"q": "响应时间测试"})
        end_time = time.time()

        response_time_ms = (end_time - start_time) * 1000

        assert response.status_code == 200

        # API响应时间应该合理(例如,小于5秒)
        assert response_time_ms < 5000

        # 验证响应中的时间信息
        data = response.json()
        assert data["took_ms"] >= 0

        performance = data.get("performance_summary", {})
        perf_data = performance.get("performance", {})
        total_duration = perf_data.get("total_duration_ms", 0)

        # 总处理时间应该包括API开销
        assert total_duration > 0

    def test_search_api_large_query(self, client):
        """测试大查询API"""
        # 构造一个较长的查询
        long_query = " " * 1000 + "红色连衣裙"

        response = client.get("/search", params={"q": long_query})

        assert response.status_code == 200
        data = response.json()

        # 验证长查询被正确处理
        query_analysis = data.get("performance_summary", {}).get("query_analysis", {})
        assert query_analysis.get("original_query") == long_query

    def test_search_api_unicode_support(self, client):
        """测试API Unicode支持"""
        unicode_queries = [
            "红色连衣裙",  # 中文
            "red dress",   # 英文
            "robe rouge",  # 法文
            "赤いドレス",  # 日文
            "أحمر فستان",  # 阿拉伯文
            "👗🔴",        # Emoji
        ]

        for query in unicode_queries:
            response = client.get("/search", params={"q": query})

            assert response.status_code == 200
            data = response.json()

            # 验证Unicode查询被正确处理
            query_analysis = data.get("performance_summary", {}).get("query_analysis", {})
            assert query_analysis.get("original_query") == query

    def test_search_api_request_id_tracking(self, client):
        """测试API请求ID跟踪"""
        response = client.get("/search", params={"q": "请求ID测试"})

        assert response.status_code == 200
        data = response.json()

        # 验证每个请求都有唯一的reqid
        performance = data.get("performance_summary", {})
        request_info = performance.get("request_info", {})
        reqid = request_info.get("reqid")

        assert reqid is not None
        assert len(reqid) == 8
        assert reqid.isalnum()

    def test_search_api_rate_limiting(self, client):
        """测试API速率限制(如果实现了)"""
        # 快速发送多个请求
        responses = []
        for i in range(20):  # 发送20个快速请求
            response = client.get("/search", params={"q": f"速率限制测试-{i}"})
            responses.append(response)

        # 检查是否有请求被限制
        status_codes = [r.status_code for r in responses]
        rate_limited = any(code == 429 for code in status_codes)

        # 根据是否实现速率限制,验证结果
        if rate_limited:
            # 如果有速率限制,应该有一些429响应
            assert 429 in status_codes
        else:
            # 如果没有速率限制,所有请求都应该成功
            assert all(code == 200 for code in status_codes)

    def test_search_api_cors_headers(self, client):
        """测试API CORS头"""
        response = client.get("/search", params={"q": "CORS测试"})

        assert response.status_code == 200

        # 检查CORS头(如果配置了CORS)
        # 这取决于实际的CORS配置
        # response.headers.get("Access-Control-Allow-Origin")