From b712a831f1a6b2735b07e3d342ea680b5a5933a0 Mon Sep 17 00:00:00 2001 From: tangwang Date: Wed, 25 Mar 2026 09:33:16 +0800 Subject: [PATCH] 意图识别策略和性能优化 --- config/dictionaries/product_title_exclusion.tsv | 3 ++- config/dictionaries/style_intent_color.csv | 30 +++++++++++++++--------------- config/dictionaries/style_intent_size.csv | 15 +++++++-------- config/loader.py | 32 +++++++++++++++++++++++--------- config/schema.py | 2 +- query/style_intent.py | 154 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------------------- search/sku_intent_selector.py | 283 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- search/sku_intent_selector___deprecated.py | 452 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_search_rerank_window.py | 1 + tests/test_sku_intent_selector.py | 106 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_style_intent.py | 31 ++++++++++++++++++++++++++++--- 11 files changed, 812 insertions(+), 297 deletions(-) create mode 100644 search/sku_intent_selector___deprecated.py create mode 100644 tests/test_sku_intent_selector.py diff --git a/config/dictionaries/product_title_exclusion.tsv b/config/dictionaries/product_title_exclusion.tsv index 7c10912..cf2ccff 100644 --- a/config/dictionaries/product_title_exclusion.tsv +++ b/config/dictionaries/product_title_exclusion.tsv @@ -1,2 +1,3 @@ # zh triggers en triggers zh title exclusions en title exclusions -修身 fitted 宽松 loose,relaxed,oversized,baggy,slouchy +修身,紧身 fitted,tight 宽松 loose,relaxed,oversized,baggy,slouchy +宽松 loose,relaxed,oversized,baggy,slouchy 修身,紧身 fitted,tight diff --git a/config/dictionaries/style_intent_color.csv b/config/dictionaries/style_intent_color.csv index 4068f18..2385373 100644 --- a/config/dictionaries/style_intent_color.csv +++ b/config/dictionaries/style_intent_color.csv @@ -1,15 +1,15 @@ -black,black,blk,黑,黑色 -white,white,wht,白,白色 -red,red,reddish,红,红色 -blue,blue,blu,蓝,蓝色 -green,green,grn,绿,绿色 -yellow,yellow,ylw,黄,黄色 -pink,pink,粉,粉色 -purple,purple,violet,紫,紫色 -gray,gray,grey,灰,灰色 -brown,brown,棕,棕色,咖啡色 -beige,beige,khaki,米色,卡其色 -navy,navy,navy blue,藏青,藏蓝,深蓝 -silver,silver,银,银色 -gold,gold,金,金色 -orange,orange,橙,橙色 +"black,blk","黑,黑色","black" +"white,wht","白,白色","white" +"red,reddish","红,红色","red" +"blue,blu","蓝,蓝色","blue" +"green,grn","绿,绿色","green" +"yellow,ylw","黄,黄色","yellow" +"pink","粉,粉色","pink" +"purple,violet","紫,紫色","purple" +"gray,grey","灰,灰色","gray,grey" +"brown","棕,棕色,咖啡色","brown" +"beige,khaki","米色,卡其色","beige,khaki" +"navy,navy blue","藏青,藏蓝,深蓝","navy" +"silver","银,银色","silver" +"gold","金,金色","gold" +"orange","橙,橙色","orange" diff --git a/config/dictionaries/style_intent_size.csv b/config/dictionaries/style_intent_size.csv index 011dc26..d865c6f 100644 --- a/config/dictionaries/style_intent_size.csv +++ b/config/dictionaries/style_intent_size.csv @@ -1,8 +1,7 @@ -xs,xs,extra small,x-small,加小码 -s,s,small,小码,小号 -m,m,medium,中码,中号 -l,l,large,大码,大号 -xl,xl,x-large,extra large,加大码 -xxl,xxl,2xl,xx-large,双加大码 -xxxl,xxxl,3xl,xxx-large,三加大码 -one size,one size,onesize,free size,均码 +"xs,extra small,x-small","加小码","xs,extra small,x-small" +"s,small","小码,小号","s,small" +"m,medium","中码,中号","m,medium" +"l,large","大码,大号","l,large" +"xl,x-large,extra large","加大码","xl,x-large,extra large" +"xxl,2xl,xx-large","双加大码","xxl,2xl,xx-large" +"xxxl,3xl,xxx-large","三加大码","xxxl,3xl,xxx-large" diff --git a/config/loader.py b/config/loader.py index e8498fd..96ba410 100644 --- a/config/loader.py +++ b/config/loader.py @@ -10,6 +10,7 @@ from __future__ import annotations import hashlib import json import os +import csv from copy import deepcopy from dataclasses import asdict from functools import lru_cache @@ -96,20 +97,33 @@ def _read_rewrite_dictionary(path: Path) -> Dict[str, str]: return rewrite_dict -def _read_synonym_csv_dictionary(path: Path) -> List[List[str]]: - rows: List[List[str]] = [] +def _read_synonym_csv_dictionary(path: Path) -> List[Dict[str, List[str]]]: + rows: List[Dict[str, List[str]]] = [] if not path.exists(): return rows + def _split_terms(cell: str) -> List[str]: + return [item.strip() for item in str(cell or "").split(",") if item.strip()] + with open(path, "r", encoding="utf-8") as handle: - for raw_line in handle: - line = raw_line.strip() - if not line or line.startswith("#"): + reader = csv.reader(handle) + for parts in reader: + if not parts: + continue + if parts[0].strip().startswith("#"): continue - parts = [segment.strip() for segment in line.split(",")] - normalized = [segment for segment in parts if segment] - if normalized: - rows.append(normalized) + + normalized = [segment.strip() for segment in parts] + if len(normalized) < 3: + continue + + row = { + "en_terms": _split_terms(normalized[0]), + "zh_terms": _split_terms(normalized[1]), + "attribute_terms": _split_terms(normalized[2]), + } + if any(row.values()): + rows.append(row) return rows diff --git a/config/schema.py b/config/schema.py index 3f5300b..e83d79a 100644 --- a/config/schema.py +++ b/config/schema.py @@ -65,7 +65,7 @@ class QueryConfig: translation_embedding_wait_budget_ms_source_in_index: int = 80 translation_embedding_wait_budget_ms_source_not_in_index: int = 200 style_intent_enabled: bool = True - style_intent_terms: Dict[str, List[List[str]]] = field(default_factory=dict) + style_intent_terms: Dict[str, List[Dict[str, List[str]]]] = field(default_factory=dict) style_intent_dimension_aliases: Dict[str, List[str]] = field(default_factory=dict) product_title_exclusion_enabled: bool = True product_title_exclusion_rules: List[Dict[str, List[str]]] = field(default_factory=list) diff --git a/query/style_intent.py b/query/style_intent.py index 13525fc..96e4234 100644 --- a/query/style_intent.py +++ b/query/style_intent.py @@ -11,38 +11,79 @@ from .tokenization import TokenizedText, normalize_query_text, tokenize_text @dataclass(frozen=True) +class StyleIntentTermDefinition: + canonical_value: str + en_terms: Tuple[str, ...] + zh_terms: Tuple[str, ...] + attribute_terms: Tuple[str, ...] + + +@dataclass(frozen=True) class StyleIntentDefinition: intent_type: str - term_groups: Tuple[Tuple[str, ...], ...] + terms: Tuple[StyleIntentTermDefinition, ...] dimension_aliases: Tuple[str, ...] - synonym_to_canonical: Dict[str, str] + en_synonym_to_term: Dict[str, StyleIntentTermDefinition] + zh_synonym_to_term: Dict[str, StyleIntentTermDefinition] max_term_ngram: int = 3 @classmethod def from_rows( cls, intent_type: str, - rows: Sequence[Sequence[str]], + rows: Sequence[Dict[str, List[str]]], dimension_aliases: Sequence[str], ) -> "StyleIntentDefinition": - term_groups: List[Tuple[str, ...]] = [] - synonym_to_canonical: Dict[str, str] = {} + terms: List[StyleIntentTermDefinition] = [] + en_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {} + zh_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {} max_ngram = 1 for row in rows: - normalized_terms: List[str] = [] - for raw_term in row: - term = normalize_query_text(raw_term) - if not term or term in normalized_terms: - continue - normalized_terms.append(term) - if not normalized_terms: + normalized_en = tuple( + dict.fromkeys( + term + for term in (normalize_query_text(raw) for raw in row.get("en_terms", [])) + if term + ) + ) + normalized_zh = tuple( + dict.fromkeys( + term + for term in (normalize_query_text(raw) for raw in row.get("zh_terms", [])) + if term + ) + ) + normalized_attribute = tuple( + dict.fromkeys( + term + for term in (normalize_query_text(raw) for raw in row.get("attribute_terms", [])) + if term + ) + ) + if not normalized_en and not normalized_zh and not normalized_attribute: continue - canonical = normalized_terms[0] - term_groups.append(tuple(normalized_terms)) - for term in normalized_terms: - synonym_to_canonical[term] = canonical + canonical = ( + normalized_attribute[0] + if normalized_attribute + else normalized_en[0] + if normalized_en + else normalized_zh[0] + ) + term_definition = StyleIntentTermDefinition( + canonical_value=canonical, + en_terms=normalized_en, + zh_terms=normalized_zh, + attribute_terms=normalized_attribute, + ) + terms.append(term_definition) + + for term in normalized_en: + en_synonym_to_term[term] = term_definition + max_ngram = max(max_ngram, len(term.split())) + for term in normalized_zh: + zh_synonym_to_term[term] = term_definition max_ngram = max(max_ngram, len(term.split())) aliases = tuple( @@ -58,28 +99,31 @@ class StyleIntentDefinition: return cls( intent_type=intent_type, - term_groups=tuple(term_groups), + terms=tuple(terms), dimension_aliases=aliases, - synonym_to_canonical=synonym_to_canonical, + en_synonym_to_term=en_synonym_to_term, + zh_synonym_to_term=zh_synonym_to_term, max_term_ngram=max_ngram, ) - def match_candidates(self, candidates: Iterable[str]) -> Set[str]: - matched: Set[str] = set() + def match_candidates(self, candidates: Iterable[str], *, language: str) -> Set[StyleIntentTermDefinition]: + mapping = self.zh_synonym_to_term if language == "zh" else self.en_synonym_to_term + matched: Set[StyleIntentTermDefinition] = set() for candidate in candidates: - canonical = self.synonym_to_canonical.get(normalize_query_text(candidate)) - if canonical: - matched.add(canonical) + term_definition = mapping.get(normalize_query_text(candidate)) + if term_definition: + matched.add(term_definition) return matched def match_text( self, text: str, *, + language: str, tokenizer: Optional[Callable[[str], Any]] = None, - ) -> Set[str]: + ) -> Set[StyleIntentTermDefinition]: bundle = tokenize_text(text, tokenizer=tokenizer, max_ngram=self.max_term_ngram) - return self.match_candidates(bundle.candidates) + return self.match_candidates(bundle.candidates, language=language) @dataclass(frozen=True) @@ -88,6 +132,7 @@ class DetectedStyleIntent: canonical_value: str matched_term: str matched_query_text: str + attribute_terms: Tuple[str, ...] dimension_aliases: Tuple[str, ...] def to_dict(self) -> Dict[str, Any]: @@ -96,6 +141,7 @@ class DetectedStyleIntent: "canonical_value": self.canonical_value, "matched_term": self.matched_term, "matched_query_text": self.matched_query_text, + "attribute_terms": list(self.attribute_terms), "dimension_aliases": list(self.dimension_aliases), } @@ -159,7 +205,7 @@ class StyleIntentRegistry: rows=rows or [], dimension_aliases=dimension_aliases.get(intent_type, []), ) - if definition.synonym_to_canonical: + if definition.terms: definitions[definition.intent_type] = definition return cls( @@ -191,15 +237,10 @@ class StyleIntentDetector: seen = set() variants: List[TokenizedText] = [] texts = [ - getattr(parsed_query, "original_query", None), - getattr(parsed_query, "query_normalized", None), - getattr(parsed_query, "rewritten_query", None), + self._get_language_query_text(parsed_query, "zh"), + self._get_language_query_text(parsed_query, "en"), ] - translations = getattr(parsed_query, "translations", {}) or {} - if isinstance(translations, dict): - texts.extend(translations.values()) - for raw_text in texts: text = str(raw_text or "").strip() if not text: @@ -221,35 +262,66 @@ class StyleIntentDetector: return tuple(variants) + @staticmethod + def _get_language_query_text(parsed_query: Any, language: str) -> str: + translations = getattr(parsed_query, "translations", {}) or {} + if isinstance(translations, dict): + translated = translations.get(language) + if translated: + return str(translated) + return str(getattr(parsed_query, "original_query", "") or "") + + def _tokenize_language_query(self, parsed_query: Any, language: str) -> Optional[TokenizedText]: + text = self._get_language_query_text(parsed_query, language).strip() + if not text: + return None + return tokenize_text( + text, + tokenizer=self.tokenizer, + max_ngram=max( + (definition.max_term_ngram for definition in self.registry.definitions.values()), + default=3, + ), + ) + def detect(self, parsed_query: Any) -> StyleIntentProfile: if not self.registry.enabled or not self.registry.definitions: return StyleIntentProfile() query_variants = self._build_query_variants(parsed_query) + zh_variant = self._tokenize_language_query(parsed_query, "zh") + en_variant = self._tokenize_language_query(parsed_query, "en") detected: List[DetectedStyleIntent] = [] seen_pairs = set() - for variant in query_variants: - for intent_type, definition in self.registry.definitions.items(): - matched_canonicals = definition.match_candidates(variant.candidates) - if not matched_canonicals: + for intent_type, definition in self.registry.definitions.items(): + for language, variant, mapping in ( + ("zh", zh_variant, definition.zh_synonym_to_term), + ("en", en_variant, definition.en_synonym_to_term), + ): + if variant is None or not mapping: + continue + + matched_terms = definition.match_candidates(variant.candidates, language=language) + if not matched_terms: continue for candidate in variant.candidates: normalized_candidate = normalize_query_text(candidate) - canonical = definition.synonym_to_canonical.get(normalized_candidate) - if not canonical or canonical not in matched_canonicals: + term_definition = mapping.get(normalized_candidate) + if term_definition is None or term_definition not in matched_terms: continue - pair = (intent_type, canonical) + pair = (intent_type, term_definition.canonical_value) if pair in seen_pairs: continue seen_pairs.add(pair) detected.append( DetectedStyleIntent( intent_type=intent_type, - canonical_value=canonical, + canonical_value=term_definition.canonical_value, matched_term=normalized_candidate, matched_query_text=variant.text, + attribute_terms=term_definition.attribute_terms, dimension_aliases=definition.dimension_aliases, ) ) diff --git a/search/sku_intent_selector.py b/search/sku_intent_selector.py index 9b49e06..de32799 100644 --- a/search/sku_intent_selector.py +++ b/search/sku_intent_selector.py @@ -5,9 +5,7 @@ SKU selection for style-intent-aware search results. from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple - -import numpy as np +from typing import Any, Callable, Dict, List, Optional, Tuple from query.style_intent import StyleIntentProfile, StyleIntentRegistry from query.tokenization import normalize_query_text @@ -34,24 +32,10 @@ class SkuSelectionDecision: @dataclass -class _SkuCandidate: - index: int - sku_id: str - sku: Dict[str, Any] - selection_text: str - normalized_selection_text: str - intent_values: Dict[str, str] - normalized_intent_values: Dict[str, str] - - -@dataclass class _SelectionContext: - query_texts: Tuple[str, ...] - matched_terms_by_intent: Dict[str, Tuple[str, ...]] - query_vector: Optional[np.ndarray] + attribute_terms_by_intent: Dict[str, Tuple[str, ...]] + normalized_text_cache: Dict[str, str] = field(default_factory=dict) text_match_cache: Dict[Tuple[str, str], bool] = field(default_factory=dict) - selection_vector_cache: Dict[str, Optional[np.ndarray]] = field(default_factory=dict) - similarity_cache: Dict[str, Optional[float]] = field(default_factory=dict) class StyleSkuSelector: @@ -76,7 +60,7 @@ class StyleSkuSelector: if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active: return decisions - selection_context = self._build_selection_context(parsed_query, style_profile) + selection_context = self._build_selection_context(style_profile) for hit in es_hits: source = hit.get("_source") @@ -126,81 +110,37 @@ class StyleSkuSelector: else: hit.pop("_style_rerank_suffix", None) - def _build_query_texts( - self, - parsed_query: Any, - style_profile: StyleIntentProfile, - ) -> List[str]: - texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text] - if texts: - return list(dict.fromkeys(texts)) - - fallbacks: List[str] = [] - for value in ( - getattr(parsed_query, "original_query", None), - getattr(parsed_query, "query_normalized", None), - getattr(parsed_query, "rewritten_query", None), - ): - normalized = normalize_query_text(value) - if normalized: - fallbacks.append(normalized) - translations = getattr(parsed_query, "translations", {}) or {} - if isinstance(translations, dict): - for value in translations.values(): - normalized = normalize_query_text(value) - if normalized: - fallbacks.append(normalized) - return list(dict.fromkeys(fallbacks)) - - def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]: - query_vector = getattr(parsed_query, "query_vector", None) - if query_vector is not None: - return np.asarray(query_vector, dtype=np.float32) - - text_encoder = self._get_text_encoder() - if text_encoder is None: - return None - - query_text = ( - getattr(parsed_query, "rewritten_query", None) - or getattr(parsed_query, "query_normalized", None) - or getattr(parsed_query, "original_query", None) - ) - if not query_text: - return None - - vectors = text_encoder.encode([query_text], priority=1) - if vectors is None or len(vectors) == 0 or vectors[0] is None: - return None - return np.asarray(vectors[0], dtype=np.float32) - def _build_selection_context( self, - parsed_query: Any, style_profile: StyleIntentProfile, ) -> _SelectionContext: - matched_terms_by_intent: Dict[str, List[str]] = {} + attribute_terms_by_intent: Dict[str, List[str]] = {} for intent in style_profile.intents: - normalized_term = normalize_query_text(intent.matched_term) - if not normalized_term: - continue - matched_terms = matched_terms_by_intent.setdefault(intent.intent_type, []) - if normalized_term not in matched_terms: - matched_terms.append(normalized_term) + terms = attribute_terms_by_intent.setdefault(intent.intent_type, []) + for raw_term in intent.attribute_terms: + normalized_term = normalize_query_text(raw_term) + if not normalized_term or normalized_term in terms: + continue + terms.append(normalized_term) return _SelectionContext( - query_texts=tuple(self._build_query_texts(parsed_query, style_profile)), - matched_terms_by_intent={ + attribute_terms_by_intent={ intent_type: tuple(terms) - for intent_type, terms in matched_terms_by_intent.items() + for intent_type, terms in attribute_terms_by_intent.items() }, - query_vector=self._get_query_vector(parsed_query), ) - def _get_text_encoder(self) -> Any: - if self._text_encoder_getter is None: - return None - return self._text_encoder_getter() + @staticmethod + def _normalize_cached(selection_context: _SelectionContext, value: Any) -> str: + raw = str(value or "").strip() + if not raw: + return "" + cached = selection_context.normalized_text_cache.get(raw) + if cached is not None: + return cached + normalized = normalize_query_text(raw) + selection_context.normalized_text_cache[raw] = normalized + return normalized def _resolve_dimensions( self, @@ -225,51 +165,6 @@ class StyleSkuSelector: resolved[intent.intent_type] = matched_field return resolved - def _build_candidates( - self, - skus: List[Dict[str, Any]], - resolved_dimensions: Dict[str, Optional[str]], - ) -> List[_SkuCandidate]: - if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): - return [] - - candidates: List[_SkuCandidate] = [] - for index, sku in enumerate(skus): - intent_values: Dict[str, str] = {} - normalized_intent_values: Dict[str, str] = {} - for intent_type, field_name in resolved_dimensions.items(): - if not field_name: - continue - raw = str(sku.get(field_name) or "").strip() - intent_values[intent_type] = raw - normalized_intent_values[intent_type] = normalize_query_text(raw) - - selection_parts: List[str] = [] - norm_parts: List[str] = [] - seen: set[str] = set() - for intent_type, raw in intent_values.items(): - nv = normalized_intent_values[intent_type] - if not nv or nv in seen: - continue - seen.add(nv) - selection_parts.append(raw) - norm_parts.append(nv) - - selection_text = " ".join(selection_parts).strip() - normalized_selection_text = " ".join(norm_parts).strip() - candidates.append( - _SkuCandidate( - index=index, - sku_id=str(sku.get("sku_id") or ""), - sku=sku, - selection_text=selection_text, - normalized_selection_text=normalized_selection_text, - intent_values=intent_values, - normalized_intent_values=normalized_intent_values, - ) - ) - return candidates - @staticmethod def _empty_decision( resolved_dimensions: Dict[str, Optional[str]], @@ -286,13 +181,10 @@ class StyleSkuSelector: def _is_text_match( self, intent_type: str, - value: str, selection_context: _SelectionContext, *, - normalized_value: Optional[str] = None, + normalized_value: str, ) -> bool: - if normalized_value is None: - normalized_value = normalize_query_text(value) if not normalized_value: return False @@ -301,84 +193,44 @@ class StyleSkuSelector: if cached is not None: return cached - matched_terms = selection_context.matched_terms_by_intent.get(intent_type, ()) - has_term_match = any(term in normalized_value for term in matched_terms if term) - query_contains_value = any( - normalized_value in query_text - for query_text in selection_context.query_texts - ) - matched = bool(has_term_match or query_contains_value) + attribute_terms = selection_context.attribute_terms_by_intent.get(intent_type, ()) + matched = any(term in normalized_value for term in attribute_terms if term) selection_context.text_match_cache[cache_key] = matched return matched def _find_first_text_match( self, - candidates: Sequence[_SkuCandidate], + skus: List[Dict[str, Any]], + resolved_dimensions: Dict[str, Optional[str]], selection_context: _SelectionContext, - ) -> Optional[_SkuCandidate]: - for candidate in candidates: - if candidate.intent_values and all( - self._is_text_match( - intent_type, - value, - selection_context, - normalized_value=candidate.normalized_intent_values[intent_type], - ) - for intent_type, value in candidate.intent_values.items() - ): - return candidate - return None + ) -> Optional[Tuple[str, str]]: + for sku in skus: + selection_parts: List[str] = [] + seen_parts: set[str] = set() + matched = True - def _select_by_embedding( - self, - candidates: Sequence[_SkuCandidate], - selection_context: _SelectionContext, - ) -> Tuple[Optional[_SkuCandidate], Optional[float]]: - if not candidates: - return None, None - text_encoder = self._get_text_encoder() - if selection_context.query_vector is None or text_encoder is None: - return None, None - - unique_texts = list( - dict.fromkeys( - candidate.normalized_selection_text - for candidate in candidates - if candidate.normalized_selection_text - and candidate.normalized_selection_text not in selection_context.selection_vector_cache - ) - ) - if unique_texts: - vectors = text_encoder.encode(unique_texts, priority=1) - for key, vector in zip(unique_texts, vectors): - selection_context.selection_vector_cache[key] = ( - np.asarray(vector, dtype=np.float32) if vector is not None else None - ) - - best_candidate: Optional[_SkuCandidate] = None - best_score: Optional[float] = None - query_vector_array = np.asarray(selection_context.query_vector, dtype=np.float32) - for candidate in candidates: - normalized_text = candidate.normalized_selection_text - if not normalized_text: - continue + for intent_type, field_name in resolved_dimensions.items(): + if not field_name: + matched = False + break - score = selection_context.similarity_cache.get(normalized_text) - if score is None: - candidate_vector = selection_context.selection_vector_cache.get(normalized_text) - if candidate_vector is None: - selection_context.similarity_cache[normalized_text] = None - continue - score = float(np.inner(query_vector_array, candidate_vector)) - selection_context.similarity_cache[normalized_text] = score + raw_value = str(sku.get(field_name) or "").strip() + normalized_value = self._normalize_cached(selection_context, raw_value) + if not self._is_text_match( + intent_type, + selection_context, + normalized_value=normalized_value, + ): + matched = False + break - if score is None: - continue - if best_score is None or score > best_score: - best_candidate = candidate - best_score = score + if raw_value and normalized_value not in seen_parts: + seen_parts.add(normalized_value) + selection_parts.append(raw_value) - return best_candidate, best_score + if matched: + return str(sku.get("sku_id") or ""), " ".join(selection_parts).strip() + return None def _select_for_source( self, @@ -395,36 +247,29 @@ class StyleSkuSelector: if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): return self._empty_decision(resolved_dimensions, matched_stage="unresolved") - candidates = self._build_candidates(skus, resolved_dimensions) - if not candidates: - return self._empty_decision(resolved_dimensions, matched_stage="no_candidates") - - text_match = self._find_first_text_match(candidates, selection_context) - if text_match is not None: - return self._build_decision(text_match, resolved_dimensions, matched_stage="text") - - chosen, similarity_score = self._select_by_embedding(candidates, selection_context) - if chosen is None: + text_match = self._find_first_text_match(skus, resolved_dimensions, selection_context) + if text_match is None: return self._empty_decision(resolved_dimensions, matched_stage="no_match") return self._build_decision( - chosen, - resolved_dimensions, - matched_stage="embedding", - similarity_score=similarity_score, + selected_sku_id=text_match[0], + selected_text=text_match[1], + resolved_dimensions=resolved_dimensions, + matched_stage="text", ) @staticmethod def _build_decision( - candidate: _SkuCandidate, + selected_sku_id: str, + selected_text: str, resolved_dimensions: Dict[str, Optional[str]], *, matched_stage: str, similarity_score: Optional[float] = None, ) -> SkuSelectionDecision: return SkuSelectionDecision( - selected_sku_id=candidate.sku_id or None, - rerank_suffix=str(candidate.selection_text or "").strip(), - selected_text=str(candidate.selection_text or "").strip(), + selected_sku_id=selected_sku_id or None, + rerank_suffix=str(selected_text or "").strip(), + selected_text=str(selected_text or "").strip(), matched_stage=matched_stage, similarity_score=similarity_score, resolved_dimensions=dict(resolved_dimensions), diff --git a/search/sku_intent_selector___deprecated.py b/search/sku_intent_selector___deprecated.py new file mode 100644 index 0000000..9b49e06 --- /dev/null +++ b/search/sku_intent_selector___deprecated.py @@ -0,0 +1,452 @@ +""" +SKU selection for style-intent-aware search results. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import numpy as np + +from query.style_intent import StyleIntentProfile, StyleIntentRegistry +from query.tokenization import normalize_query_text + + +@dataclass(frozen=True) +class SkuSelectionDecision: + selected_sku_id: Optional[str] + rerank_suffix: str + selected_text: str + matched_stage: str + similarity_score: Optional[float] = None + resolved_dimensions: Dict[str, Optional[str]] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "selected_sku_id": self.selected_sku_id, + "rerank_suffix": self.rerank_suffix, + "selected_text": self.selected_text, + "matched_stage": self.matched_stage, + "similarity_score": self.similarity_score, + "resolved_dimensions": dict(self.resolved_dimensions), + } + + +@dataclass +class _SkuCandidate: + index: int + sku_id: str + sku: Dict[str, Any] + selection_text: str + normalized_selection_text: str + intent_values: Dict[str, str] + normalized_intent_values: Dict[str, str] + + +@dataclass +class _SelectionContext: + query_texts: Tuple[str, ...] + matched_terms_by_intent: Dict[str, Tuple[str, ...]] + query_vector: Optional[np.ndarray] + text_match_cache: Dict[Tuple[str, str], bool] = field(default_factory=dict) + selection_vector_cache: Dict[str, Optional[np.ndarray]] = field(default_factory=dict) + similarity_cache: Dict[str, Optional[float]] = field(default_factory=dict) + + +class StyleSkuSelector: + """Selects the best SKU for an SPU based on detected style intent.""" + + def __init__( + self, + registry: StyleIntentRegistry, + *, + text_encoder_getter: Optional[Callable[[], Any]] = None, + ) -> None: + self.registry = registry + self._text_encoder_getter = text_encoder_getter + + def prepare_hits( + self, + es_hits: List[Dict[str, Any]], + parsed_query: Any, + ) -> Dict[str, SkuSelectionDecision]: + decisions: Dict[str, SkuSelectionDecision] = {} + style_profile = getattr(parsed_query, "style_intent_profile", None) + if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active: + return decisions + + selection_context = self._build_selection_context(parsed_query, style_profile) + + for hit in es_hits: + source = hit.get("_source") + if not isinstance(source, dict): + continue + + decision = self._select_for_source( + source, + style_profile=style_profile, + selection_context=selection_context, + ) + if decision is None: + continue + + if decision.rerank_suffix: + hit["_style_rerank_suffix"] = decision.rerank_suffix + else: + hit.pop("_style_rerank_suffix", None) + + doc_id = hit.get("_id") + if doc_id is not None: + decisions[str(doc_id)] = decision + + return decisions + + def apply_precomputed_decisions( + self, + es_hits: List[Dict[str, Any]], + decisions: Dict[str, SkuSelectionDecision], + ) -> None: + if not es_hits or not decisions: + return + + for hit in es_hits: + doc_id = hit.get("_id") + if doc_id is None: + continue + decision = decisions.get(str(doc_id)) + if decision is None: + continue + source = hit.get("_source") + if not isinstance(source, dict): + continue + self._apply_decision_to_source(source, decision) + if decision.rerank_suffix: + hit["_style_rerank_suffix"] = decision.rerank_suffix + else: + hit.pop("_style_rerank_suffix", None) + + def _build_query_texts( + self, + parsed_query: Any, + style_profile: StyleIntentProfile, + ) -> List[str]: + texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text] + if texts: + return list(dict.fromkeys(texts)) + + fallbacks: List[str] = [] + for value in ( + getattr(parsed_query, "original_query", None), + getattr(parsed_query, "query_normalized", None), + getattr(parsed_query, "rewritten_query", None), + ): + normalized = normalize_query_text(value) + if normalized: + fallbacks.append(normalized) + translations = getattr(parsed_query, "translations", {}) or {} + if isinstance(translations, dict): + for value in translations.values(): + normalized = normalize_query_text(value) + if normalized: + fallbacks.append(normalized) + return list(dict.fromkeys(fallbacks)) + + def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]: + query_vector = getattr(parsed_query, "query_vector", None) + if query_vector is not None: + return np.asarray(query_vector, dtype=np.float32) + + text_encoder = self._get_text_encoder() + if text_encoder is None: + return None + + query_text = ( + getattr(parsed_query, "rewritten_query", None) + or getattr(parsed_query, "query_normalized", None) + or getattr(parsed_query, "original_query", None) + ) + if not query_text: + return None + + vectors = text_encoder.encode([query_text], priority=1) + if vectors is None or len(vectors) == 0 or vectors[0] is None: + return None + return np.asarray(vectors[0], dtype=np.float32) + + def _build_selection_context( + self, + parsed_query: Any, + style_profile: StyleIntentProfile, + ) -> _SelectionContext: + matched_terms_by_intent: Dict[str, List[str]] = {} + for intent in style_profile.intents: + normalized_term = normalize_query_text(intent.matched_term) + if not normalized_term: + continue + matched_terms = matched_terms_by_intent.setdefault(intent.intent_type, []) + if normalized_term not in matched_terms: + matched_terms.append(normalized_term) + + return _SelectionContext( + query_texts=tuple(self._build_query_texts(parsed_query, style_profile)), + matched_terms_by_intent={ + intent_type: tuple(terms) + for intent_type, terms in matched_terms_by_intent.items() + }, + query_vector=self._get_query_vector(parsed_query), + ) + + def _get_text_encoder(self) -> Any: + if self._text_encoder_getter is None: + return None + return self._text_encoder_getter() + + def _resolve_dimensions( + self, + source: Dict[str, Any], + style_profile: StyleIntentProfile, + ) -> Dict[str, Optional[str]]: + option_names = { + "option1_value": normalize_query_text(source.get("option1_name")), + "option2_value": normalize_query_text(source.get("option2_name")), + "option3_value": normalize_query_text(source.get("option3_name")), + } + resolved: Dict[str, Optional[str]] = {} + for intent in style_profile.intents: + if intent.intent_type in resolved: + continue + aliases = set(intent.dimension_aliases or self.registry.get_dimension_aliases(intent.intent_type)) + matched_field = None + for field_name, option_name in option_names.items(): + if option_name and option_name in aliases: + matched_field = field_name + break + resolved[intent.intent_type] = matched_field + return resolved + + def _build_candidates( + self, + skus: List[Dict[str, Any]], + resolved_dimensions: Dict[str, Optional[str]], + ) -> List[_SkuCandidate]: + if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): + return [] + + candidates: List[_SkuCandidate] = [] + for index, sku in enumerate(skus): + intent_values: Dict[str, str] = {} + normalized_intent_values: Dict[str, str] = {} + for intent_type, field_name in resolved_dimensions.items(): + if not field_name: + continue + raw = str(sku.get(field_name) or "").strip() + intent_values[intent_type] = raw + normalized_intent_values[intent_type] = normalize_query_text(raw) + + selection_parts: List[str] = [] + norm_parts: List[str] = [] + seen: set[str] = set() + for intent_type, raw in intent_values.items(): + nv = normalized_intent_values[intent_type] + if not nv or nv in seen: + continue + seen.add(nv) + selection_parts.append(raw) + norm_parts.append(nv) + + selection_text = " ".join(selection_parts).strip() + normalized_selection_text = " ".join(norm_parts).strip() + candidates.append( + _SkuCandidate( + index=index, + sku_id=str(sku.get("sku_id") or ""), + sku=sku, + selection_text=selection_text, + normalized_selection_text=normalized_selection_text, + intent_values=intent_values, + normalized_intent_values=normalized_intent_values, + ) + ) + return candidates + + @staticmethod + def _empty_decision( + resolved_dimensions: Dict[str, Optional[str]], + matched_stage: str, + ) -> SkuSelectionDecision: + return SkuSelectionDecision( + selected_sku_id=None, + rerank_suffix="", + selected_text="", + matched_stage=matched_stage, + resolved_dimensions=dict(resolved_dimensions), + ) + + def _is_text_match( + self, + intent_type: str, + value: str, + selection_context: _SelectionContext, + *, + normalized_value: Optional[str] = None, + ) -> bool: + if normalized_value is None: + normalized_value = normalize_query_text(value) + if not normalized_value: + return False + + cache_key = (intent_type, normalized_value) + cached = selection_context.text_match_cache.get(cache_key) + if cached is not None: + return cached + + matched_terms = selection_context.matched_terms_by_intent.get(intent_type, ()) + has_term_match = any(term in normalized_value for term in matched_terms if term) + query_contains_value = any( + normalized_value in query_text + for query_text in selection_context.query_texts + ) + matched = bool(has_term_match or query_contains_value) + selection_context.text_match_cache[cache_key] = matched + return matched + + def _find_first_text_match( + self, + candidates: Sequence[_SkuCandidate], + selection_context: _SelectionContext, + ) -> Optional[_SkuCandidate]: + for candidate in candidates: + if candidate.intent_values and all( + self._is_text_match( + intent_type, + value, + selection_context, + normalized_value=candidate.normalized_intent_values[intent_type], + ) + for intent_type, value in candidate.intent_values.items() + ): + return candidate + return None + + def _select_by_embedding( + self, + candidates: Sequence[_SkuCandidate], + selection_context: _SelectionContext, + ) -> Tuple[Optional[_SkuCandidate], Optional[float]]: + if not candidates: + return None, None + text_encoder = self._get_text_encoder() + if selection_context.query_vector is None or text_encoder is None: + return None, None + + unique_texts = list( + dict.fromkeys( + candidate.normalized_selection_text + for candidate in candidates + if candidate.normalized_selection_text + and candidate.normalized_selection_text not in selection_context.selection_vector_cache + ) + ) + if unique_texts: + vectors = text_encoder.encode(unique_texts, priority=1) + for key, vector in zip(unique_texts, vectors): + selection_context.selection_vector_cache[key] = ( + np.asarray(vector, dtype=np.float32) if vector is not None else None + ) + + best_candidate: Optional[_SkuCandidate] = None + best_score: Optional[float] = None + query_vector_array = np.asarray(selection_context.query_vector, dtype=np.float32) + for candidate in candidates: + normalized_text = candidate.normalized_selection_text + if not normalized_text: + continue + + score = selection_context.similarity_cache.get(normalized_text) + if score is None: + candidate_vector = selection_context.selection_vector_cache.get(normalized_text) + if candidate_vector is None: + selection_context.similarity_cache[normalized_text] = None + continue + score = float(np.inner(query_vector_array, candidate_vector)) + selection_context.similarity_cache[normalized_text] = score + + if score is None: + continue + if best_score is None or score > best_score: + best_candidate = candidate + best_score = score + + return best_candidate, best_score + + def _select_for_source( + self, + source: Dict[str, Any], + *, + style_profile: StyleIntentProfile, + selection_context: _SelectionContext, + ) -> Optional[SkuSelectionDecision]: + skus = source.get("skus") + if not isinstance(skus, list) or not skus: + return None + + resolved_dimensions = self._resolve_dimensions(source, style_profile) + if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): + return self._empty_decision(resolved_dimensions, matched_stage="unresolved") + + candidates = self._build_candidates(skus, resolved_dimensions) + if not candidates: + return self._empty_decision(resolved_dimensions, matched_stage="no_candidates") + + text_match = self._find_first_text_match(candidates, selection_context) + if text_match is not None: + return self._build_decision(text_match, resolved_dimensions, matched_stage="text") + + chosen, similarity_score = self._select_by_embedding(candidates, selection_context) + if chosen is None: + return self._empty_decision(resolved_dimensions, matched_stage="no_match") + return self._build_decision( + chosen, + resolved_dimensions, + matched_stage="embedding", + similarity_score=similarity_score, + ) + + @staticmethod + def _build_decision( + candidate: _SkuCandidate, + resolved_dimensions: Dict[str, Optional[str]], + *, + matched_stage: str, + similarity_score: Optional[float] = None, + ) -> SkuSelectionDecision: + return SkuSelectionDecision( + selected_sku_id=candidate.sku_id or None, + rerank_suffix=str(candidate.selection_text or "").strip(), + selected_text=str(candidate.selection_text or "").strip(), + matched_stage=matched_stage, + similarity_score=similarity_score, + resolved_dimensions=dict(resolved_dimensions), + ) + + @staticmethod + def _apply_decision_to_source(source: Dict[str, Any], decision: SkuSelectionDecision) -> None: + skus = source.get("skus") + if not isinstance(skus, list) or not skus or not decision.selected_sku_id: + return + + selected_index = None + for index, sku in enumerate(skus): + if str(sku.get("sku_id") or "") == decision.selected_sku_id: + selected_index = index + break + if selected_index is None: + return + + selected_sku = skus.pop(selected_index) + skus.insert(0, selected_sku) + + image_src = selected_sku.get("image_src") or selected_sku.get("imageSrc") + if image_src: + source["image_url"] = image_src diff --git a/tests/test_search_rerank_window.py b/tests/test_search_rerank_window.py index 1cff026..f5ec64d 100644 --- a/tests/test_search_rerank_window.py +++ b/tests/test_search_rerank_window.py @@ -63,6 +63,7 @@ def _build_style_intent_profile(intent_type: str, canonical_value: str, *dimensi canonical_value=canonical_value, matched_term=canonical_value, matched_query_text=canonical_value, + attribute_terms=(canonical_value,), dimension_aliases=tuple(aliases), ), ) diff --git a/tests/test_sku_intent_selector.py b/tests/test_sku_intent_selector.py new file mode 100644 index 0000000..20174e8 --- /dev/null +++ b/tests/test_sku_intent_selector.py @@ -0,0 +1,106 @@ +from types import SimpleNamespace + +from config import QueryConfig +from query.style_intent import DetectedStyleIntent, StyleIntentProfile, StyleIntentRegistry +from search.sku_intent_selector import StyleSkuSelector + + +def test_style_sku_selector_matches_first_sku_by_attribute_terms(): + registry = StyleIntentRegistry.from_query_config( + QueryConfig( + style_intent_terms={ + "color": [{"en_terms": ["navy"], "zh_terms": ["藏青"], "attribute_terms": ["navy"]}], + "size": [{"en_terms": ["xl"], "zh_terms": ["加大码"], "attribute_terms": ["x-large"]}], + }, + style_intent_dimension_aliases={ + "color": ["color", "颜色"], + "size": ["size", "尺码"], + }, + ) + ) + selector = StyleSkuSelector(registry) + parsed_query = SimpleNamespace( + style_intent_profile=StyleIntentProfile( + intents=( + DetectedStyleIntent( + intent_type="color", + canonical_value="navy", + matched_term="藏青", + matched_query_text="藏青", + attribute_terms=("navy",), + dimension_aliases=("color", "颜色"), + ), + DetectedStyleIntent( + intent_type="size", + canonical_value="x-large", + matched_term="xl", + matched_query_text="xl", + attribute_terms=("x-large",), + dimension_aliases=("size", "尺码"), + ), + ), + ) + ) + source = { + "option1_name": "Color", + "option2_name": "Size", + "skus": [ + {"sku_id": "1", "option1_value": "Black", "option2_value": "M"}, + {"sku_id": "2", "option1_value": "Navy Blue", "option2_value": "X-Large", "image_src": "matched.jpg"}, + {"sku_id": "3", "option1_value": "Navy", "option2_value": "XL"}, + ], + } + hits = [{"_id": "spu-1", "_source": source}] + + decisions = selector.prepare_hits(hits, parsed_query) + decision = decisions["spu-1"] + + assert decision.selected_sku_id == "2" + assert decision.selected_text == "Navy Blue X-Large" + assert decision.matched_stage == "text" + + selector.apply_precomputed_decisions(hits, decisions) + + assert source["skus"][0]["sku_id"] == "2" + assert source["image_url"] == "matched.jpg" + + +def test_style_sku_selector_returns_no_match_without_attribute_contains(): + registry = StyleIntentRegistry.from_query_config( + QueryConfig( + style_intent_terms={ + "color": [{"en_terms": ["beige"], "zh_terms": ["米色"], "attribute_terms": ["beige"]}], + }, + style_intent_dimension_aliases={"color": ["color", "颜色"]}, + ) + ) + selector = StyleSkuSelector(registry) + parsed_query = SimpleNamespace( + style_intent_profile=StyleIntentProfile( + intents=( + DetectedStyleIntent( + intent_type="color", + canonical_value="beige", + matched_term="米色", + matched_query_text="米色", + attribute_terms=("beige",), + dimension_aliases=("color", "颜色"), + ), + ), + ) + ) + hits = [{ + "_id": "spu-1", + "_source": { + "option1_name": "Color", + "skus": [ + {"sku_id": "1", "option1_value": "Khaki"}, + {"sku_id": "2", "option1_value": "Light Brown"}, + ], + }, + }] + + decisions = selector.prepare_hits(hits, parsed_query) + + assert decisions["spu-1"].selected_sku_id is None + assert decisions["spu-1"].matched_stage == "no_match" diff --git a/tests/test_style_intent.py b/tests/test_style_intent.py index d46217a..6fe19db 100644 --- a/tests/test_style_intent.py +++ b/tests/test_style_intent.py @@ -7,8 +7,8 @@ from query.style_intent import StyleIntentDetector, StyleIntentRegistry def test_style_intent_detector_matches_original_and_translated_queries(): query_config = QueryConfig( style_intent_terms={ - "color": [["black", "黑色", "black"]], - "size": [["xl", "x-large", "加大码"]], + "color": [{"en_terms": ["black"], "zh_terms": ["黑色"], "attribute_terms": ["black"]}], + "size": [{"en_terms": ["xl", "x-large"], "zh_terms": ["加大码"], "attribute_terms": ["x-large"]}], }, style_intent_dimension_aliases={ "color": ["color", "颜色"], @@ -31,5 +31,30 @@ def test_style_intent_detector_matches_original_and_translated_queries(): assert profile.is_active is True assert profile.get_canonical_values("color") == {"black"} - assert profile.get_canonical_values("size") == {"xl"} + assert profile.get_canonical_values("size") == {"x-large"} assert len(profile.query_variants) == 2 + + +def test_style_intent_detector_uses_original_query_when_language_translation_missing(): + query_config = QueryConfig( + style_intent_terms={ + "color": [{"en_terms": ["black"], "zh_terms": ["黑色"], "attribute_terms": ["black"]}], + }, + style_intent_dimension_aliases={"color": ["color", "颜色"]}, + ) + detector = StyleIntentDetector( + StyleIntentRegistry.from_query_config(query_config), + tokenizer=lambda text: text.split(), + ) + + parsed_query = SimpleNamespace( + original_query="black dress", + query_normalized="black dress", + rewritten_query="black dress", + translations={"zh": "连衣裙"}, + ) + + profile = detector.detect(parsed_query) + + assert profile.get_canonical_values("color") == {"black"} + assert profile.intents[0].attribute_terms == ("black",) -- libgit2 0.21.2