test_process_products_batching.py 3.24 KB
from __future__ import annotations

from typing import Any, Dict, List

import indexer.product_enrich as process_products


def _mk_products(n: int) -> List[Dict[str, str]]:
    return [{"id": str(i), "title": f"title-{i}"} for i in range(n)]


def test_analyze_products_caps_batch_size_to_20(monkeypatch):
    monkeypatch.setattr(process_products, "API_KEY", "fake-key")
    seen_batch_sizes: List[int] = []

    def _fake_process_batch(
        batch_data: List[Dict[str, str]],
        batch_num: int,
        target_lang: str = "zh",
        analysis_kind: str = "content",
        category_taxonomy_profile=None,
    ):
        assert analysis_kind == "content"
        assert category_taxonomy_profile is None
        seen_batch_sizes.append(len(batch_data))
        return [
            {
                "id": item["id"],
                "lang": target_lang,
                "title_input": item["title"],
                "title": "",
                "category_path": "",
                "tags": "",
                "target_audience": "",
                "usage_scene": "",
                "season": "",
                "key_attributes": "",
                "material": "",
                "features": "",
                "anchor_text": "",
            }
            for item in batch_data
        ]

    monkeypatch.setattr(process_products, "process_batch", _fake_process_batch)
    monkeypatch.setattr(process_products, "_set_cached_analysis_result", lambda *args, **kwargs: None)

    out = process_products.analyze_products(
        products=_mk_products(45),
        target_lang="zh",
        batch_size=200,
        tenant_id="162",
    )

    assert len(out) == 45
    # 并发执行时 batch 调用顺序可能变化,因此校验“批大小集合”而不是严格顺序
    assert sorted(seen_batch_sizes) == [5, 20, 20]


def test_analyze_products_uses_min_batch_size_1(monkeypatch):
    monkeypatch.setattr(process_products, "API_KEY", "fake-key")
    seen_batch_sizes: List[int] = []

    def _fake_process_batch(
        batch_data: List[Dict[str, str]],
        batch_num: int,
        target_lang: str = "zh",
        analysis_kind: str = "content",
        category_taxonomy_profile=None,
    ):
        assert analysis_kind == "content"
        assert category_taxonomy_profile is None
        seen_batch_sizes.append(len(batch_data))
        return [
            {
                "id": item["id"],
                "lang": target_lang,
                "title_input": item["title"],
                "title": "",
                "category_path": "",
                "tags": "",
                "target_audience": "",
                "usage_scene": "",
                "season": "",
                "key_attributes": "",
                "material": "",
                "features": "",
                "anchor_text": "",
            }
            for item in batch_data
        ]

    monkeypatch.setattr(process_products, "process_batch", _fake_process_batch)
    monkeypatch.setattr(process_products, "_set_cached_analysis_result", lambda *args, **kwargs: None)

    out = process_products.analyze_products(
        products=_mk_products(3),
        target_lang="zh",
        batch_size=0,
        tenant_id="162",
    )

    assert len(out) == 3
    assert seen_batch_sizes == [1, 1, 1]