sku_intent_selector.py 12.2 KB
"""
SKU selection for style-intent-aware search results.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple

from query.style_intent import StyleIntentProfile, StyleIntentRegistry
from query.tokenization import normalize_query_text, simple_tokenize_query


@dataclass(frozen=True)
class SkuSelectionDecision:
    selected_sku_id: Optional[str]
    rerank_suffix: str
    selected_text: str
    matched_stage: str
    similarity_score: Optional[float] = None
    resolved_dimensions: Dict[str, Optional[str]] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "selected_sku_id": self.selected_sku_id,
            "rerank_suffix": self.rerank_suffix,
            "selected_text": self.selected_text,
            "matched_stage": self.matched_stage,
            "similarity_score": self.similarity_score,
            "resolved_dimensions": dict(self.resolved_dimensions),
        }


@dataclass
class _SelectionContext:
    attribute_terms_by_intent: Dict[str, Tuple[str, ...]]
    normalized_text_cache: Dict[str, str] = field(default_factory=dict)
    tokenized_text_cache: Dict[str, Tuple[str, ...]] = field(default_factory=dict)
    text_match_cache: Dict[Tuple[str, str], bool] = field(default_factory=dict)


class StyleSkuSelector:
    """Selects the best SKU for an SPU based on detected style intent."""

    def __init__(
        self,
        registry: StyleIntentRegistry,
        *,
        text_encoder_getter: Optional[Callable[[], Any]] = None,
    ) -> None:
        self.registry = registry
        self._text_encoder_getter = text_encoder_getter

    def prepare_hits(
        self,
        es_hits: List[Dict[str, Any]],
        parsed_query: Any,
    ) -> Dict[str, SkuSelectionDecision]:
        decisions: Dict[str, SkuSelectionDecision] = {}
        style_profile = getattr(parsed_query, "style_intent_profile", None)
        if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active:
            return decisions

        selection_context = self._build_selection_context(style_profile)

        for hit in es_hits:
            source = hit.get("_source")
            if not isinstance(source, dict):
                continue

            decision = self._select_for_source(
                source,
                style_profile=style_profile,
                selection_context=selection_context,
            )
            if decision is None:
                continue

            if decision.rerank_suffix:
                hit["_style_rerank_suffix"] = decision.rerank_suffix
            else:
                hit.pop("_style_rerank_suffix", None)

            doc_id = hit.get("_id")
            if doc_id is not None:
                decisions[str(doc_id)] = decision

        return decisions

    def apply_precomputed_decisions(
        self,
        es_hits: List[Dict[str, Any]],
        decisions: Dict[str, SkuSelectionDecision],
    ) -> None:
        if not es_hits or not decisions:
            return

        for hit in es_hits:
            doc_id = hit.get("_id")
            if doc_id is None:
                continue
            decision = decisions.get(str(doc_id))
            if decision is None:
                continue
            source = hit.get("_source")
            if not isinstance(source, dict):
                continue
            self._apply_decision_to_source(source, decision)
            if decision.rerank_suffix:
                hit["_style_rerank_suffix"] = decision.rerank_suffix
            else:
                hit.pop("_style_rerank_suffix", None)

    def _build_selection_context(
        self,
        style_profile: StyleIntentProfile,
    ) -> _SelectionContext:
        attribute_terms_by_intent: Dict[str, List[str]] = {}
        for intent in style_profile.intents:
            terms = attribute_terms_by_intent.setdefault(intent.intent_type, [])
            for raw_term in intent.attribute_terms:
                normalized_term = normalize_query_text(raw_term)
                if not normalized_term or normalized_term in terms:
                    continue
                terms.append(normalized_term)

        return _SelectionContext(
            attribute_terms_by_intent={
                intent_type: tuple(terms)
                for intent_type, terms in attribute_terms_by_intent.items()
            },
        )

    @staticmethod
    def _normalize_cached(selection_context: _SelectionContext, value: Any) -> str:
        raw = str(value or "").strip()
        if not raw:
            return ""
        cached = selection_context.normalized_text_cache.get(raw)
        if cached is not None:
            return cached
        normalized = normalize_query_text(raw)
        selection_context.normalized_text_cache[raw] = normalized
        return normalized

    def _resolve_dimensions(
        self,
        source: Dict[str, Any],
        style_profile: StyleIntentProfile,
    ) -> Dict[str, Optional[str]]:
        option_names = {
            "option1_value": normalize_query_text(source.get("option1_name")),
            "option2_value": normalize_query_text(source.get("option2_name")),
            "option3_value": normalize_query_text(source.get("option3_name")),
        }
        resolved: Dict[str, Optional[str]] = {}
        for intent in style_profile.intents:
            if intent.intent_type in resolved:
                continue
            aliases = set(intent.dimension_aliases or self.registry.get_dimension_aliases(intent.intent_type))
            matched_field = None
            for field_name, option_name in option_names.items():
                if option_name and option_name in aliases:
                    matched_field = field_name
                    break
            resolved[intent.intent_type] = matched_field
        return resolved

    @staticmethod
    def _empty_decision(
        resolved_dimensions: Dict[str, Optional[str]],
        matched_stage: str,
    ) -> SkuSelectionDecision:
        return SkuSelectionDecision(
            selected_sku_id=None,
            rerank_suffix="",
            selected_text="",
            matched_stage=matched_stage,
            resolved_dimensions=dict(resolved_dimensions),
        )

    def _is_text_match(
        self,
        intent_type: str,
        selection_context: _SelectionContext,
        *,
        normalized_value: str,
    ) -> bool:
        if not normalized_value:
            return False

        cache_key = (intent_type, normalized_value)
        cached = selection_context.text_match_cache.get(cache_key)
        if cached is not None:
            return cached

        attribute_terms = selection_context.attribute_terms_by_intent.get(intent_type, ())
        value_tokens = self._tokenize_cached(selection_context, normalized_value)
        matched = any(
            self._matches_term_tokens(
                term=term,
                value_tokens=value_tokens,
                selection_context=selection_context,
                normalized_value=normalized_value,
            )
            for term in attribute_terms
            if term
        )
        selection_context.text_match_cache[cache_key] = matched
        return matched

    @staticmethod
    def _tokenize_cached(selection_context: _SelectionContext, value: str) -> Tuple[str, ...]:
        normalized_value = normalize_query_text(value)
        if not normalized_value:
            return ()
        cached = selection_context.tokenized_text_cache.get(normalized_value)
        if cached is not None:
            return cached
        tokens = tuple(normalize_query_text(token) for token in simple_tokenize_query(normalized_value) if token)
        selection_context.tokenized_text_cache[normalized_value] = tokens
        return tokens

    def _matches_term_tokens(
        self,
        *,
        term: str,
        value_tokens: Tuple[str, ...],
        selection_context: _SelectionContext,
        normalized_value: str,
    ) -> bool:
        normalized_term = normalize_query_text(term)
        if not normalized_term:
            return False
        if normalized_term == normalized_value:
            return True

        term_tokens = self._tokenize_cached(selection_context, normalized_term)
        if not term_tokens or not value_tokens:
            return normalized_term in normalized_value

        term_length = len(term_tokens)
        value_length = len(value_tokens)
        if term_length > value_length:
            return False

        for start in range(value_length - term_length + 1):
            if value_tokens[start:start + term_length] == term_tokens:
                return True
        return False

    def _find_first_text_match(
        self,
        skus: List[Dict[str, Any]],
        resolved_dimensions: Dict[str, Optional[str]],
        selection_context: _SelectionContext,
    ) -> Optional[Tuple[str, str]]:
        for sku in skus:
            selection_parts: List[str] = []
            seen_parts: set[str] = set()
            matched = True

            for intent_type, field_name in resolved_dimensions.items():
                if not field_name:
                    matched = False
                    break

                raw_value = str(sku.get(field_name) or "").strip()
                normalized_value = self._normalize_cached(selection_context, raw_value)
                if not self._is_text_match(
                    intent_type,
                    selection_context,
                    normalized_value=normalized_value,
                ):
                    matched = False
                    break

                if raw_value and normalized_value not in seen_parts:
                    seen_parts.add(normalized_value)
                    selection_parts.append(raw_value)

            if matched:
                return str(sku.get("sku_id") or ""), " ".join(selection_parts).strip()
        return None

    def _select_for_source(
        self,
        source: Dict[str, Any],
        *,
        style_profile: StyleIntentProfile,
        selection_context: _SelectionContext,
    ) -> Optional[SkuSelectionDecision]:
        skus = source.get("skus")
        if not isinstance(skus, list) or not skus:
            return None

        resolved_dimensions = self._resolve_dimensions(source, style_profile)
        if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()):
            return self._empty_decision(resolved_dimensions, matched_stage="unresolved")

        text_match = self._find_first_text_match(skus, resolved_dimensions, selection_context)
        if text_match is None:
            return self._empty_decision(resolved_dimensions, matched_stage="no_match")
        return self._build_decision(
            selected_sku_id=text_match[0],
            selected_text=text_match[1],
            resolved_dimensions=resolved_dimensions,
            matched_stage="text",
        )

    @staticmethod
    def _build_decision(
        selected_sku_id: str,
        selected_text: str,
        resolved_dimensions: Dict[str, Optional[str]],
        *,
        matched_stage: str,
        similarity_score: Optional[float] = None,
    ) -> SkuSelectionDecision:
        return SkuSelectionDecision(
            selected_sku_id=selected_sku_id or None,
            rerank_suffix=str(selected_text or "").strip(),
            selected_text=str(selected_text or "").strip(),
            matched_stage=matched_stage,
            similarity_score=similarity_score,
            resolved_dimensions=dict(resolved_dimensions),
        )

    @staticmethod
    def _apply_decision_to_source(source: Dict[str, Any], decision: SkuSelectionDecision) -> None:
        skus = source.get("skus")
        if not isinstance(skus, list) or not skus or not decision.selected_sku_id:
            return

        selected_index = None
        for index, sku in enumerate(skus):
            if str(sku.get("sku_id") or "") == decision.selected_sku_id:
                selected_index = index
                break
        if selected_index is None:
            return

        selected_sku = skus.pop(selected_index)
        skus.insert(0, selected_sku)

        image_src = selected_sku.get("image_src") or selected_sku.get("imageSrc")
        if image_src:
            source["image_url"] = image_src