test_suggestions.py 11.8 KB
import json
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List

import pytest

from suggestion.builder import (
    QueryDelta,
    SuggestionIndexBuilder,
    get_suggestion_alias_name,
)
from suggestion.service import SuggestionService


class FakeESClient:
    """Lightweight fake ES client for suggestion unit tests."""

    def __init__(self) -> None:
        self.calls: List[Dict[str, Any]] = []
        self.indices: set[str] = set()
        self.aliases: Dict[str, List[str]] = {}
        self.client = self  # support service._completion_suggest -> self.es_client.client.search

    def search(
        self,
        index_name: str = None,
        body: Dict[str, Any] = None,
        size: int = 10,
        from_: int = 0,
        routing: str = None,
        index: str = None,
        **kwargs,
    ) -> Dict[str, Any]:
        idx = index_name or index
        body = body or {}
        self.calls.append(
            {
                "op": "search",
                "index": idx,
                "body": body,
                "size": size,
                "from": from_,
                "routing": routing,
            }
        )

        # Completion suggest path
        if "suggest" in body:
            return {
                "suggest": {
                    "s": [
                        {
                            "text": "iph",
                            "offset": 0,
                            "length": 3,
                            "options": [
                                {
                                    "text": "iphone 15",
                                    "_score": 6.3,
                                    "_source": {
                                        "text": "iphone 15",
                                        "lang": "en",
                                        "rank_score": 5.0,
                                        "sources": ["query_log", "qanchor"],
                                        "lang_source": "log_field",
                                        "lang_confidence": 1.0,
                                        "lang_conflict": False,
                                    },
                                }
                            ],
                        }
                    ]
                }
            }

        # bool_prefix path
        if idx and "search_suggestions_tenant_" in idx:
            return {
                "hits": {
                    "total": {"value": 1},
                    "max_score": 3.2,
                    "hits": [
                        {
                            "_id": "1",
                            "_score": 3.2,
                            "_source": {
                                "text": "iphone 15",
                                "lang": "en",
                                "rank_score": 5.0,
                                "sources": ["query_log", "qanchor"],
                                "lang_source": "log_field",
                                "lang_confidence": 1.0,
                                "lang_conflict": False,
                            },
                        }
                    ],
                }
            }

        return {"hits": {"total": {"value": 0}, "max_score": 0.0, "hits": []}}

    def bulk_index(self, index_name: str, docs: List[Dict[str, Any]]) -> Dict[str, Any]:
        self.calls.append({"op": "bulk_index", "index": index_name, "docs": docs})
        return {"success": len(docs), "failed": 0, "errors": []}

    def bulk_actions(self, actions: List[Dict[str, Any]]) -> Dict[str, Any]:
        self.calls.append({"op": "bulk_actions", "actions": actions})
        return {"success": len(actions), "failed": 0, "errors": []}

    def index_exists(self, index_name: str) -> bool:
        return index_name in self.indices

    def delete_index(self, index_name: str) -> bool:
        if index_name in self.indices:
            self.indices.remove(index_name)
            return True
        return False

    def create_index(self, index_name: str, body: Dict[str, Any]) -> bool:
        self.calls.append({"op": "create_index", "index": index_name, "body": body})
        self.indices.add(index_name)
        return True

    def refresh(self, index_name: str) -> bool:
        self.calls.append({"op": "refresh", "index": index_name})
        return True

    def alias_exists(self, alias_name: str) -> bool:
        return alias_name in self.aliases and len(self.aliases[alias_name]) > 0

    def get_alias_indices(self, alias_name: str) -> List[str]:
        return list(self.aliases.get(alias_name, []))

    def update_aliases(self, actions: List[Dict[str, Any]]) -> bool:
        self.calls.append({"op": "update_aliases", "actions": actions})
        for action in actions:
            if "remove" in action:
                alias = action["remove"]["alias"]
                index = action["remove"]["index"]
                self.aliases[alias] = [x for x in self.aliases.get(alias, []) if x != index]
            if "add" in action:
                alias = action["add"]["alias"]
                index = action["add"]["index"]
                self.aliases[alias] = [index]
        return True

    def list_indices(self, pattern: str) -> List[str]:
        prefix = pattern.rstrip("*")
        return sorted([x for x in self.indices if x.startswith(prefix)])


@pytest.mark.unit
def test_resolve_query_language_prefers_log_field():
    fake_es = FakeESClient()
    builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None)
    lang, conf, source, conflict = builder._resolve_query_language(
        query="iphone 15",
        log_language="en",
        request_params=None,
        index_languages=["zh", "en"],
        primary_language="zh",
    )
    assert lang == "en"
    assert conf == 1.0
    assert source == "log_field"
    assert conflict is False


@pytest.mark.unit
def test_resolve_query_language_uses_request_params_when_log_missing():
    fake_es = FakeESClient()
    builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None)
    request_params = json.dumps({"language": "zh"})
    lang, conf, source, conflict = builder._resolve_query_language(
        query="芭比娃娃",
        log_language=None,
        request_params=request_params,
        index_languages=["zh", "en"],
        primary_language="en",
    )
    assert lang == "zh"
    assert conf == 1.0
    assert source == "request_params"
    assert conflict is False


@pytest.mark.unit
def test_resolve_query_language_fallback_to_primary():
    fake_es = FakeESClient()
    builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None)
    lang, conf, source, conflict = builder._resolve_query_language(
        query="123",
        log_language=None,
        request_params=None,
        index_languages=["zh", "en"],
        primary_language="zh",
    )
    assert lang == "zh"
    assert source == "default"
    assert conflict is False


@pytest.mark.unit
def test_suggestion_service_basic_flow_uses_alias_and_routing():
    from config import tenant_config_loader as tcl

    loader = tcl.get_tenant_config_loader()
    loader._config = {
        "default": {"primary_language": "en", "index_languages": ["en", "zh"]},
        "tenants": {
            "1": {"primary_language": "en", "index_languages": ["en", "zh"]},
        },
    }

    fake_es = FakeESClient()
    alias_name = get_suggestion_alias_name("1")
    fake_es.aliases[alias_name] = ["search_suggestions_tenant_1_v20260310190000"]

    service = SuggestionService(es_client=fake_es)
    result = service.search(
        tenant_id="1",
        query="iph",
        language="en",
        size=5,
    )

    assert result["resolved_language"] == "en"
    assert result["query"] == "iph"
    assert result["took_ms"] >= 0
    suggestions = result["suggestions"]
    assert len(suggestions) == 1
    assert suggestions[0]["text"] == "iphone 15"

    search_calls = [x for x in fake_es.calls if x.get("op") == "search"]
    assert len(search_calls) >= 2
    assert any(x.get("routing") == "1" for x in search_calls)
    assert any(x.get("index") == alias_name for x in search_calls)


@pytest.mark.unit
def test_publish_alias_and_cleanup_old_versions(monkeypatch):
    fake_es = FakeESClient()
    builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None)

    tenant_id = "162"
    alias_name = get_suggestion_alias_name(tenant_id)
    fake_es.indices.update(
        {
            "search_suggestions_tenant_162_v20260310170000",
            "search_suggestions_tenant_162_v20260310180000",
            "search_suggestions_tenant_162_v20260310190000",
        }
    )
    fake_es.aliases[alias_name] = ["search_suggestions_tenant_162_v20260310180000"]

    monkeypatch.setattr(builder, "_upsert_meta", lambda tenant_id, patch: None)

    result = builder._publish_alias(
        tenant_id=tenant_id,
        index_name="search_suggestions_tenant_162_v20260310190000",
        keep_versions=2,
    )

    assert result["current_index"] == "search_suggestions_tenant_162_v20260310190000"
    assert fake_es.aliases[alias_name] == ["search_suggestions_tenant_162_v20260310190000"]
    assert "search_suggestions_tenant_162_v20260310170000" not in fake_es.indices


@pytest.mark.unit
def test_incremental_bootstrap_when_no_active_index(monkeypatch):
    fake_es = FakeESClient()
    builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None)

    from config import tenant_config_loader as tcl

    loader = tcl.get_tenant_config_loader()
    loader._config = {
        "default": {"primary_language": "en", "index_languages": ["en", "zh"]},
        "tenants": {"162": {"primary_language": "en", "index_languages": ["en", "zh"]}},
    }

    monkeypatch.setattr(
        builder,
        "rebuild_tenant_index",
        lambda **kwargs: {"mode": "full", "tenant_id": kwargs["tenant_id"], "index_name": "v_idx"},
    )

    result = builder.incremental_update_tenant_index(tenant_id="162", bootstrap_if_missing=True)
    assert result["mode"] == "incremental"
    assert result["bootstrapped"] is True
    assert result["bootstrap_result"]["mode"] == "full"


@pytest.mark.unit
def test_incremental_updates_existing_index(monkeypatch):
    fake_es = FakeESClient()
    builder = SuggestionIndexBuilder(es_client=fake_es, db_engine=None)

    from config import tenant_config_loader as tcl

    loader = tcl.get_tenant_config_loader()
    loader._config = {
        "default": {"primary_language": "en", "index_languages": ["en", "zh"]},
        "tenants": {"162": {"primary_language": "en", "index_languages": ["en", "zh"]}},
    }

    tenant_id = "162"
    alias_name = get_suggestion_alias_name(tenant_id)
    active_index = "search_suggestions_tenant_162_v20260310190000"
    fake_es.aliases[alias_name] = [active_index]

    watermark = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()
    monkeypatch.setattr(builder, "_get_meta", lambda _tenant_id: {"last_incremental_watermark": watermark})
    monkeypatch.setattr(builder, "_upsert_meta", lambda tenant_id, patch: None)

    monkeypatch.setattr(
        builder,
        "_build_incremental_deltas",
        lambda **kwargs: {
            ("en", "iphone 15"): QueryDelta(
                tenant_id=tenant_id,
                lang="en",
                text="iphone 15",
                text_norm="iphone 15",
                delta_7d=2,
                delta_30d=3,
                lang_confidence=1.0,
                lang_source="log_field",
                lang_conflict=False,
            )
        },
    )

    result = builder.incremental_update_tenant_index(
        tenant_id=tenant_id,
        bootstrap_if_missing=False,
        overlap_minutes=10,
    )

    assert result["mode"] == "incremental"
    assert result["target_index"] == active_index
    assert result["updated_terms"] == 1
    assert result["bulk_result"]["failed"] == 0

    bulk_calls = [x for x in fake_es.calls if x.get("op") == "bulk_actions"]
    assert len(bulk_calls) == 1
    assert len(bulk_calls[0]["actions"]) == 1