test_aggregation_api.py 7.2 KB
"""
Tests for aggregation API functionality.
"""

import pytest
from fastapi.testclient import TestClient
from api.app import app

client = TestClient(app)


@pytest.mark.integration
@pytest.mark.api
def test_search_with_aggregations():
    """Test search with dynamic aggregations."""
    request_data = {
        "query": "芭比娃娃",
        "size": 10,
        "aggregations": {
            "category_name": {
                "type": "terms",
                "field": "categoryName_keyword",
                "size": 10
            },
            "brand_name": {
                "type": "terms",
                "field": "brandName_keyword",
                "size": 10
            },
            "price_ranges": {
                "type": "range",
                "field": "price",
                "ranges": [
                    {"key": "0-50", "to": 50},
                    {"key": "50-100", "from": 50, "to": 100},
                    {"key": "100-200", "from": 100, "to": 200},
                    {"key": "200+", "from": 200}
                ]
            }
        }
    }

    response = client.post("/search/", json=request_data)

    assert response.status_code == 200
    data = response.json()

    # Check basic search response structure
    assert "hits" in data
    assert "total" in data
    assert "aggregations" in data
    assert "query_info" in data

    # Check aggregations structure
    aggregations = data["aggregations"]

    # Should have category aggregations
    if "category_name" in aggregations:
        assert "buckets" in aggregations["category_name"]
        assert isinstance(aggregations["category_name"]["buckets"], list)

    # Should have brand aggregations
    if "brand_name" in aggregations:
        assert "buckets" in aggregations["brand_name"]
        assert isinstance(aggregations["brand_name"]["buckets"], list)

    # Should have price range aggregations
    if "price_ranges" in aggregations:
        assert "buckets" in aggregations["price_ranges"]
        assert isinstance(aggregations["price_ranges"]["buckets"], list)


@pytest.mark.integration
@pytest.mark.api
def test_search_with_sorting():
    """Test search with different sorting options."""

    # Test price ascending
    request_data = {
        "query": "玩具",
        "size": 5,
        "sort_by": "price_asc"
    }

    response = client.post("/search/", json=request_data)
    assert response.status_code == 200
    data = response.json()

    if data["hits"] and len(data["hits"]) > 1:
        # Check if results are sorted by price (ascending)
        prices = []
        for hit in data["hits"]:
            if "_source" in hit and "price" in hit["_source"]:
                prices.append(hit["_source"]["price"])

        if len(prices) > 1:
            assert prices == sorted(prices), "Results should be sorted by price ascending"

    # Test price descending
    request_data["sort_by"] = "price_desc"
    response = client.post("/search/", json=request_data)
    assert response.status_code == 200
    data = response.json()

    if data["hits"] and len(data["hits"]) > 1:
        prices = []
        for hit in data["hits"]:
            if "_source" in hit and "price" in hit["_source"]:
                prices.append(hit["_source"]["price"])

        if len(prices) > 1:
            assert prices == sorted(prices, reverse=True), "Results should be sorted by price descending"

    # Test time descending
    request_data["sort_by"] = "time_desc"
    response = client.post("/search/", json=request_data)
    assert response.status_code == 200
    data = response.json()

    if data["hits"] and len(data["hits"]) > 1:
        times = []
        for hit in data["hits"]:
            if "_source" in hit and "create_time" in hit["_source"]:
                times.append(hit["_source"]["create_time"])

        if len(times) > 1:
            # Newer items should come first
            assert times == sorted(times, reverse=True), "Results should be sorted by time descending"


@pytest.mark.integration
@pytest.mark.api
def test_search_with_filters_and_aggregations():
    """Test search with filters and aggregations together."""
    request_data = {
        "query": "玩具",
        "size": 10,
        "filters": {
            "category_name": ["芭比"]
        },
        "aggregations": {
            "brand_name": {
                "type": "terms",
                "field": "brandName_keyword",
                "size": 10
            }
        }
    }

    response = client.post("/search/", json=request_data)
    assert response.status_code == 200
    data = response.json()

    # Check that results are filtered
    assert "hits" in data
    for hit in data["hits"]:
        if "_source" in hit and "categoryName" in hit["_source"]:
            assert "芭比" in hit["_source"]["categoryName"]

    # Check that aggregations are still present
    assert "aggregations" in data


@pytest.mark.integration
@pytest.mark.api
def test_search_without_aggregations():
    """Test search without aggregations (default behavior)."""
    request_data = {
        "query": "玩具",
        "size": 10
    }

    response = client.post("/search/", json=request_data)
    assert response.status_code == 200
    data = response.json()

    # Should still have basic response structure
    assert "hits" in data
    assert "total" in data
    assert "query_info" in data

    # Aggregations might be empty or not present without explicit request
    assert "aggregations" in data


@pytest.mark.integration
@pytest.mark.api
def test_aggregation_edge_cases():
    """Test aggregation edge cases."""

    # Test with empty query
    request_data = {
        "query": "",
        "size": 10,
        "aggregations": {
            "category_name": {
                "type": "terms",
                "field": "categoryName_keyword",
                "size": 10
            }
        }
    }

    response = client.post("/search/", json=request_data)
    # Should handle empty query gracefully
    assert response.status_code in [200, 422]

    # Test with invalid aggregation type
    request_data = {
        "query": "玩具",
        "size": 10,
        "aggregations": {
            "invalid_agg": {
                "type": "invalid_type",
                "field": "categoryName_keyword",
                "size": 10
            }
        }
    }

    response = client.post("/search/", json=request_data)
    # Should handle invalid aggregation type gracefully
    assert response.status_code in [200, 422]


@pytest.mark.unit
def test_aggregation_spec_validation():
    """Test aggregation specification validation."""
    from api.models import AggregationSpec

    # Test valid aggregation spec
    agg_spec = AggregationSpec(
        field="categoryName_keyword",
        type="terms",
        size=10
    )
    assert agg_spec.field == "categoryName_keyword"
    assert agg_spec.type == "terms"
    assert agg_spec.size == 10

    # Test range aggregation spec
    range_agg = AggregationSpec(
        field="price",
        type="range",
        ranges=[
            {"key": "0-50", "to": 50},
            {"key": "50-100", "from": 50, "to": 100}
        ]
    )
    assert range_agg.field == "price"
    assert range_agg.type == "range"
    assert len(range_agg.ranges) == 2


if __name__ == "__main__":
    pytest.main([__file__])