""" SKU selection for style-intent-aware search results. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import numpy as np from query.style_intent import StyleIntentProfile, StyleIntentRegistry from query.tokenization import normalize_query_text @dataclass(frozen=True) class SkuSelectionDecision: selected_sku_id: Optional[str] rerank_suffix: str selected_text: str matched_stage: str similarity_score: Optional[float] = None resolved_dimensions: Dict[str, Optional[str]] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: return { "selected_sku_id": self.selected_sku_id, "rerank_suffix": self.rerank_suffix, "selected_text": self.selected_text, "matched_stage": self.matched_stage, "similarity_score": self.similarity_score, "resolved_dimensions": dict(self.resolved_dimensions), } @dataclass class _SkuCandidate: index: int sku_id: str sku: Dict[str, Any] selection_text: str normalized_selection_text: str intent_values: Dict[str, str] normalized_intent_values: Dict[str, str] @dataclass class _SelectionContext: query_texts: Tuple[str, ...] matched_terms_by_intent: Dict[str, Tuple[str, ...]] query_vector: Optional[np.ndarray] text_match_cache: Dict[Tuple[str, str], bool] = field(default_factory=dict) selection_vector_cache: Dict[str, Optional[np.ndarray]] = field(default_factory=dict) similarity_cache: Dict[str, Optional[float]] = field(default_factory=dict) class StyleSkuSelector: """Selects the best SKU for an SPU based on detected style intent.""" def __init__( self, registry: StyleIntentRegistry, *, text_encoder_getter: Optional[Callable[[], Any]] = None, ) -> None: self.registry = registry self._text_encoder_getter = text_encoder_getter def prepare_hits( self, es_hits: List[Dict[str, Any]], parsed_query: Any, ) -> Dict[str, SkuSelectionDecision]: decisions: Dict[str, SkuSelectionDecision] = {} style_profile = getattr(parsed_query, "style_intent_profile", None) if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active: return decisions selection_context = self._build_selection_context(parsed_query, style_profile) for hit in es_hits: source = hit.get("_source") if not isinstance(source, dict): continue decision = self._select_for_source( source, style_profile=style_profile, selection_context=selection_context, ) if decision is None: continue if decision.rerank_suffix: hit["_style_rerank_suffix"] = decision.rerank_suffix else: hit.pop("_style_rerank_suffix", None) doc_id = hit.get("_id") if doc_id is not None: decisions[str(doc_id)] = decision return decisions def apply_precomputed_decisions( self, es_hits: List[Dict[str, Any]], decisions: Dict[str, SkuSelectionDecision], ) -> None: if not es_hits or not decisions: return for hit in es_hits: doc_id = hit.get("_id") if doc_id is None: continue decision = decisions.get(str(doc_id)) if decision is None: continue source = hit.get("_source") if not isinstance(source, dict): continue self._apply_decision_to_source(source, decision) if decision.rerank_suffix: hit["_style_rerank_suffix"] = decision.rerank_suffix else: hit.pop("_style_rerank_suffix", None) def _build_query_texts( self, parsed_query: Any, style_profile: StyleIntentProfile, ) -> List[str]: texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text] if texts: return list(dict.fromkeys(texts)) fallbacks: List[str] = [] for value in ( getattr(parsed_query, "original_query", None), getattr(parsed_query, "query_normalized", None), getattr(parsed_query, "rewritten_query", None), ): normalized = normalize_query_text(value) if normalized: fallbacks.append(normalized) translations = getattr(parsed_query, "translations", {}) or {} if isinstance(translations, dict): for value in translations.values(): normalized = normalize_query_text(value) if normalized: fallbacks.append(normalized) return list(dict.fromkeys(fallbacks)) def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]: query_vector = getattr(parsed_query, "query_vector", None) if query_vector is not None: return np.asarray(query_vector, dtype=np.float32) text_encoder = self._get_text_encoder() if text_encoder is None: return None query_text = ( getattr(parsed_query, "rewritten_query", None) or getattr(parsed_query, "query_normalized", None) or getattr(parsed_query, "original_query", None) ) if not query_text: return None vectors = text_encoder.encode([query_text], priority=1) if vectors is None or len(vectors) == 0 or vectors[0] is None: return None return np.asarray(vectors[0], dtype=np.float32) def _build_selection_context( self, parsed_query: Any, style_profile: StyleIntentProfile, ) -> _SelectionContext: matched_terms_by_intent: Dict[str, List[str]] = {} for intent in style_profile.intents: normalized_term = normalize_query_text(intent.matched_term) if not normalized_term: continue matched_terms = matched_terms_by_intent.setdefault(intent.intent_type, []) if normalized_term not in matched_terms: matched_terms.append(normalized_term) return _SelectionContext( query_texts=tuple(self._build_query_texts(parsed_query, style_profile)), matched_terms_by_intent={ intent_type: tuple(terms) for intent_type, terms in matched_terms_by_intent.items() }, query_vector=self._get_query_vector(parsed_query), ) def _get_text_encoder(self) -> Any: if self._text_encoder_getter is None: return None return self._text_encoder_getter() def _resolve_dimensions( self, source: Dict[str, Any], style_profile: StyleIntentProfile, ) -> Dict[str, Optional[str]]: option_names = { "option1_value": normalize_query_text(source.get("option1_name")), "option2_value": normalize_query_text(source.get("option2_name")), "option3_value": normalize_query_text(source.get("option3_name")), } resolved: Dict[str, Optional[str]] = {} for intent in style_profile.intents: if intent.intent_type in resolved: continue aliases = set(intent.dimension_aliases or self.registry.get_dimension_aliases(intent.intent_type)) matched_field = None for field_name, option_name in option_names.items(): if option_name and option_name in aliases: matched_field = field_name break resolved[intent.intent_type] = matched_field return resolved def _build_candidates( self, skus: List[Dict[str, Any]], resolved_dimensions: Dict[str, Optional[str]], ) -> List[_SkuCandidate]: if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): return [] candidates: List[_SkuCandidate] = [] for index, sku in enumerate(skus): intent_values: Dict[str, str] = {} normalized_intent_values: Dict[str, str] = {} for intent_type, field_name in resolved_dimensions.items(): if not field_name: continue raw = str(sku.get(field_name) or "").strip() intent_values[intent_type] = raw normalized_intent_values[intent_type] = normalize_query_text(raw) selection_parts: List[str] = [] norm_parts: List[str] = [] seen: set[str] = set() for intent_type, raw in intent_values.items(): nv = normalized_intent_values[intent_type] if not nv or nv in seen: continue seen.add(nv) selection_parts.append(raw) norm_parts.append(nv) selection_text = " ".join(selection_parts).strip() normalized_selection_text = " ".join(norm_parts).strip() candidates.append( _SkuCandidate( index=index, sku_id=str(sku.get("sku_id") or ""), sku=sku, selection_text=selection_text, normalized_selection_text=normalized_selection_text, intent_values=intent_values, normalized_intent_values=normalized_intent_values, ) ) return candidates @staticmethod def _empty_decision( resolved_dimensions: Dict[str, Optional[str]], matched_stage: str, ) -> SkuSelectionDecision: return SkuSelectionDecision( selected_sku_id=None, rerank_suffix="", selected_text="", matched_stage=matched_stage, resolved_dimensions=dict(resolved_dimensions), ) def _is_text_match( self, intent_type: str, value: str, selection_context: _SelectionContext, *, normalized_value: Optional[str] = None, ) -> bool: if normalized_value is None: normalized_value = normalize_query_text(value) if not normalized_value: return False cache_key = (intent_type, normalized_value) cached = selection_context.text_match_cache.get(cache_key) if cached is not None: return cached matched_terms = selection_context.matched_terms_by_intent.get(intent_type, ()) has_term_match = any(term in normalized_value for term in matched_terms if term) query_contains_value = any( normalized_value in query_text for query_text in selection_context.query_texts ) matched = bool(has_term_match or query_contains_value) selection_context.text_match_cache[cache_key] = matched return matched def _find_first_text_match( self, candidates: Sequence[_SkuCandidate], selection_context: _SelectionContext, ) -> Optional[_SkuCandidate]: for candidate in candidates: if candidate.intent_values and all( self._is_text_match( intent_type, value, selection_context, normalized_value=candidate.normalized_intent_values[intent_type], ) for intent_type, value in candidate.intent_values.items() ): return candidate return None def _select_by_embedding( self, candidates: Sequence[_SkuCandidate], selection_context: _SelectionContext, ) -> Tuple[Optional[_SkuCandidate], Optional[float]]: if not candidates: return None, None text_encoder = self._get_text_encoder() if selection_context.query_vector is None or text_encoder is None: return None, None unique_texts = list( dict.fromkeys( candidate.normalized_selection_text for candidate in candidates if candidate.normalized_selection_text and candidate.normalized_selection_text not in selection_context.selection_vector_cache ) ) if unique_texts: vectors = text_encoder.encode(unique_texts, priority=1) for key, vector in zip(unique_texts, vectors): selection_context.selection_vector_cache[key] = ( np.asarray(vector, dtype=np.float32) if vector is not None else None ) best_candidate: Optional[_SkuCandidate] = None best_score: Optional[float] = None query_vector_array = np.asarray(selection_context.query_vector, dtype=np.float32) for candidate in candidates: normalized_text = candidate.normalized_selection_text if not normalized_text: continue score = selection_context.similarity_cache.get(normalized_text) if score is None: candidate_vector = selection_context.selection_vector_cache.get(normalized_text) if candidate_vector is None: selection_context.similarity_cache[normalized_text] = None continue score = float(np.inner(query_vector_array, candidate_vector)) selection_context.similarity_cache[normalized_text] = score if score is None: continue if best_score is None or score > best_score: best_candidate = candidate best_score = score return best_candidate, best_score def _select_for_source( self, source: Dict[str, Any], *, style_profile: StyleIntentProfile, selection_context: _SelectionContext, ) -> Optional[SkuSelectionDecision]: skus = source.get("skus") if not isinstance(skus, list) or not skus: return None resolved_dimensions = self._resolve_dimensions(source, style_profile) if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): return self._empty_decision(resolved_dimensions, matched_stage="unresolved") candidates = self._build_candidates(skus, resolved_dimensions) if not candidates: return self._empty_decision(resolved_dimensions, matched_stage="no_candidates") text_match = self._find_first_text_match(candidates, selection_context) if text_match is not None: return self._build_decision(text_match, resolved_dimensions, matched_stage="text") chosen, similarity_score = self._select_by_embedding(candidates, selection_context) if chosen is None: return self._empty_decision(resolved_dimensions, matched_stage="no_match") return self._build_decision( chosen, resolved_dimensions, matched_stage="embedding", similarity_score=similarity_score, ) @staticmethod def _build_decision( candidate: _SkuCandidate, resolved_dimensions: Dict[str, Optional[str]], *, matched_stage: str, similarity_score: Optional[float] = None, ) -> SkuSelectionDecision: return SkuSelectionDecision( selected_sku_id=candidate.sku_id or None, rerank_suffix=str(candidate.selection_text or "").strip(), selected_text=str(candidate.selection_text or "").strip(), matched_stage=matched_stage, similarity_score=similarity_score, resolved_dimensions=dict(resolved_dimensions), ) @staticmethod def _apply_decision_to_source(source: Dict[str, Any], decision: SkuSelectionDecision) -> None: skus = source.get("skus") if not isinstance(skus, list) or not skus or not decision.selected_sku_id: return selected_index = None for index, sku in enumerate(skus): if str(sku.get("sku_id") or "") == decision.selected_sku_id: selected_index = index break if selected_index is None: return selected_sku = skus.pop(selected_index) skus.insert(0, selected_sku) image_src = selected_sku.get("image_src") or selected_sku.get("imageSrc") if image_src: source["image_url"] = image_src