style_intent.py 13.5 KB
"""
Style intent detection for query understanding.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple

from .tokenization import QueryTextAnalysisCache, TokenizedText, normalize_query_text, tokenize_text


@dataclass(frozen=True)
class StyleIntentTermDefinition:
    canonical_value: str
    en_terms: Tuple[str, ...]
    zh_terms: Tuple[str, ...]
    attribute_terms: Tuple[str, ...]


@dataclass(frozen=True)
class StyleIntentDefinition:
    intent_type: str
    terms: Tuple[StyleIntentTermDefinition, ...]
    dimension_aliases: Tuple[str, ...]
    en_synonym_to_term: Dict[str, StyleIntentTermDefinition]
    zh_synonym_to_term: Dict[str, StyleIntentTermDefinition]
    max_term_ngram: int = 3

    @classmethod
    def from_rows(
        cls,
        intent_type: str,
        rows: Sequence[Dict[str, List[str]]],
        dimension_aliases: Sequence[str],
    ) -> "StyleIntentDefinition":
        terms: List[StyleIntentTermDefinition] = []
        en_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {}
        zh_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {}
        max_ngram = 1

        for row in rows:
            normalized_en = tuple(
                dict.fromkeys(
                    term
                    for term in (normalize_query_text(raw) for raw in row.get("en_terms", []))
                    if term
                )
            )
            normalized_zh = tuple(
                dict.fromkeys(
                    term
                    for term in (normalize_query_text(raw) for raw in row.get("zh_terms", []))
                    if term
                )
            )
            normalized_attribute = tuple(
                dict.fromkeys(
                    term
                    for term in (normalize_query_text(raw) for raw in row.get("attribute_terms", []))
                    if term
                )
            )
            if not normalized_en and not normalized_zh and not normalized_attribute:
                continue

            canonical = (
                normalized_attribute[0]
                if normalized_attribute
                else normalized_en[0]
                if normalized_en
                else normalized_zh[0]
            )
            term_definition = StyleIntentTermDefinition(
                canonical_value=canonical,
                en_terms=normalized_en,
                zh_terms=normalized_zh,
                attribute_terms=normalized_attribute,
            )
            terms.append(term_definition)

            for term in normalized_en:
                en_synonym_to_term[term] = term_definition
                max_ngram = max(max_ngram, len(term.split()))
            for term in normalized_zh:
                zh_synonym_to_term[term] = term_definition
                max_ngram = max(max_ngram, len(term.split()))

        aliases = tuple(
            dict.fromkeys(
                term
                for term in (
                    normalize_query_text(alias)
                    for alias in dimension_aliases
                )
                if term
            )
        )

        return cls(
            intent_type=intent_type,
            terms=tuple(terms),
            dimension_aliases=aliases,
            en_synonym_to_term=en_synonym_to_term,
            zh_synonym_to_term=zh_synonym_to_term,
            max_term_ngram=max_ngram,
        )

    def match_candidates(self, candidates: Iterable[str], *, language: str) -> Set[StyleIntentTermDefinition]:
        mapping = self.zh_synonym_to_term if language == "zh" else self.en_synonym_to_term
        matched: Set[StyleIntentTermDefinition] = set()
        for candidate in candidates:
            term_definition = mapping.get(normalize_query_text(candidate))
            if term_definition:
                matched.add(term_definition)
        return matched

    def match_text(
        self,
        text: str,
        *,
        language: str,
        tokenizer: Optional[Callable[[str], Any]] = None,
    ) -> Set[StyleIntentTermDefinition]:
        bundle = tokenize_text(text, tokenizer=tokenizer, max_ngram=self.max_term_ngram)
        return self.match_candidates(bundle.candidates, language=language)


@dataclass(frozen=True)
class DetectedStyleIntent:
    intent_type: str
    canonical_value: str
    matched_term: str
    matched_query_text: str
    attribute_terms: Tuple[str, ...]
    dimension_aliases: Tuple[str, ...]

    def to_dict(self) -> Dict[str, Any]:
        return {
            "intent_type": self.intent_type,
            "canonical_value": self.canonical_value,
            "matched_term": self.matched_term,
            "matched_query_text": self.matched_query_text,
            "attribute_terms": list(self.attribute_terms),
            "dimension_aliases": list(self.dimension_aliases),
        }


@dataclass(frozen=True)
class StyleIntentProfile:
    query_variants: Tuple[TokenizedText, ...] = field(default_factory=tuple)
    intents: Tuple[DetectedStyleIntent, ...] = field(default_factory=tuple)

    @property
    def is_active(self) -> bool:
        return bool(self.intents)

    def get_intents(self, intent_type: Optional[str] = None) -> List[DetectedStyleIntent]:
        if intent_type is None:
            return list(self.intents)
        normalized = normalize_query_text(intent_type)
        return [intent for intent in self.intents if intent.intent_type == normalized]

    def get_canonical_values(self, intent_type: str) -> Set[str]:
        return {intent.canonical_value for intent in self.get_intents(intent_type)}

    def to_dict(self) -> Dict[str, Any]:
        return {
            "active": self.is_active,
            "intents": [intent.to_dict() for intent in self.intents],
            "query_variants": [
                {
                    "text": variant.text,
                    "normalized_text": variant.normalized_text,
                    "fine_tokens": list(variant.fine_tokens),
                    "coarse_tokens": list(variant.coarse_tokens),
                    "candidates": list(variant.candidates),
                }
                for variant in self.query_variants
            ],
        }


class StyleIntentRegistry:
    """Holds style intent vocabularies and matching helpers."""

    def __init__(
        self,
        definitions: Dict[str, StyleIntentDefinition],
        *,
        enabled: bool = True,
    ) -> None:
        self.definitions = definitions
        self.enabled = bool(enabled)

    @classmethod
    def from_query_config(cls, query_config: Any) -> "StyleIntentRegistry":
        style_terms = getattr(query_config, "style_intent_terms", {}) or {}
        dimension_aliases = getattr(query_config, "style_intent_dimension_aliases", {}) or {}
        definitions: Dict[str, StyleIntentDefinition] = {}

        for intent_type, rows in style_terms.items():
            definition = StyleIntentDefinition.from_rows(
                intent_type=normalize_query_text(intent_type),
                rows=rows or [],
                dimension_aliases=dimension_aliases.get(intent_type, []),
            )
            if definition.terms:
                definitions[definition.intent_type] = definition

        return cls(
            definitions,
            enabled=bool(getattr(query_config, "style_intent_enabled", True)),
        )

    def get_definition(self, intent_type: str) -> Optional[StyleIntentDefinition]:
        return self.definitions.get(normalize_query_text(intent_type))

    def get_dimension_aliases(self, intent_type: str) -> Tuple[str, ...]:
        definition = self.get_definition(intent_type)
        return definition.dimension_aliases if definition else tuple()


class StyleIntentDetector:
    """Detects style intents from parsed query variants."""

    def __init__(
        self,
        registry: StyleIntentRegistry,
        *,
        tokenizer: Optional[Callable[[str], Any]] = None,
    ) -> None:
        self.registry = registry
        self.tokenizer = tokenizer

    def _max_term_ngram(self) -> int:
        return max(
            (definition.max_term_ngram for definition in self.registry.definitions.values()),
            default=3,
        )

    def _tokenize_text(
        self,
        text: str,
        *,
        analysis_cache: Optional[QueryTextAnalysisCache] = None,
    ) -> TokenizedText:
        max_term_ngram = self._max_term_ngram()
        if analysis_cache is not None:
            return analysis_cache.get_tokenized_text(text, max_ngram=max_term_ngram)
        return tokenize_text(
            text,
            tokenizer=self.tokenizer,
            max_ngram=max_term_ngram,
        )

    def _build_language_variants(
        self,
        parsed_query: Any,
        *,
        analysis_cache: Optional[QueryTextAnalysisCache] = None,
    ) -> Dict[str, TokenizedText]:
        variants: Dict[str, TokenizedText] = {}
        for language in ("zh", "en"):
            text = self._get_language_query_text(parsed_query, language).strip()
            if not text:
                continue
            variants[language] = self._tokenize_text(
                text,
                analysis_cache=analysis_cache,
            )
        return variants

    def _build_query_variants(
        self,
        parsed_query: Any,
        *,
        language_variants: Optional[Dict[str, TokenizedText]] = None,
        analysis_cache: Optional[QueryTextAnalysisCache] = None,
    ) -> Tuple[TokenizedText, ...]:
        seen = set()
        variants: List[TokenizedText] = []

        for variant in (language_variants or self._build_language_variants(
            parsed_query,
            analysis_cache=analysis_cache,
        )).values():
            normalized = variant.normalized_text
            if not normalized or normalized in seen:
                continue
            seen.add(normalized)
            variants.append(variant)

        return tuple(variants)

    @staticmethod
    def _get_language_query_text(parsed_query: Any, language: str) -> str:
        translations = getattr(parsed_query, "translations", {}) or {}
        if isinstance(translations, dict):
            translated = translations.get(language)
            if translated:
                return str(translated)
        return str(getattr(parsed_query, "original_query", "") or "")

    def _tokenize_language_query(
        self,
        parsed_query: Any,
        language: str,
        *,
        language_variants: Optional[Dict[str, TokenizedText]] = None,
        analysis_cache: Optional[QueryTextAnalysisCache] = None,
    ) -> Optional[TokenizedText]:
        if language_variants is not None:
            return language_variants.get(language)
        text = self._get_language_query_text(parsed_query, language).strip()
        if not text:
            return None
        return self._tokenize_text(
            text,
            analysis_cache=analysis_cache,
        )

    def detect(self, parsed_query: Any) -> StyleIntentProfile:
        if not self.registry.enabled or not self.registry.definitions:
            return StyleIntentProfile()

        analysis_cache = getattr(parsed_query, "_text_analysis_cache", None)
        language_variants = self._build_language_variants(
            parsed_query,
            analysis_cache=analysis_cache,
        )
        query_variants = self._build_query_variants(
            parsed_query,
            language_variants=language_variants,
            analysis_cache=analysis_cache,
        )
        zh_variant = self._tokenize_language_query(
            parsed_query,
            "zh",
            language_variants=language_variants,
            analysis_cache=analysis_cache,
        )
        en_variant = self._tokenize_language_query(
            parsed_query,
            "en",
            language_variants=language_variants,
            analysis_cache=analysis_cache,
        )
        detected: List[DetectedStyleIntent] = []
        seen_pairs = set()

        for intent_type, definition in self.registry.definitions.items():
            for language, variant, mapping in (
                ("zh", zh_variant, definition.zh_synonym_to_term),
                ("en", en_variant, definition.en_synonym_to_term),
            ):
                if variant is None or not mapping:
                    continue

                matched_terms = definition.match_candidates(variant.candidates, language=language)
                if not matched_terms:
                    continue

                for candidate in variant.candidates:
                    normalized_candidate = normalize_query_text(candidate)
                    term_definition = mapping.get(normalized_candidate)
                    if term_definition is None or term_definition not in matched_terms:
                        continue
                    pair = (intent_type, term_definition.canonical_value)
                    if pair in seen_pairs:
                        continue
                    seen_pairs.add(pair)
                    detected.append(
                        DetectedStyleIntent(
                            intent_type=intent_type,
                            canonical_value=term_definition.canonical_value,
                            matched_term=normalized_candidate,
                            matched_query_text=variant.text,
                            attribute_terms=term_definition.attribute_terms,
                            dimension_aliases=definition.dimension_aliases,
                        )
                    )
                    break

        return StyleIntentProfile(
            query_variants=query_variants,
            intents=tuple(detected),
        )