product_title_exclusion.py 8.16 KB
"""
Product title exclusion detection for query understanding.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple

from .tokenization import QueryTextAnalysisCache, TokenizedText, normalize_query_text, tokenize_text


def _dedupe_terms(terms: Iterable[str]) -> List[str]:
    result: List[str] = []
    seen: Set[str] = set()
    for raw_term in terms:
        term = normalize_query_text(raw_term)
        if not term or term in seen:
            continue
        seen.add(term)
        result.append(term)
    return result


@dataclass(frozen=True)
class ProductTitleExclusionRule:
    zh_trigger_terms: Tuple[str, ...]
    en_trigger_terms: Tuple[str, ...]
    zh_title_exclusions: Tuple[str, ...]
    en_title_exclusions: Tuple[str, ...]
    max_term_ngram: int = 3

    @classmethod
    def from_config_row(cls, row: Dict[str, Sequence[str]]) -> Optional["ProductTitleExclusionRule"]:
        zh_trigger_terms = tuple(_dedupe_terms(row.get("zh_trigger_terms") or []))
        en_trigger_terms = tuple(_dedupe_terms(row.get("en_trigger_terms") or []))
        zh_title_exclusions = tuple(_dedupe_terms(row.get("zh_title_exclusions") or []))
        en_title_exclusions = tuple(_dedupe_terms(row.get("en_title_exclusions") or []))
        if not zh_title_exclusions and not en_title_exclusions:
            return None
        if not zh_trigger_terms and not en_trigger_terms:
            return None

        max_ngram = max(
            [1]
            + [len(term.split()) for term in zh_trigger_terms]
            + [len(term.split()) for term in en_trigger_terms]
        )
        return cls(
            zh_trigger_terms=zh_trigger_terms,
            en_trigger_terms=en_trigger_terms,
            zh_title_exclusions=zh_title_exclusions,
            en_title_exclusions=en_title_exclusions,
            max_term_ngram=max_ngram,
        )

    def match_candidates(self, candidates: Iterable[str]) -> Optional[str]:
        normalized_candidates = {normalize_query_text(candidate) for candidate in candidates}
        for term in self.zh_trigger_terms:
            if term in normalized_candidates:
                return term
        for term in self.en_trigger_terms:
            if term in normalized_candidates:
                return term
        return None


@dataclass(frozen=True)
class DetectedProductTitleExclusion:
    matched_term: str
    matched_query_text: str
    zh_title_exclusions: Tuple[str, ...]
    en_title_exclusions: Tuple[str, ...]

    def to_dict(self) -> Dict[str, Any]:
        return {
            "matched_term": self.matched_term,
            "matched_query_text": self.matched_query_text,
            "zh_title_exclusions": list(self.zh_title_exclusions),
            "en_title_exclusions": list(self.en_title_exclusions),
        }


@dataclass(frozen=True)
class ProductTitleExclusionProfile:
    query_variants: Tuple[TokenizedText, ...] = field(default_factory=tuple)
    exclusions: Tuple[DetectedProductTitleExclusion, ...] = field(default_factory=tuple)

    @property
    def is_active(self) -> bool:
        return bool(self.exclusions)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "active": self.is_active,
            "exclusions": [item.to_dict() for item in self.exclusions],
            "query_variants": [
                {
                    "text": variant.text,
                    "normalized_text": variant.normalized_text,
                    "fine_tokens": list(variant.fine_tokens),
                    "coarse_tokens": list(variant.coarse_tokens),
                    "candidates": list(variant.candidates),
                }
                for variant in self.query_variants
            ],
        }

    def all_zh_title_exclusions(self) -> List[str]:
        return _dedupe_terms(
            term
            for item in self.exclusions
            for term in item.zh_title_exclusions
        )

    def all_en_title_exclusions(self) -> List[str]:
        return _dedupe_terms(
            term
            for item in self.exclusions
            for term in item.en_title_exclusions
        )


class ProductTitleExclusionRegistry:
    def __init__(
        self,
        rules: Sequence[ProductTitleExclusionRule],
        *,
        enabled: bool = True,
    ) -> None:
        self.rules = tuple(rules)
        self.enabled = bool(enabled)
        self.max_term_ngram = max((rule.max_term_ngram for rule in self.rules), default=3)

    @classmethod
    def from_query_config(cls, query_config: Any) -> "ProductTitleExclusionRegistry":
        raw_rules = getattr(query_config, "product_title_exclusion_rules", []) or []
        rules: List[ProductTitleExclusionRule] = []
        for row in raw_rules:
            if not isinstance(row, dict):
                continue
            rule = ProductTitleExclusionRule.from_config_row(row)
            if rule is not None:
                rules.append(rule)
        return cls(
            rules,
            enabled=bool(getattr(query_config, "product_title_exclusion_enabled", True)),
        )


class ProductTitleExclusionDetector:
    def __init__(
        self,
        registry: ProductTitleExclusionRegistry,
        *,
        tokenizer: Optional[Callable[[str], Any]] = None,
    ) -> None:
        self.registry = registry
        self.tokenizer = tokenizer

    def _tokenize_text(
        self,
        text: str,
        *,
        analysis_cache: Optional[QueryTextAnalysisCache] = None,
    ) -> TokenizedText:
        if analysis_cache is not None:
            return analysis_cache.get_tokenized_text(
                text,
                max_ngram=self.registry.max_term_ngram,
            )
        return tokenize_text(
            text,
            tokenizer=self.tokenizer,
            max_ngram=self.registry.max_term_ngram,
        )

    def _build_query_variants(self, parsed_query: Any) -> Tuple[TokenizedText, ...]:
        seen = set()
        variants: List[TokenizedText] = []
        analysis_cache = getattr(parsed_query, "_text_analysis_cache", None)
        texts = [
            getattr(parsed_query, "original_query", None),
            getattr(parsed_query, "query_normalized", None),
            getattr(parsed_query, "rewritten_query", None),
        ]

        translations = getattr(parsed_query, "translations", {}) or {}
        if isinstance(translations, dict):
            texts.extend(translations.values())

        for raw_text in texts:
            text = str(raw_text or "").strip()
            if not text:
                continue
            normalized = normalize_query_text(text)
            if not normalized or normalized in seen:
                continue
            seen.add(normalized)
            variants.append(
                self._tokenize_text(
                    text,
                    analysis_cache=analysis_cache,
                )
            )

        return tuple(variants)

    def detect(self, parsed_query: Any) -> ProductTitleExclusionProfile:
        if not self.registry.enabled or not self.registry.rules:
            return ProductTitleExclusionProfile()

        query_variants = self._build_query_variants(parsed_query)
        detected: List[DetectedProductTitleExclusion] = []
        seen_keys = set()

        for variant in query_variants:
            for rule in self.registry.rules:
                matched_term = rule.match_candidates(variant.candidates)
                if not matched_term:
                    continue

                key = (
                    tuple(rule.zh_title_exclusions),
                    tuple(rule.en_title_exclusions),
                )
                if key in seen_keys:
                    continue
                seen_keys.add(key)
                detected.append(
                    DetectedProductTitleExclusion(
                        matched_term=matched_term,
                        matched_query_text=variant.text,
                        zh_title_exclusions=rule.zh_title_exclusions,
                        en_title_exclusions=rule.en_title_exclusions,
                    )
                )

        return ProductTitleExclusionProfile(
            query_variants=query_variants,
            exclusions=tuple(detected),
        )