""" SKU selection for style-intent-aware search results. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple import numpy as np from query.style_intent import StyleIntentProfile, StyleIntentRegistry from query.tokenization import normalize_query_text @dataclass(frozen=True) class SkuSelectionDecision: selected_sku_id: Optional[str] rerank_suffix: str selected_text: str matched_stage: str similarity_score: Optional[float] = None resolved_dimensions: Dict[str, Optional[str]] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: return { "selected_sku_id": self.selected_sku_id, "rerank_suffix": self.rerank_suffix, "selected_text": self.selected_text, "matched_stage": self.matched_stage, "similarity_score": self.similarity_score, "resolved_dimensions": dict(self.resolved_dimensions), } @dataclass class _SkuCandidate: index: int sku_id: str sku: Dict[str, Any] selection_text: str intent_texts: Dict[str, str] class StyleSkuSelector: """Selects the best SKU for an SPU based on detected style intent.""" def __init__( self, registry: StyleIntentRegistry, *, text_encoder_getter: Optional[Callable[[], Any]] = None, tokenizer_getter: Optional[Callable[[], Any]] = None, ) -> None: self.registry = registry self._text_encoder_getter = text_encoder_getter self._tokenizer_getter = tokenizer_getter def prepare_hits( self, es_hits: List[Dict[str, Any]], parsed_query: Any, ) -> Dict[str, SkuSelectionDecision]: decisions: Dict[str, SkuSelectionDecision] = {} style_profile = getattr(parsed_query, "style_intent_profile", None) if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active: return decisions query_texts = self._build_query_texts(parsed_query, style_profile) query_vector = self._get_query_vector(parsed_query) tokenizer = self._get_tokenizer() for hit in es_hits: source = hit.get("_source") if not isinstance(source, dict): continue decision = self._select_for_source( source, style_profile=style_profile, query_texts=query_texts, query_vector=query_vector, tokenizer=tokenizer, ) if decision is None: continue self._apply_decision_to_source(source, decision) if decision.rerank_suffix: hit["_style_rerank_suffix"] = decision.rerank_suffix doc_id = hit.get("_id") if doc_id is not None: decisions[str(doc_id)] = decision return decisions def apply_precomputed_decisions( self, es_hits: List[Dict[str, Any]], decisions: Dict[str, SkuSelectionDecision], ) -> None: if not es_hits or not decisions: return for hit in es_hits: doc_id = hit.get("_id") if doc_id is None: continue decision = decisions.get(str(doc_id)) if decision is None: continue source = hit.get("_source") if not isinstance(source, dict): continue self._apply_decision_to_source(source, decision) if decision.rerank_suffix: hit["_style_rerank_suffix"] = decision.rerank_suffix def _build_query_texts( self, parsed_query: Any, style_profile: StyleIntentProfile, ) -> List[str]: texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text] if texts: return list(dict.fromkeys(texts)) fallbacks: List[str] = [] for value in ( getattr(parsed_query, "original_query", None), getattr(parsed_query, "query_normalized", None), getattr(parsed_query, "rewritten_query", None), ): normalized = normalize_query_text(value) if normalized: fallbacks.append(normalized) translations = getattr(parsed_query, "translations", {}) or {} if isinstance(translations, dict): for value in translations.values(): normalized = normalize_query_text(value) if normalized: fallbacks.append(normalized) return list(dict.fromkeys(fallbacks)) def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]: query_vector = getattr(parsed_query, "query_vector", None) if query_vector is not None: return np.asarray(query_vector, dtype=np.float32) text_encoder = self._get_text_encoder() if text_encoder is None: return None query_text = ( getattr(parsed_query, "rewritten_query", None) or getattr(parsed_query, "query_normalized", None) or getattr(parsed_query, "original_query", None) ) if not query_text: return None vectors = text_encoder.encode([query_text], priority=1) if vectors is None or len(vectors) == 0 or vectors[0] is None: return None return np.asarray(vectors[0], dtype=np.float32) def _get_text_encoder(self) -> Any: if self._text_encoder_getter is None: return None return self._text_encoder_getter() def _get_tokenizer(self) -> Any: if self._tokenizer_getter is None: return None return self._tokenizer_getter() @staticmethod def _fallback_sku_text(sku: Dict[str, Any]) -> str: parts = [] for field_name in ("option1_value", "option2_value", "option3_value"): value = str(sku.get(field_name) or "").strip() if value: parts.append(value) return " ".join(parts) def _resolve_dimensions( self, source: Dict[str, Any], style_profile: StyleIntentProfile, ) -> Dict[str, Optional[str]]: option_names = { "option1_value": normalize_query_text(source.get("option1_name")), "option2_value": normalize_query_text(source.get("option2_name")), "option3_value": normalize_query_text(source.get("option3_name")), } resolved: Dict[str, Optional[str]] = {} for intent in style_profile.intents: if intent.intent_type in resolved: continue aliases = set(intent.dimension_aliases or self.registry.get_dimension_aliases(intent.intent_type)) matched_field = None for field_name, option_name in option_names.items(): if option_name and option_name in aliases: matched_field = field_name break resolved[intent.intent_type] = matched_field return resolved def _build_candidates( self, skus: List[Dict[str, Any]], resolved_dimensions: Dict[str, Optional[str]], ) -> List[_SkuCandidate]: candidates: List[_SkuCandidate] = [] for index, sku in enumerate(skus): fallback_text = self._fallback_sku_text(sku) intent_texts: Dict[str, str] = {} for intent_type, field_name in resolved_dimensions.items(): if field_name: value = str(sku.get(field_name) or "").strip() intent_texts[intent_type] = value or fallback_text else: intent_texts[intent_type] = fallback_text selection_parts: List[str] = [] seen = set() for value in intent_texts.values(): normalized = normalize_query_text(value) if not normalized or normalized in seen: continue seen.add(normalized) selection_parts.append(str(value).strip()) selection_text = " ".join(selection_parts).strip() or fallback_text candidates.append( _SkuCandidate( index=index, sku_id=str(sku.get("sku_id") or ""), sku=sku, selection_text=selection_text, intent_texts=intent_texts, ) ) return candidates @staticmethod def _is_direct_match( candidate: _SkuCandidate, query_texts: Sequence[str], ) -> bool: if not candidate.intent_texts or not query_texts: return False for value in candidate.intent_texts.values(): normalized_value = normalize_query_text(value) if not normalized_value: return False if not any(normalized_value in query_text for query_text in query_texts): return False return True def _is_generalized_match( self, candidate: _SkuCandidate, style_profile: StyleIntentProfile, tokenizer: Any, ) -> bool: if not candidate.intent_texts: return False for intent_type, value in candidate.intent_texts.items(): definition = self.registry.get_definition(intent_type) if definition is None: return False matched_canonicals = definition.match_text(value, tokenizer=tokenizer) if not matched_canonicals.intersection(style_profile.get_canonical_values(intent_type)): return False return True def _select_by_embedding( self, candidates: Sequence[_SkuCandidate], query_vector: Optional[np.ndarray], ) -> Tuple[Optional[_SkuCandidate], Optional[float]]: if not candidates: return None, None text_encoder = self._get_text_encoder() if query_vector is None or text_encoder is None: return candidates[0], None unique_texts = list( dict.fromkeys( normalize_query_text(candidate.selection_text) for candidate in candidates if normalize_query_text(candidate.selection_text) ) ) if not unique_texts: return candidates[0], None vectors = text_encoder.encode(unique_texts, priority=1) vector_map: Dict[str, np.ndarray] = {} for key, vector in zip(unique_texts, vectors): if vector is None: continue vector_map[key] = np.asarray(vector, dtype=np.float32) best_candidate: Optional[_SkuCandidate] = None best_score: Optional[float] = None query_vector_array = np.asarray(query_vector, dtype=np.float32) for candidate in candidates: normalized_text = normalize_query_text(candidate.selection_text) candidate_vector = vector_map.get(normalized_text) if candidate_vector is None: continue score = float(np.inner(query_vector_array, candidate_vector)) if best_score is None or score > best_score: best_candidate = candidate best_score = score return best_candidate or candidates[0], best_score def _select_for_source( self, source: Dict[str, Any], *, style_profile: StyleIntentProfile, query_texts: Sequence[str], query_vector: Optional[np.ndarray], tokenizer: Any, ) -> Optional[SkuSelectionDecision]: skus = source.get("skus") if not isinstance(skus, list) or not skus: return None resolved_dimensions = self._resolve_dimensions(source, style_profile) candidates = self._build_candidates(skus, resolved_dimensions) if not candidates: return None direct_matches = [candidate for candidate in candidates if self._is_direct_match(candidate, query_texts)] if len(direct_matches) == 1: chosen = direct_matches[0] return self._build_decision(chosen, resolved_dimensions, matched_stage="direct") generalized_matches: List[_SkuCandidate] = [] if not direct_matches: generalized_matches = [ candidate for candidate in candidates if self._is_generalized_match(candidate, style_profile, tokenizer) ] if len(generalized_matches) == 1: chosen = generalized_matches[0] return self._build_decision(chosen, resolved_dimensions, matched_stage="generalized") embedding_pool = direct_matches or generalized_matches or candidates chosen, similarity_score = self._select_by_embedding(embedding_pool, query_vector) if chosen is None: return None stage = "embedding_from_matches" if direct_matches or generalized_matches else "embedding_from_all" return self._build_decision( chosen, resolved_dimensions, matched_stage=stage, similarity_score=similarity_score, ) @staticmethod def _build_decision( candidate: _SkuCandidate, resolved_dimensions: Dict[str, Optional[str]], *, matched_stage: str, similarity_score: Optional[float] = None, ) -> SkuSelectionDecision: return SkuSelectionDecision( selected_sku_id=candidate.sku_id or None, rerank_suffix=str(candidate.selection_text or "").strip(), selected_text=str(candidate.selection_text or "").strip(), matched_stage=matched_stage, similarity_score=similarity_score, resolved_dimensions=dict(resolved_dimensions), ) @staticmethod def _apply_decision_to_source(source: Dict[str, Any], decision: SkuSelectionDecision) -> None: skus = source.get("skus") if not isinstance(skus, list) or not skus or not decision.selected_sku_id: return selected_index = None for index, sku in enumerate(skus): if str(sku.get("sku_id") or "") == decision.selected_sku_id: selected_index = index break if selected_index is None: return selected_sku = skus.pop(selected_index) skus.insert(0, selected_sku) image_src = selected_sku.get("image_src") or selected_sku.get("imageSrc") if image_src: source["image_url"] = image_src