""" Style intent detection for query understanding. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple from .tokenization import QueryTextAnalysisCache, TokenizedText, normalize_query_text, tokenize_text @dataclass(frozen=True) class StyleIntentTermDefinition: canonical_value: str en_terms: Tuple[str, ...] zh_terms: Tuple[str, ...] attribute_terms: Tuple[str, ...] @dataclass(frozen=True) class StyleIntentDefinition: intent_type: str terms: Tuple[StyleIntentTermDefinition, ...] dimension_aliases: Tuple[str, ...] en_synonym_to_term: Dict[str, StyleIntentTermDefinition] zh_synonym_to_term: Dict[str, StyleIntentTermDefinition] max_term_ngram: int = 3 @classmethod def from_rows( cls, intent_type: str, rows: Sequence[Dict[str, List[str]]], dimension_aliases: Sequence[str], ) -> "StyleIntentDefinition": terms: List[StyleIntentTermDefinition] = [] en_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {} zh_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {} max_ngram = 1 for row in rows: normalized_en = tuple( dict.fromkeys( term for term in (normalize_query_text(raw) for raw in row.get("en_terms", [])) if term ) ) normalized_zh = tuple( dict.fromkeys( term for term in (normalize_query_text(raw) for raw in row.get("zh_terms", [])) if term ) ) normalized_attribute = tuple( dict.fromkeys( term for term in (normalize_query_text(raw) for raw in row.get("attribute_terms", [])) if term ) ) if not normalized_en and not normalized_zh and not normalized_attribute: continue canonical = ( normalized_attribute[0] if normalized_attribute else normalized_en[0] if normalized_en else normalized_zh[0] ) term_definition = StyleIntentTermDefinition( canonical_value=canonical, en_terms=normalized_en, zh_terms=normalized_zh, attribute_terms=normalized_attribute, ) terms.append(term_definition) for term in normalized_en: en_synonym_to_term[term] = term_definition max_ngram = max(max_ngram, len(term.split())) for term in normalized_zh: zh_synonym_to_term[term] = term_definition max_ngram = max(max_ngram, len(term.split())) aliases = tuple( dict.fromkeys( term for term in ( normalize_query_text(alias) for alias in dimension_aliases ) if term ) ) return cls( intent_type=intent_type, terms=tuple(terms), dimension_aliases=aliases, en_synonym_to_term=en_synonym_to_term, zh_synonym_to_term=zh_synonym_to_term, max_term_ngram=max_ngram, ) def match_candidates(self, candidates: Iterable[str], *, language: str) -> Set[StyleIntentTermDefinition]: mapping = self.zh_synonym_to_term if language == "zh" else self.en_synonym_to_term matched: Set[StyleIntentTermDefinition] = set() for candidate in candidates: term_definition = mapping.get(normalize_query_text(candidate)) if term_definition: matched.add(term_definition) return matched def match_text( self, text: str, *, language: str, tokenizer: Optional[Callable[[str], Any]] = None, ) -> Set[StyleIntentTermDefinition]: bundle = tokenize_text(text, tokenizer=tokenizer, max_ngram=self.max_term_ngram) return self.match_candidates(bundle.candidates, language=language) @dataclass(frozen=True) class DetectedStyleIntent: intent_type: str canonical_value: str matched_term: str matched_query_text: str attribute_terms: Tuple[str, ...] dimension_aliases: Tuple[str, ...] def to_dict(self) -> Dict[str, Any]: return { "intent_type": self.intent_type, "canonical_value": self.canonical_value, "matched_term": self.matched_term, "matched_query_text": self.matched_query_text, "attribute_terms": list(self.attribute_terms), "dimension_aliases": list(self.dimension_aliases), } @dataclass(frozen=True) class StyleIntentProfile: query_variants: Tuple[TokenizedText, ...] = field(default_factory=tuple) intents: Tuple[DetectedStyleIntent, ...] = field(default_factory=tuple) @property def is_active(self) -> bool: return bool(self.intents) def get_intents(self, intent_type: Optional[str] = None) -> List[DetectedStyleIntent]: if intent_type is None: return list(self.intents) normalized = normalize_query_text(intent_type) return [intent for intent in self.intents if intent.intent_type == normalized] def get_canonical_values(self, intent_type: str) -> Set[str]: return {intent.canonical_value for intent in self.get_intents(intent_type)} def to_dict(self) -> Dict[str, Any]: return { "active": self.is_active, "intents": [intent.to_dict() for intent in self.intents], "query_variants": [ { "text": variant.text, "normalized_text": variant.normalized_text, "fine_tokens": list(variant.fine_tokens), "coarse_tokens": list(variant.coarse_tokens), "candidates": list(variant.candidates), } for variant in self.query_variants ], } class StyleIntentRegistry: """Holds style intent vocabularies and matching helpers.""" def __init__( self, definitions: Dict[str, StyleIntentDefinition], *, enabled: bool = True, ) -> None: self.definitions = definitions self.enabled = bool(enabled) @classmethod def from_query_config(cls, query_config: Any) -> "StyleIntentRegistry": style_terms = getattr(query_config, "style_intent_terms", {}) or {} dimension_aliases = getattr(query_config, "style_intent_dimension_aliases", {}) or {} definitions: Dict[str, StyleIntentDefinition] = {} for intent_type, rows in style_terms.items(): definition = StyleIntentDefinition.from_rows( intent_type=normalize_query_text(intent_type), rows=rows or [], dimension_aliases=dimension_aliases.get(intent_type, []), ) if definition.terms: definitions[definition.intent_type] = definition return cls( definitions, enabled=bool(getattr(query_config, "style_intent_enabled", True)), ) def get_definition(self, intent_type: str) -> Optional[StyleIntentDefinition]: return self.definitions.get(normalize_query_text(intent_type)) def get_dimension_aliases(self, intent_type: str) -> Tuple[str, ...]: definition = self.get_definition(intent_type) return definition.dimension_aliases if definition else tuple() class StyleIntentDetector: """Detects style intents from parsed query variants.""" def __init__( self, registry: StyleIntentRegistry, *, tokenizer: Optional[Callable[[str], Any]] = None, ) -> None: self.registry = registry self.tokenizer = tokenizer def _max_term_ngram(self) -> int: return max( (definition.max_term_ngram for definition in self.registry.definitions.values()), default=3, ) def _tokenize_text( self, text: str, *, analysis_cache: Optional[QueryTextAnalysisCache] = None, ) -> TokenizedText: max_term_ngram = self._max_term_ngram() if analysis_cache is not None: return analysis_cache.get_tokenized_text(text, max_ngram=max_term_ngram) return tokenize_text( text, tokenizer=self.tokenizer, max_ngram=max_term_ngram, ) def _build_language_variants( self, parsed_query: Any, *, analysis_cache: Optional[QueryTextAnalysisCache] = None, ) -> Dict[str, TokenizedText]: variants: Dict[str, TokenizedText] = {} for language in ("zh", "en"): text = self._get_language_query_text(parsed_query, language).strip() if not text: continue variants[language] = self._tokenize_text( text, analysis_cache=analysis_cache, ) return variants def _build_query_variants( self, parsed_query: Any, *, language_variants: Optional[Dict[str, TokenizedText]] = None, analysis_cache: Optional[QueryTextAnalysisCache] = None, ) -> Tuple[TokenizedText, ...]: seen = set() variants: List[TokenizedText] = [] for variant in (language_variants or self._build_language_variants( parsed_query, analysis_cache=analysis_cache, )).values(): normalized = variant.normalized_text if not normalized or normalized in seen: continue seen.add(normalized) variants.append(variant) return tuple(variants) @staticmethod def _get_language_query_text(parsed_query: Any, language: str) -> str: translations = getattr(parsed_query, "translations", {}) or {} if isinstance(translations, dict): translated = translations.get(language) if translated: return str(translated) return str(getattr(parsed_query, "original_query", "") or "") def _tokenize_language_query( self, parsed_query: Any, language: str, *, language_variants: Optional[Dict[str, TokenizedText]] = None, analysis_cache: Optional[QueryTextAnalysisCache] = None, ) -> Optional[TokenizedText]: if language_variants is not None: return language_variants.get(language) text = self._get_language_query_text(parsed_query, language).strip() if not text: return None return self._tokenize_text( text, analysis_cache=analysis_cache, ) def detect(self, parsed_query: Any) -> StyleIntentProfile: if not self.registry.enabled or not self.registry.definitions: return StyleIntentProfile() analysis_cache = getattr(parsed_query, "_text_analysis_cache", None) language_variants = self._build_language_variants( parsed_query, analysis_cache=analysis_cache, ) query_variants = self._build_query_variants( parsed_query, language_variants=language_variants, analysis_cache=analysis_cache, ) zh_variant = self._tokenize_language_query( parsed_query, "zh", language_variants=language_variants, analysis_cache=analysis_cache, ) en_variant = self._tokenize_language_query( parsed_query, "en", language_variants=language_variants, analysis_cache=analysis_cache, ) detected: List[DetectedStyleIntent] = [] seen_pairs = set() for intent_type, definition in self.registry.definitions.items(): for language, variant, mapping in ( ("zh", zh_variant, definition.zh_synonym_to_term), ("en", en_variant, definition.en_synonym_to_term), ): if variant is None or not mapping: continue matched_terms = definition.match_candidates(variant.candidates, language=language) if not matched_terms: continue for candidate in variant.candidates: normalized_candidate = normalize_query_text(candidate) term_definition = mapping.get(normalized_candidate) if term_definition is None or term_definition not in matched_terms: continue pair = (intent_type, term_definition.canonical_value) if pair in seen_pairs: continue seen_pairs.add(pair) detected.append( DetectedStyleIntent( intent_type=intent_type, canonical_value=term_definition.canonical_value, matched_term=normalized_candidate, matched_query_text=variant.text, attribute_terms=term_definition.attribute_terms, dimension_aliases=definition.dimension_aliases, ) ) break return StyleIntentProfile( query_variants=query_variants, intents=tuple(detected), )