test_product_enrich_partial_mode.py 12.8 KB
from __future__ import annotations

import importlib.util
import io
import json
import logging
import sys
import types
from pathlib import Path
from unittest import mock


def _load_product_enrich_module():
    if "dotenv" not in sys.modules:
        fake_dotenv = types.ModuleType("dotenv")
        fake_dotenv.load_dotenv = lambda *args, **kwargs: None
        sys.modules["dotenv"] = fake_dotenv

    if "redis" not in sys.modules:
        fake_redis = types.ModuleType("redis")

        class _FakeRedisClient:
            def __init__(self, *args, **kwargs):
                pass

            def ping(self):
                return True

        fake_redis.Redis = _FakeRedisClient
        sys.modules["redis"] = fake_redis

    repo_root = Path(__file__).resolve().parents[1]
    if str(repo_root) not in sys.path:
        sys.path.insert(0, str(repo_root))

    module_path = repo_root / "indexer" / "product_enrich.py"
    spec = importlib.util.spec_from_file_location("product_enrich_under_test", module_path)
    module = importlib.util.module_from_spec(spec)
    assert spec and spec.loader
    spec.loader.exec_module(module)
    return module


product_enrich = _load_product_enrich_module()


def _attach_stream(logger_obj: logging.Logger):
    stream = io.StringIO()
    handler = logging.StreamHandler(stream)
    handler.setFormatter(logging.Formatter("%(message)s"))
    logger_obj.addHandler(handler)
    return stream, handler


def test_create_prompt_splits_shared_context_and_localized_tail():
    products = [
        {"id": "1", "title": "dress"},
        {"id": "2", "title": "linen shirt"},
    ]

    shared_zh, user_zh, prefix_zh = product_enrich.create_prompt(products, target_lang="zh")
    shared_en, user_en, prefix_en = product_enrich.create_prompt(products, target_lang="en")

    assert shared_zh == shared_en
    assert "Analyze each input product text" in shared_zh
    assert "1. dress" in shared_zh
    assert "2. linen shirt" in shared_zh
    assert "Product list" not in user_zh
    assert "Product list" not in user_en
    assert "specified language" in user_zh
    assert "Language: Chinese" in user_zh
    assert "Language: English" in user_en
    assert prefix_zh.startswith("| 序号 | 商品标题 | 品类路径 |")
    assert prefix_en.startswith("| No. | Product title | Category path |")


def test_call_llm_logs_shared_context_once_and_verbose_contains_full_requests():
    payloads = []
    response_bodies = [
        {
            "choices": [
                {
                    "message": {
                        "content": (
                            "| 1 | 连衣裙 | 女装>连衣裙 | 法式,收腰 | 年轻女性 | "
                            "通勤,约会 | 春季,夏季 | 中长款 | 聚酯纤维 | 透气 | "
                            "修身显瘦 | 法式收腰连衣裙 |\n"
                        )
                    }
                }
            ],
            "usage": {"prompt_tokens": 120, "completion_tokens": 45, "total_tokens": 165},
        },
        {
            "choices": [
                {
                    "message": {
                        "content": (
                            "| 1 | Dress | Women>Dress | French,Waisted | Young women | "
                            "Commute,Date | Spring,Summer | Midi | Polyester | Breathable | "
                            "Slim fit | French waisted dress |\n"
                        )
                    }
                }
            ],
            "usage": {"prompt_tokens": 118, "completion_tokens": 43, "total_tokens": 161},
        },
    ]

    class _FakeResponse:
        def __init__(self, body):
            self.body = body

        def raise_for_status(self):
            return None

        def json(self):
            return self.body

    class _FakeSession:
        trust_env = True

        def post(self, url, headers=None, json=None, timeout=None, proxies=None):
            del url, headers, timeout, proxies
            payloads.append(json)
            return _FakeResponse(response_bodies[len(payloads) - 1])

        def close(self):
            return None

    product_enrich.reset_logged_shared_context_keys()
    main_stream, main_handler = _attach_stream(product_enrich.logger)
    verbose_stream, verbose_handler = _attach_stream(product_enrich.verbose_logger)

    try:
        with mock.patch.object(product_enrich, "API_KEY", "fake-key"), mock.patch.object(
            product_enrich.requests,
            "Session",
            lambda: _FakeSession(),
        ):
            zh_shared, zh_user, zh_prefix = product_enrich.create_prompt(
                [{"id": "1", "title": "dress"}],
                target_lang="zh",
            )
            en_shared, en_user, en_prefix = product_enrich.create_prompt(
                [{"id": "1", "title": "dress"}],
                target_lang="en",
            )

            zh_markdown, zh_raw = product_enrich.call_llm(
                zh_shared,
                zh_user,
                zh_prefix,
                target_lang="zh",
            )
            en_markdown, en_raw = product_enrich.call_llm(
                en_shared,
                en_user,
                en_prefix,
                target_lang="en",
            )
    finally:
        product_enrich.logger.removeHandler(main_handler)
        product_enrich.verbose_logger.removeHandler(verbose_handler)

    assert zh_shared == en_shared
    assert len(payloads) == 2
    assert len(payloads[0]["messages"]) == 3
    assert payloads[0]["messages"][1]["role"] == "user"
    assert "1. dress" in payloads[0]["messages"][1]["content"]
    assert "Language: Chinese" in payloads[0]["messages"][1]["content"]
    assert "Language: English" in payloads[1]["messages"][1]["content"]
    assert payloads[0]["messages"][-1]["partial"] is True
    assert payloads[1]["messages"][-1]["partial"] is True

    main_log = main_stream.getvalue()
    verbose_log = verbose_stream.getvalue()

    assert main_log.count("LLM Shared Context") == 1
    assert main_log.count("LLM Request Variant") == 2
    assert "Localized Requirement" in main_log
    assert "Shared Context" in main_log

    assert verbose_log.count("LLM Request [model=") == 2
    assert verbose_log.count("LLM Response [model=") == 2
    assert '"partial": true' in verbose_log
    assert "Combined User Prompt" in verbose_log
    assert "French waisted dress" in verbose_log
    assert "法式收腰连衣裙" in verbose_log

    assert zh_markdown.startswith(zh_prefix)
    assert en_markdown.startswith(en_prefix)
    assert json.loads(zh_raw)["usage"]["total_tokens"] == 165
    assert json.loads(en_raw)["usage"]["total_tokens"] == 161


def test_process_batch_reads_result_and_validates_expected_fields():
    merged_markdown = """| 序号 | 商品标题 | 品类路径 | 细分标签 | 适用人群 | 使用场景 | 适用季节 | 关键属性 | 材质说明 | 功能特点 | 锚文本 |
|----|----|----|----|----|----|----|----|----|----|----|
| 1 | 法式连衣裙 | 女装>连衣裙 | 法式,收腰 | 年轻女性 | 通勤,约会 | 春季,夏季 | 中长款 | 聚酯纤维 | 透气 | 法式收腰连衣裙 |
"""

    with mock.patch.object(
        product_enrich,
        "call_llm",
        return_value=(merged_markdown, json.dumps({"choices": [{"message": {"content": "stub"}}]})),
    ):
        results = product_enrich.process_batch(
            [{"id": "sku-1", "title": "dress"}],
            batch_num=1,
            target_lang="zh",
        )

    assert len(results) == 1
    row = results[0]
    assert row["id"] == "sku-1"
    assert row["lang"] == "zh"
    assert row["title_input"] == "dress"
    assert row["title"] == "法式连衣裙"
    assert row["category_path"] == "女装>连衣裙"
    assert row["tags"] == "法式,收腰"
    assert row["target_audience"] == "年轻女性"
    assert row["usage_scene"] == "通勤,约会"
    assert row["season"] == "春季,夏季"
    assert row["key_attributes"] == "中长款"
    assert row["material"] == "聚酯纤维"
    assert row["features"] == "透气"
    assert row["anchor_text"] == "法式收腰连衣裙"


def test_analyze_products_uses_product_level_cache_across_batch_requests():
    cache_store = {}
    process_calls = []

    def _cache_key(product, target_lang):
        return (
            target_lang,
            product.get("title", ""),
            product.get("brief", ""),
            product.get("description", ""),
            product.get("image_url", ""),
        )

    def fake_get_cached_anchor_result(product, target_lang):
        return cache_store.get(_cache_key(product, target_lang))

    def fake_set_cached_anchor_result(product, target_lang, result):
        cache_store[_cache_key(product, target_lang)] = result

    def fake_process_batch(batch_data, batch_num, target_lang="zh"):
        process_calls.append(
            {
                "batch_num": batch_num,
                "target_lang": target_lang,
                "titles": [item["title"] for item in batch_data],
            }
        )
        return [
            {
                "id": item["id"],
                "lang": target_lang,
                "title_input": item["title"],
                "title": f"normalized:{item['title']}",
                "category_path": "cat",
                "tags": "tags",
                "target_audience": "audience",
                "usage_scene": "scene",
                "season": "season",
                "key_attributes": "attrs",
                "material": "material",
                "features": "features",
                "anchor_text": f"anchor:{item['title']}",
            }
            for item in batch_data
        ]

    products = [
        {"id": "1", "title": "dress"},
        {"id": "2", "title": "shirt"},
    ]

    with mock.patch.object(product_enrich, "API_KEY", "fake-key"), mock.patch.object(
        product_enrich,
        "_get_cached_anchor_result",
        side_effect=fake_get_cached_anchor_result,
    ), mock.patch.object(
        product_enrich,
        "_set_cached_anchor_result",
        side_effect=fake_set_cached_anchor_result,
    ), mock.patch.object(
        product_enrich,
        "process_batch",
        side_effect=fake_process_batch,
    ):
        first = product_enrich.analyze_products(
            [products[0]],
            target_lang="zh",
            tenant_id="170",
        )
        second = product_enrich.analyze_products(
            products,
            target_lang="zh",
            tenant_id="999",
        )
        third = product_enrich.analyze_products(
            products,
            target_lang="zh",
            tenant_id="170",
        )

    assert [row["title_input"] for row in first] == ["dress"]
    assert [row["title_input"] for row in second] == ["dress", "shirt"]
    assert [row["title_input"] for row in third] == ["dress", "shirt"]

    assert process_calls == [
        {"batch_num": 1, "target_lang": "zh", "titles": ["dress"]},
        {"batch_num": 1, "target_lang": "zh", "titles": ["shirt"]},
    ]
    assert second[0]["anchor_text"] == "anchor:dress"
    assert second[1]["anchor_text"] == "anchor:shirt"
    assert third[0]["anchor_text"] == "anchor:dress"
    assert third[1]["anchor_text"] == "anchor:shirt"


def test_anchor_cache_key_depends_on_product_input_not_identifiers():
    product_a = {
        "id": "1",
        "spu_id": "1001",
        "title": "dress",
        "brief": "soft cotton",
        "description": "summer dress",
        "image_url": "https://img/a.jpg",
    }
    product_b = {
        "id": "2",
        "spu_id": "9999",
        "title": "dress",
        "brief": "soft cotton",
        "description": "summer dress",
        "image_url": "https://img/a.jpg",
    }
    product_c = {
        "id": "1",
        "spu_id": "1001",
        "title": "dress",
        "brief": "soft cotton updated",
        "description": "summer dress",
        "image_url": "https://img/a.jpg",
    }

    key_a = product_enrich._make_anchor_cache_key(product_a, "zh")
    key_b = product_enrich._make_anchor_cache_key(product_b, "zh")
    key_c = product_enrich._make_anchor_cache_key(product_c, "zh")

    assert key_a == key_b
    assert key_a != key_c


def test_build_prompt_input_text_appends_brief_and_description_for_short_title():
    product = {
        "title": "T恤",
        "brief": "夏季透气纯棉短袖,舒适亲肤",
        "description": "100%棉,圆领版型,适合日常通勤与休闲穿搭。",
    }

    text = product_enrich._build_prompt_input_text(product)

    assert text.startswith("T恤")
    assert "夏季透气纯棉短袖" in text
    assert "100%棉" in text


def test_build_prompt_input_text_truncates_non_cjk_by_words():
    product = {
        "title": "dress",
        "brief": " ".join(f"brief{i}" for i in range(50)),
        "description": " ".join(f"desc{i}" for i in range(50)),
    }

    text = product_enrich._build_prompt_input_text(product)

    assert len(text.split()) <= product_enrich.PROMPT_INPUT_MAX_WORDS