test_suggestions.py 6.26 KB
import json
from typing import Any, Dict, List

import pytest

from suggestion.builder import SuggestionIndexBuilder
from suggestion.service import SuggestionService


class FakeESClient:
    """Minimal fake ES client for SuggestionService tests."""

    def __init__(self) -> None:
        self.calls: List[Dict[str, Any]] = []

    def search(self, index_name: str, body: Dict[str, Any], size: int = 10, from_: int = 0) -> Dict[str, Any]:
        self.calls.append({"index": index_name, "body": body, "size": size, "from": from_})
        # Suggestion index
        if "search_suggestions_tenant_" in index_name:
            return {
                "hits": {
                    "total": {"value": 1},
                    "max_score": 3.2,
                    "hits": [
                        {
                            "_id": "1",
                            "_score": 3.2,
                            "_source": {
                                "text": "iphone 15",
                                "lang": "en",
                                "rank_score": 5.0,
                                "sources": ["query_log", "qanchor"],
                                "lang_source": "log_field",
                                "lang_confidence": 1.0,
                                "lang_conflict": False,
                                "top_spu_ids": ["12345"],
                            },
                        }
                    ],
                }
            }
        # Product index
        if "search_products_tenant_" in index_name:
            return {
                "hits": {
                    "total": {"value": 1},
                    "max_score": 2.5,
                    "hits": [
                        {
                            "_id": "12345",
                            "_score": 2.5,
                            "_source": {
                                "spu_id": "12345",
                                "title": {"en": "iPhone 15 Pro Max"},
                                "min_price": 999.0,
                                "image_url": "https://example.com/image.jpg",
                                "sales": 100,
                                "total_inventory": 50,
                            },
                        }
                    ],
                }
            }
        return {"hits": {"total": {"value": 0}, "max_score": 0.0, "hits": []}}

    # For builder.bulk_index usage compatibility in full runs (not used in these unit tests)
    def bulk_index(self, index_name: str, docs: List[Dict[str, Any]]) -> Dict[str, Any]:
        self.calls.append({"index": index_name, "bulk": True, "docs": docs})
        return {"success": len(docs), "failed": 0, "errors": []}

    def index_exists(self, index_name: str) -> bool:
        return False

    def delete_index(self, index_name: str) -> bool:
        return True

    def create_index(self, index_name: str, body: Dict[str, Any]) -> bool:
        self.calls.append({"index": index_name, "create": True, "body": body})
        return True

    def refresh(self, index_name: str) -> bool:
        return True


@pytest.mark.unit
def test_resolve_query_language_prefers_log_field(monkeypatch):
    """builder.resolve_query_language 应优先使用日志 language 字段。"""
    fake_es = FakeESClient()
    builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None)
    # index_languages 里包含 en/zh,primary 设为 zh
    lang, conf, source, conflict = builder._resolve_query_language(
        query="iphone 15",
        log_language="en",
        request_params=None,
        index_languages=["zh", "en"],
        primary_language="zh",
    )
    assert lang == "en"
    assert conf == 1.0
    assert source == "log_field"
    assert conflict is False


@pytest.mark.unit
def test_resolve_query_language_uses_request_params_when_log_missing():
    """当日志 language 为空时,应从 request_params.language 解析。"""
    fake_es = FakeESClient()
    builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None)
    request_params = json.dumps({"language": "zh"})
    lang, conf, source, conflict = builder._resolve_query_language(
        query="芭比娃娃",
        log_language=None,
        request_params=request_params,
        index_languages=["zh", "en"],
        primary_language="en",
    )
    assert lang == "zh"
    assert conf == 1.0
    assert source == "request_params"
    assert conflict is False


@pytest.mark.unit
def test_resolve_query_language_fallback_to_primary():
    """当无任何语言线索时,应回落到租户 primary_language。"""
    fake_es = FakeESClient()
    builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None)
    lang, conf, source, conflict = builder._resolve_query_language(
        query="some text",
        log_language=None,
        request_params=None,
        index_languages=["zh", "en"],
        primary_language="zh",
    )
    assert lang == "zh"
    assert source == "default"
    assert conflict is False


@pytest.mark.unit
def test_suggestion_service_basic_flow(monkeypatch):
    """
    SuggestionService.search 应正确调用 ES 并返回 suggestion + products。
    使用 FakeESClient 避免真实 ES 依赖。
    """
    # 覆盖 tenant_config_loader 以避免依赖外部 config.yaml 改动
    from config import tenant_config_loader as tcl

    loader = tcl.get_tenant_config_loader()
    # 强制覆盖内部缓存配置
    loader._config = {
        "default": {"primary_language": "en", "index_languages": ["en", "zh"]},
        "tenants": {
            "1": {"primary_language": "en", "index_languages": ["en", "zh"]},
        },
    }

    fake_es = FakeESClient()
    service = SuggestionService(es_client=fake_es)
    result = service.search(
        tenant_id="1",
        query="iph",
        language="en",
        size=5,
        with_results=True,
        result_size=2,
    )

    assert result["resolved_language"] == "en"
    assert result["query"] == "iph"
    assert result["took_ms"] >= 0
    suggestions = result["suggestions"]
    assert len(suggestions) == 1
    s0 = suggestions[0]
    assert s0["text"] == "iphone 15"
    assert s0["lang"] == "en"
    assert isinstance(s0.get("products"), list)
    assert len(s0["products"]) >= 1
    p0 = s0["products"][0]
    assert p0["spu_id"] == "12345"
    assert "title" in p0
    assert "price" in p0