service.py 10.1 KB
"""
Online suggestion query service.
"""

import logging
import time
from typing import Any, Dict, List, Optional

from config.tenant_config_loader import get_tenant_config_loader
from query.tokenization import simple_tokenize_query
from suggestion.builder import get_suggestion_alias_name
from utils.es_client import ESClient

logger = logging.getLogger(__name__)


def _suggestion_length_factor(text: str) -> float:
    """Down-weight longer strings at query time: factor 1 / sqrt(token_len)."""
    n = max(len(simple_tokenize_query(str(text or ""))), 1)
    return 1.0 / (n ** 0.5)


def _score_with_token_length_penalty(item: Dict[str, Any]) -> float:
    base = float(item.get("score") or 0.0)
    return base * _suggestion_length_factor(str(item.get("text") or ""))


class SuggestionService:
    def __init__(self, es_client: ESClient):
        self.es_client = es_client

    def _resolve_language(self, tenant_id: str, language: str) -> str:
        cfg = get_tenant_config_loader().get_tenant_config(tenant_id)
        index_languages = cfg.get("index_languages") or ["en", "zh"]
        primary = cfg.get("primary_language") or "en"
        lang = (language or "").strip().lower().replace("-", "_")
        if lang in {"zh_tw", "pt_br"}:
            normalized = lang
        else:
            normalized = lang.split("_")[0] if lang else ""
        if normalized in index_languages:
            return normalized
        if primary in index_languages:
            return primary
        return index_languages[0]

    def _resolve_search_target(self, tenant_id: str) -> Optional[str]:
        alias_name = get_suggestion_alias_name(tenant_id)
        if self.es_client.alias_exists(alias_name):
            return alias_name
        return None

    def _completion_suggest(
        self,
        index_name: str,
        query: str,
        lang: str,
        size: int,
        tenant_id: str,
    ) -> List[Dict[str, Any]]:
        """
        Query ES completion suggester from `completion.<lang>`.

        Returns items in the same shape as search hits -> dicts with "text"/"lang"/"score"/"rank_score"/"sources".
        """
        field_name = f"completion.{lang}"
        body = {
            "suggest": {
                "s": {
                    "prefix": query,
                    "completion": {
                        "field": field_name,
                        "size": size,
                        "skip_duplicates": True,
                    },
                }
            },
            "_source": [
                "text",
                "lang",
                "rank_score",
                "sources",
                "lang_source",
                "lang_confidence",
                "lang_conflict",
            ],
        }
        try:
            resp = self.es_client.client.search(index=index_name, body=body, routing=str(tenant_id))
        except Exception as e:
            # completion is an optimization path; never hard-fail the whole endpoint
            logger.warning("Completion suggest failed for index=%s field=%s: %s", index_name, field_name, e)
            return []

        entries = (resp.get("suggest", {}) or {}).get("s", []) or []
        if not entries:
            return []
        options = entries[0].get("options", []) or []
        out: List[Dict[str, Any]] = []
        for opt in options:
            src = opt.get("_source", {}) or {}
            out.append(
                {
                    "text": src.get("text") or opt.get("text"),
                    "lang": src.get("lang") or lang,
                    "score": opt.get("_score", 0.0),
                    "rank_score": src.get("rank_score"),
                    "sources": src.get("sources", []),
                    "lang_source": src.get("lang_source"),
                    "lang_confidence": src.get("lang_confidence"),
                    "lang_conflict": src.get("lang_conflict", False),
                }
            )
        return out

    def search(
        self,
        tenant_id: str,
        query: str,
        language: str,
        size: int = 10,
    ) -> Dict[str, Any]:
        start = time.time()
        query_text = str(query or "").strip()
        resolved_lang = self._resolve_language(tenant_id, language)
        index_name = self._resolve_search_target(tenant_id)
        if not index_name:
            # On a fresh ES cluster the suggestion index might not be built yet.
            # Keep endpoint stable for frontend autocomplete: return empty list instead of 500.
            took_ms = int((time.time() - start) * 1000)
            return {
                "query": query,
                "language": language,
                "resolved_language": resolved_lang,
                "suggestions": [],
                "took_ms": took_ms,
            }

        # Recall path A: completion suggester (fast path, usually enough for short prefix typing)
        t_completion_start = time.time()
        completion_items = self._completion_suggest(
            index_name=index_name,
            query=query_text,
            lang=resolved_lang,
            size=size,
            tenant_id=tenant_id,
        )
        completion_ms = int((time.time() - t_completion_start) * 1000)

        suggestions: List[Dict[str, Any]] = []
        seen_text_norm: set = set()

        def _norm_text(v: Any) -> str:
            return str(v or "").strip().lower()

        def _append_items(items: List[Dict[str, Any]]) -> None:
            for item in items:
                text_val = item.get("text")
                norm = _norm_text(text_val)
                if not norm or norm in seen_text_norm:
                    continue
                seen_text_norm.add(norm)
                suggestions.append(dict(item))

        def _finalize_suggestion_list(items: List[Dict[str, Any]], limit: int) -> List[Dict[str, Any]]:
            out = list(items)
            out.sort(
                key=lambda x: (
                    _score_with_token_length_penalty(x),
                    float(x.get("rank_score") or 0.0),
                ),
                reverse=True,
            )
            return out[:limit]

        _append_items(completion_items)

        # Fast path: avoid a second ES query for short prefixes or when completion already full.
        if len(query_text) <= 2 or len(suggestions) >= size:
            took_ms = int((time.time() - start) * 1000)
            logger.info(
                "suggest completion-fast-return | tenant=%s lang=%s q=%s completion=%d took_ms=%d completion_ms=%d",
                tenant_id,
                resolved_lang,
                query_text,
                len(suggestions),
                took_ms,
                completion_ms,
            )
            return {
                "query": query,
                "language": language,
                "resolved_language": resolved_lang,
                "suggestions": _finalize_suggestion_list(suggestions, size),
                "took_ms": took_ms,
            }

        # Recall path B: bool_prefix on search_as_you_type (fallback/recall补全)
        sat_field = f"sat.{resolved_lang}"
        dsl = {
            "track_total_hits": False,
            "query": {
                "function_score": {
                    "query": {
                        "bool": {
                            "filter": [
                                {"term": {"lang": resolved_lang}},
                                {"term": {"status": 1}},
                            ],
                            "should": [
                                {
                                    "multi_match": {
                                        "query": query_text,
                                        "type": "bool_prefix",
                                        "fields": [sat_field, f"{sat_field}._2gram", f"{sat_field}._3gram"],
                                    }
                                }
                            ],
                            "minimum_should_match": 1,
                        }
                    },
                    "field_value_factor": {
                        "field": "rank_score",
                        "factor": 1.0,
                        "modifier": "log1p",
                        "missing": 0.0,
                    },
                    "boost_mode": "sum",
                    "score_mode": "sum",
                }
            },
            "_source": [
                "text",
                "lang",
                "rank_score",
                "sources",
                "lang_source",
                "lang_confidence",
                "lang_conflict",
            ],
        }
        t_sat_start = time.time()
        es_resp = self.es_client.search(
            index_name=index_name,
            body=dsl,
            size=size,
            from_=0,
            routing=str(tenant_id),
        )
        sat_ms = int((time.time() - t_sat_start) * 1000)
        hits = es_resp.get("hits", {}).get("hits", []) or []

        sat_items: List[Dict[str, Any]] = []
        for hit in hits:
            src = hit.get("_source", {}) or {}
            sat_items.append(
                {
                    "text": src.get("text"),
                    "lang": src.get("lang"),
                    "score": hit.get("_score", 0.0),
                    "rank_score": src.get("rank_score"),
                    "sources": src.get("sources", []),
                    "lang_source": src.get("lang_source"),
                    "lang_confidence": src.get("lang_confidence"),
                    "lang_conflict": src.get("lang_conflict", False),
                }
            )
        _append_items(sat_items)

        took_ms = int((time.time() - start) * 1000)
        logger.info(
            "suggest completion+sat-return | tenant=%s lang=%s q=%s completion=%d sat_hits=%d took_ms=%d completion_ms=%d sat_ms=%d",
            tenant_id,
            resolved_lang,
            query_text,
            len(completion_items),
            len(hits),
            took_ms,
            completion_ms,
            sat_ms,
        )
        return {
            "query": query,
            "language": language,
            "resolved_language": resolved_lang,
            "suggestions": _finalize_suggestion_list(suggestions, size),
            "took_ms": took_ms,
        }