Commit b712a831f1a6b2735b07e3d342ea680b5a5933a0
1 parent
74fdf9bd
意图识别策略和性能优化
@config/dictionaries/style_intent_color.csv @config/dictionaries/style_intent_size.csv @query/style_intent.py @search/sku_intent_selector.py 1. 两个csv词典,分为三列, - 英文关键词 - 中文关键词 - 标准属性名称词 三列都可以允许逗号分割。补充的第三列使用在商品属性中,使用的是标准的英文名称 2. 判断意图的时候,中文词用中文翻译名去匹配,如果不存在中文翻译名,则用原始 query,英文词同理 3. SKU 选择的时候,用每一个 SKU 的属性名去匹配。 匹配规则要大幅度简化,并做性能优化: 1)文本匹配规则只需要看规范化后的属性值是否包含了词典配置的第三列"标准属性名称词",如果包含了,则认为匹配成功。 找到第一个匹配成功的即可。如果都没有成功,后面也不再需要用向量匹配。 暂时废弃向量匹配、双向匹配等复杂逻辑。
Showing
11 changed files
with
812 additions
and
297 deletions
Show diff stats
config/dictionaries/product_title_exclusion.tsv
config/dictionaries/style_intent_color.csv
| 1 | -black,black,blk,黑,黑色 | |
| 2 | -white,white,wht,白,白色 | |
| 3 | -red,red,reddish,红,红色 | |
| 4 | -blue,blue,blu,蓝,蓝色 | |
| 5 | -green,green,grn,绿,绿色 | |
| 6 | -yellow,yellow,ylw,黄,黄色 | |
| 7 | -pink,pink,粉,粉色 | |
| 8 | -purple,purple,violet,紫,紫色 | |
| 9 | -gray,gray,grey,灰,灰色 | |
| 10 | -brown,brown,棕,棕色,咖啡色 | |
| 11 | -beige,beige,khaki,米色,卡其色 | |
| 12 | -navy,navy,navy blue,藏青,藏蓝,深蓝 | |
| 13 | -silver,silver,银,银色 | |
| 14 | -gold,gold,金,金色 | |
| 15 | -orange,orange,橙,橙色 | |
| 1 | +"black,blk","黑,黑色","black" | |
| 2 | +"white,wht","白,白色","white" | |
| 3 | +"red,reddish","红,红色","red" | |
| 4 | +"blue,blu","蓝,蓝色","blue" | |
| 5 | +"green,grn","绿,绿色","green" | |
| 6 | +"yellow,ylw","黄,黄色","yellow" | |
| 7 | +"pink","粉,粉色","pink" | |
| 8 | +"purple,violet","紫,紫色","purple" | |
| 9 | +"gray,grey","灰,灰色","gray,grey" | |
| 10 | +"brown","棕,棕色,咖啡色","brown" | |
| 11 | +"beige,khaki","米色,卡其色","beige,khaki" | |
| 12 | +"navy,navy blue","藏青,藏蓝,深蓝","navy" | |
| 13 | +"silver","银,银色","silver" | |
| 14 | +"gold","金,金色","gold" | |
| 15 | +"orange","橙,橙色","orange" | ... | ... |
config/dictionaries/style_intent_size.csv
| 1 | -xs,xs,extra small,x-small,加小码 | |
| 2 | -s,s,small,小码,小号 | |
| 3 | -m,m,medium,中码,中号 | |
| 4 | -l,l,large,大码,大号 | |
| 5 | -xl,xl,x-large,extra large,加大码 | |
| 6 | -xxl,xxl,2xl,xx-large,双加大码 | |
| 7 | -xxxl,xxxl,3xl,xxx-large,三加大码 | |
| 8 | -one size,one size,onesize,free size,均码 | |
| 1 | +"xs,extra small,x-small","加小码","xs,extra small,x-small" | |
| 2 | +"s,small","小码,小号","s,small" | |
| 3 | +"m,medium","中码,中号","m,medium" | |
| 4 | +"l,large","大码,大号","l,large" | |
| 5 | +"xl,x-large,extra large","加大码","xl,x-large,extra large" | |
| 6 | +"xxl,2xl,xx-large","双加大码","xxl,2xl,xx-large" | |
| 7 | +"xxxl,3xl,xxx-large","三加大码","xxxl,3xl,xxx-large" | ... | ... |
config/loader.py
| ... | ... | @@ -10,6 +10,7 @@ from __future__ import annotations |
| 10 | 10 | import hashlib |
| 11 | 11 | import json |
| 12 | 12 | import os |
| 13 | +import csv | |
| 13 | 14 | from copy import deepcopy |
| 14 | 15 | from dataclasses import asdict |
| 15 | 16 | from functools import lru_cache |
| ... | ... | @@ -96,20 +97,33 @@ def _read_rewrite_dictionary(path: Path) -> Dict[str, str]: |
| 96 | 97 | return rewrite_dict |
| 97 | 98 | |
| 98 | 99 | |
| 99 | -def _read_synonym_csv_dictionary(path: Path) -> List[List[str]]: | |
| 100 | - rows: List[List[str]] = [] | |
| 100 | +def _read_synonym_csv_dictionary(path: Path) -> List[Dict[str, List[str]]]: | |
| 101 | + rows: List[Dict[str, List[str]]] = [] | |
| 101 | 102 | if not path.exists(): |
| 102 | 103 | return rows |
| 103 | 104 | |
| 105 | + def _split_terms(cell: str) -> List[str]: | |
| 106 | + return [item.strip() for item in str(cell or "").split(",") if item.strip()] | |
| 107 | + | |
| 104 | 108 | with open(path, "r", encoding="utf-8") as handle: |
| 105 | - for raw_line in handle: | |
| 106 | - line = raw_line.strip() | |
| 107 | - if not line or line.startswith("#"): | |
| 109 | + reader = csv.reader(handle) | |
| 110 | + for parts in reader: | |
| 111 | + if not parts: | |
| 112 | + continue | |
| 113 | + if parts[0].strip().startswith("#"): | |
| 108 | 114 | continue |
| 109 | - parts = [segment.strip() for segment in line.split(",")] | |
| 110 | - normalized = [segment for segment in parts if segment] | |
| 111 | - if normalized: | |
| 112 | - rows.append(normalized) | |
| 115 | + | |
| 116 | + normalized = [segment.strip() for segment in parts] | |
| 117 | + if len(normalized) < 3: | |
| 118 | + continue | |
| 119 | + | |
| 120 | + row = { | |
| 121 | + "en_terms": _split_terms(normalized[0]), | |
| 122 | + "zh_terms": _split_terms(normalized[1]), | |
| 123 | + "attribute_terms": _split_terms(normalized[2]), | |
| 124 | + } | |
| 125 | + if any(row.values()): | |
| 126 | + rows.append(row) | |
| 113 | 127 | return rows |
| 114 | 128 | |
| 115 | 129 | ... | ... |
config/schema.py
| ... | ... | @@ -65,7 +65,7 @@ class QueryConfig: |
| 65 | 65 | translation_embedding_wait_budget_ms_source_in_index: int = 80 |
| 66 | 66 | translation_embedding_wait_budget_ms_source_not_in_index: int = 200 |
| 67 | 67 | style_intent_enabled: bool = True |
| 68 | - style_intent_terms: Dict[str, List[List[str]]] = field(default_factory=dict) | |
| 68 | + style_intent_terms: Dict[str, List[Dict[str, List[str]]]] = field(default_factory=dict) | |
| 69 | 69 | style_intent_dimension_aliases: Dict[str, List[str]] = field(default_factory=dict) |
| 70 | 70 | product_title_exclusion_enabled: bool = True |
| 71 | 71 | product_title_exclusion_rules: List[Dict[str, List[str]]] = field(default_factory=list) | ... | ... |
query/style_intent.py
| ... | ... | @@ -11,38 +11,79 @@ from .tokenization import TokenizedText, normalize_query_text, tokenize_text |
| 11 | 11 | |
| 12 | 12 | |
| 13 | 13 | @dataclass(frozen=True) |
| 14 | +class StyleIntentTermDefinition: | |
| 15 | + canonical_value: str | |
| 16 | + en_terms: Tuple[str, ...] | |
| 17 | + zh_terms: Tuple[str, ...] | |
| 18 | + attribute_terms: Tuple[str, ...] | |
| 19 | + | |
| 20 | + | |
| 21 | +@dataclass(frozen=True) | |
| 14 | 22 | class StyleIntentDefinition: |
| 15 | 23 | intent_type: str |
| 16 | - term_groups: Tuple[Tuple[str, ...], ...] | |
| 24 | + terms: Tuple[StyleIntentTermDefinition, ...] | |
| 17 | 25 | dimension_aliases: Tuple[str, ...] |
| 18 | - synonym_to_canonical: Dict[str, str] | |
| 26 | + en_synonym_to_term: Dict[str, StyleIntentTermDefinition] | |
| 27 | + zh_synonym_to_term: Dict[str, StyleIntentTermDefinition] | |
| 19 | 28 | max_term_ngram: int = 3 |
| 20 | 29 | |
| 21 | 30 | @classmethod |
| 22 | 31 | def from_rows( |
| 23 | 32 | cls, |
| 24 | 33 | intent_type: str, |
| 25 | - rows: Sequence[Sequence[str]], | |
| 34 | + rows: Sequence[Dict[str, List[str]]], | |
| 26 | 35 | dimension_aliases: Sequence[str], |
| 27 | 36 | ) -> "StyleIntentDefinition": |
| 28 | - term_groups: List[Tuple[str, ...]] = [] | |
| 29 | - synonym_to_canonical: Dict[str, str] = {} | |
| 37 | + terms: List[StyleIntentTermDefinition] = [] | |
| 38 | + en_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {} | |
| 39 | + zh_synonym_to_term: Dict[str, StyleIntentTermDefinition] = {} | |
| 30 | 40 | max_ngram = 1 |
| 31 | 41 | |
| 32 | 42 | for row in rows: |
| 33 | - normalized_terms: List[str] = [] | |
| 34 | - for raw_term in row: | |
| 35 | - term = normalize_query_text(raw_term) | |
| 36 | - if not term or term in normalized_terms: | |
| 37 | - continue | |
| 38 | - normalized_terms.append(term) | |
| 39 | - if not normalized_terms: | |
| 43 | + normalized_en = tuple( | |
| 44 | + dict.fromkeys( | |
| 45 | + term | |
| 46 | + for term in (normalize_query_text(raw) for raw in row.get("en_terms", [])) | |
| 47 | + if term | |
| 48 | + ) | |
| 49 | + ) | |
| 50 | + normalized_zh = tuple( | |
| 51 | + dict.fromkeys( | |
| 52 | + term | |
| 53 | + for term in (normalize_query_text(raw) for raw in row.get("zh_terms", [])) | |
| 54 | + if term | |
| 55 | + ) | |
| 56 | + ) | |
| 57 | + normalized_attribute = tuple( | |
| 58 | + dict.fromkeys( | |
| 59 | + term | |
| 60 | + for term in (normalize_query_text(raw) for raw in row.get("attribute_terms", [])) | |
| 61 | + if term | |
| 62 | + ) | |
| 63 | + ) | |
| 64 | + if not normalized_en and not normalized_zh and not normalized_attribute: | |
| 40 | 65 | continue |
| 41 | 66 | |
| 42 | - canonical = normalized_terms[0] | |
| 43 | - term_groups.append(tuple(normalized_terms)) | |
| 44 | - for term in normalized_terms: | |
| 45 | - synonym_to_canonical[term] = canonical | |
| 67 | + canonical = ( | |
| 68 | + normalized_attribute[0] | |
| 69 | + if normalized_attribute | |
| 70 | + else normalized_en[0] | |
| 71 | + if normalized_en | |
| 72 | + else normalized_zh[0] | |
| 73 | + ) | |
| 74 | + term_definition = StyleIntentTermDefinition( | |
| 75 | + canonical_value=canonical, | |
| 76 | + en_terms=normalized_en, | |
| 77 | + zh_terms=normalized_zh, | |
| 78 | + attribute_terms=normalized_attribute, | |
| 79 | + ) | |
| 80 | + terms.append(term_definition) | |
| 81 | + | |
| 82 | + for term in normalized_en: | |
| 83 | + en_synonym_to_term[term] = term_definition | |
| 84 | + max_ngram = max(max_ngram, len(term.split())) | |
| 85 | + for term in normalized_zh: | |
| 86 | + zh_synonym_to_term[term] = term_definition | |
| 46 | 87 | max_ngram = max(max_ngram, len(term.split())) |
| 47 | 88 | |
| 48 | 89 | aliases = tuple( |
| ... | ... | @@ -58,28 +99,31 @@ class StyleIntentDefinition: |
| 58 | 99 | |
| 59 | 100 | return cls( |
| 60 | 101 | intent_type=intent_type, |
| 61 | - term_groups=tuple(term_groups), | |
| 102 | + terms=tuple(terms), | |
| 62 | 103 | dimension_aliases=aliases, |
| 63 | - synonym_to_canonical=synonym_to_canonical, | |
| 104 | + en_synonym_to_term=en_synonym_to_term, | |
| 105 | + zh_synonym_to_term=zh_synonym_to_term, | |
| 64 | 106 | max_term_ngram=max_ngram, |
| 65 | 107 | ) |
| 66 | 108 | |
| 67 | - def match_candidates(self, candidates: Iterable[str]) -> Set[str]: | |
| 68 | - matched: Set[str] = set() | |
| 109 | + def match_candidates(self, candidates: Iterable[str], *, language: str) -> Set[StyleIntentTermDefinition]: | |
| 110 | + mapping = self.zh_synonym_to_term if language == "zh" else self.en_synonym_to_term | |
| 111 | + matched: Set[StyleIntentTermDefinition] = set() | |
| 69 | 112 | for candidate in candidates: |
| 70 | - canonical = self.synonym_to_canonical.get(normalize_query_text(candidate)) | |
| 71 | - if canonical: | |
| 72 | - matched.add(canonical) | |
| 113 | + term_definition = mapping.get(normalize_query_text(candidate)) | |
| 114 | + if term_definition: | |
| 115 | + matched.add(term_definition) | |
| 73 | 116 | return matched |
| 74 | 117 | |
| 75 | 118 | def match_text( |
| 76 | 119 | self, |
| 77 | 120 | text: str, |
| 78 | 121 | *, |
| 122 | + language: str, | |
| 79 | 123 | tokenizer: Optional[Callable[[str], Any]] = None, |
| 80 | - ) -> Set[str]: | |
| 124 | + ) -> Set[StyleIntentTermDefinition]: | |
| 81 | 125 | bundle = tokenize_text(text, tokenizer=tokenizer, max_ngram=self.max_term_ngram) |
| 82 | - return self.match_candidates(bundle.candidates) | |
| 126 | + return self.match_candidates(bundle.candidates, language=language) | |
| 83 | 127 | |
| 84 | 128 | |
| 85 | 129 | @dataclass(frozen=True) |
| ... | ... | @@ -88,6 +132,7 @@ class DetectedStyleIntent: |
| 88 | 132 | canonical_value: str |
| 89 | 133 | matched_term: str |
| 90 | 134 | matched_query_text: str |
| 135 | + attribute_terms: Tuple[str, ...] | |
| 91 | 136 | dimension_aliases: Tuple[str, ...] |
| 92 | 137 | |
| 93 | 138 | def to_dict(self) -> Dict[str, Any]: |
| ... | ... | @@ -96,6 +141,7 @@ class DetectedStyleIntent: |
| 96 | 141 | "canonical_value": self.canonical_value, |
| 97 | 142 | "matched_term": self.matched_term, |
| 98 | 143 | "matched_query_text": self.matched_query_text, |
| 144 | + "attribute_terms": list(self.attribute_terms), | |
| 99 | 145 | "dimension_aliases": list(self.dimension_aliases), |
| 100 | 146 | } |
| 101 | 147 | |
| ... | ... | @@ -159,7 +205,7 @@ class StyleIntentRegistry: |
| 159 | 205 | rows=rows or [], |
| 160 | 206 | dimension_aliases=dimension_aliases.get(intent_type, []), |
| 161 | 207 | ) |
| 162 | - if definition.synonym_to_canonical: | |
| 208 | + if definition.terms: | |
| 163 | 209 | definitions[definition.intent_type] = definition |
| 164 | 210 | |
| 165 | 211 | return cls( |
| ... | ... | @@ -191,15 +237,10 @@ class StyleIntentDetector: |
| 191 | 237 | seen = set() |
| 192 | 238 | variants: List[TokenizedText] = [] |
| 193 | 239 | texts = [ |
| 194 | - getattr(parsed_query, "original_query", None), | |
| 195 | - getattr(parsed_query, "query_normalized", None), | |
| 196 | - getattr(parsed_query, "rewritten_query", None), | |
| 240 | + self._get_language_query_text(parsed_query, "zh"), | |
| 241 | + self._get_language_query_text(parsed_query, "en"), | |
| 197 | 242 | ] |
| 198 | 243 | |
| 199 | - translations = getattr(parsed_query, "translations", {}) or {} | |
| 200 | - if isinstance(translations, dict): | |
| 201 | - texts.extend(translations.values()) | |
| 202 | - | |
| 203 | 244 | for raw_text in texts: |
| 204 | 245 | text = str(raw_text or "").strip() |
| 205 | 246 | if not text: |
| ... | ... | @@ -221,35 +262,66 @@ class StyleIntentDetector: |
| 221 | 262 | |
| 222 | 263 | return tuple(variants) |
| 223 | 264 | |
| 265 | + @staticmethod | |
| 266 | + def _get_language_query_text(parsed_query: Any, language: str) -> str: | |
| 267 | + translations = getattr(parsed_query, "translations", {}) or {} | |
| 268 | + if isinstance(translations, dict): | |
| 269 | + translated = translations.get(language) | |
| 270 | + if translated: | |
| 271 | + return str(translated) | |
| 272 | + return str(getattr(parsed_query, "original_query", "") or "") | |
| 273 | + | |
| 274 | + def _tokenize_language_query(self, parsed_query: Any, language: str) -> Optional[TokenizedText]: | |
| 275 | + text = self._get_language_query_text(parsed_query, language).strip() | |
| 276 | + if not text: | |
| 277 | + return None | |
| 278 | + return tokenize_text( | |
| 279 | + text, | |
| 280 | + tokenizer=self.tokenizer, | |
| 281 | + max_ngram=max( | |
| 282 | + (definition.max_term_ngram for definition in self.registry.definitions.values()), | |
| 283 | + default=3, | |
| 284 | + ), | |
| 285 | + ) | |
| 286 | + | |
| 224 | 287 | def detect(self, parsed_query: Any) -> StyleIntentProfile: |
| 225 | 288 | if not self.registry.enabled or not self.registry.definitions: |
| 226 | 289 | return StyleIntentProfile() |
| 227 | 290 | |
| 228 | 291 | query_variants = self._build_query_variants(parsed_query) |
| 292 | + zh_variant = self._tokenize_language_query(parsed_query, "zh") | |
| 293 | + en_variant = self._tokenize_language_query(parsed_query, "en") | |
| 229 | 294 | detected: List[DetectedStyleIntent] = [] |
| 230 | 295 | seen_pairs = set() |
| 231 | 296 | |
| 232 | - for variant in query_variants: | |
| 233 | - for intent_type, definition in self.registry.definitions.items(): | |
| 234 | - matched_canonicals = definition.match_candidates(variant.candidates) | |
| 235 | - if not matched_canonicals: | |
| 297 | + for intent_type, definition in self.registry.definitions.items(): | |
| 298 | + for language, variant, mapping in ( | |
| 299 | + ("zh", zh_variant, definition.zh_synonym_to_term), | |
| 300 | + ("en", en_variant, definition.en_synonym_to_term), | |
| 301 | + ): | |
| 302 | + if variant is None or not mapping: | |
| 303 | + continue | |
| 304 | + | |
| 305 | + matched_terms = definition.match_candidates(variant.candidates, language=language) | |
| 306 | + if not matched_terms: | |
| 236 | 307 | continue |
| 237 | 308 | |
| 238 | 309 | for candidate in variant.candidates: |
| 239 | 310 | normalized_candidate = normalize_query_text(candidate) |
| 240 | - canonical = definition.synonym_to_canonical.get(normalized_candidate) | |
| 241 | - if not canonical or canonical not in matched_canonicals: | |
| 311 | + term_definition = mapping.get(normalized_candidate) | |
| 312 | + if term_definition is None or term_definition not in matched_terms: | |
| 242 | 313 | continue |
| 243 | - pair = (intent_type, canonical) | |
| 314 | + pair = (intent_type, term_definition.canonical_value) | |
| 244 | 315 | if pair in seen_pairs: |
| 245 | 316 | continue |
| 246 | 317 | seen_pairs.add(pair) |
| 247 | 318 | detected.append( |
| 248 | 319 | DetectedStyleIntent( |
| 249 | 320 | intent_type=intent_type, |
| 250 | - canonical_value=canonical, | |
| 321 | + canonical_value=term_definition.canonical_value, | |
| 251 | 322 | matched_term=normalized_candidate, |
| 252 | 323 | matched_query_text=variant.text, |
| 324 | + attribute_terms=term_definition.attribute_terms, | |
| 253 | 325 | dimension_aliases=definition.dimension_aliases, |
| 254 | 326 | ) |
| 255 | 327 | ) | ... | ... |
search/sku_intent_selector.py
| ... | ... | @@ -5,9 +5,7 @@ SKU selection for style-intent-aware search results. |
| 5 | 5 | from __future__ import annotations |
| 6 | 6 | |
| 7 | 7 | from dataclasses import dataclass, field |
| 8 | -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple | |
| 9 | - | |
| 10 | -import numpy as np | |
| 8 | +from typing import Any, Callable, Dict, List, Optional, Tuple | |
| 11 | 9 | |
| 12 | 10 | from query.style_intent import StyleIntentProfile, StyleIntentRegistry |
| 13 | 11 | from query.tokenization import normalize_query_text |
| ... | ... | @@ -34,24 +32,10 @@ class SkuSelectionDecision: |
| 34 | 32 | |
| 35 | 33 | |
| 36 | 34 | @dataclass |
| 37 | -class _SkuCandidate: | |
| 38 | - index: int | |
| 39 | - sku_id: str | |
| 40 | - sku: Dict[str, Any] | |
| 41 | - selection_text: str | |
| 42 | - normalized_selection_text: str | |
| 43 | - intent_values: Dict[str, str] | |
| 44 | - normalized_intent_values: Dict[str, str] | |
| 45 | - | |
| 46 | - | |
| 47 | -@dataclass | |
| 48 | 35 | class _SelectionContext: |
| 49 | - query_texts: Tuple[str, ...] | |
| 50 | - matched_terms_by_intent: Dict[str, Tuple[str, ...]] | |
| 51 | - query_vector: Optional[np.ndarray] | |
| 36 | + attribute_terms_by_intent: Dict[str, Tuple[str, ...]] | |
| 37 | + normalized_text_cache: Dict[str, str] = field(default_factory=dict) | |
| 52 | 38 | text_match_cache: Dict[Tuple[str, str], bool] = field(default_factory=dict) |
| 53 | - selection_vector_cache: Dict[str, Optional[np.ndarray]] = field(default_factory=dict) | |
| 54 | - similarity_cache: Dict[str, Optional[float]] = field(default_factory=dict) | |
| 55 | 39 | |
| 56 | 40 | |
| 57 | 41 | class StyleSkuSelector: |
| ... | ... | @@ -76,7 +60,7 @@ class StyleSkuSelector: |
| 76 | 60 | if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active: |
| 77 | 61 | return decisions |
| 78 | 62 | |
| 79 | - selection_context = self._build_selection_context(parsed_query, style_profile) | |
| 63 | + selection_context = self._build_selection_context(style_profile) | |
| 80 | 64 | |
| 81 | 65 | for hit in es_hits: |
| 82 | 66 | source = hit.get("_source") |
| ... | ... | @@ -126,81 +110,37 @@ class StyleSkuSelector: |
| 126 | 110 | else: |
| 127 | 111 | hit.pop("_style_rerank_suffix", None) |
| 128 | 112 | |
| 129 | - def _build_query_texts( | |
| 130 | - self, | |
| 131 | - parsed_query: Any, | |
| 132 | - style_profile: StyleIntentProfile, | |
| 133 | - ) -> List[str]: | |
| 134 | - texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text] | |
| 135 | - if texts: | |
| 136 | - return list(dict.fromkeys(texts)) | |
| 137 | - | |
| 138 | - fallbacks: List[str] = [] | |
| 139 | - for value in ( | |
| 140 | - getattr(parsed_query, "original_query", None), | |
| 141 | - getattr(parsed_query, "query_normalized", None), | |
| 142 | - getattr(parsed_query, "rewritten_query", None), | |
| 143 | - ): | |
| 144 | - normalized = normalize_query_text(value) | |
| 145 | - if normalized: | |
| 146 | - fallbacks.append(normalized) | |
| 147 | - translations = getattr(parsed_query, "translations", {}) or {} | |
| 148 | - if isinstance(translations, dict): | |
| 149 | - for value in translations.values(): | |
| 150 | - normalized = normalize_query_text(value) | |
| 151 | - if normalized: | |
| 152 | - fallbacks.append(normalized) | |
| 153 | - return list(dict.fromkeys(fallbacks)) | |
| 154 | - | |
| 155 | - def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]: | |
| 156 | - query_vector = getattr(parsed_query, "query_vector", None) | |
| 157 | - if query_vector is not None: | |
| 158 | - return np.asarray(query_vector, dtype=np.float32) | |
| 159 | - | |
| 160 | - text_encoder = self._get_text_encoder() | |
| 161 | - if text_encoder is None: | |
| 162 | - return None | |
| 163 | - | |
| 164 | - query_text = ( | |
| 165 | - getattr(parsed_query, "rewritten_query", None) | |
| 166 | - or getattr(parsed_query, "query_normalized", None) | |
| 167 | - or getattr(parsed_query, "original_query", None) | |
| 168 | - ) | |
| 169 | - if not query_text: | |
| 170 | - return None | |
| 171 | - | |
| 172 | - vectors = text_encoder.encode([query_text], priority=1) | |
| 173 | - if vectors is None or len(vectors) == 0 or vectors[0] is None: | |
| 174 | - return None | |
| 175 | - return np.asarray(vectors[0], dtype=np.float32) | |
| 176 | - | |
| 177 | 113 | def _build_selection_context( |
| 178 | 114 | self, |
| 179 | - parsed_query: Any, | |
| 180 | 115 | style_profile: StyleIntentProfile, |
| 181 | 116 | ) -> _SelectionContext: |
| 182 | - matched_terms_by_intent: Dict[str, List[str]] = {} | |
| 117 | + attribute_terms_by_intent: Dict[str, List[str]] = {} | |
| 183 | 118 | for intent in style_profile.intents: |
| 184 | - normalized_term = normalize_query_text(intent.matched_term) | |
| 185 | - if not normalized_term: | |
| 186 | - continue | |
| 187 | - matched_terms = matched_terms_by_intent.setdefault(intent.intent_type, []) | |
| 188 | - if normalized_term not in matched_terms: | |
| 189 | - matched_terms.append(normalized_term) | |
| 119 | + terms = attribute_terms_by_intent.setdefault(intent.intent_type, []) | |
| 120 | + for raw_term in intent.attribute_terms: | |
| 121 | + normalized_term = normalize_query_text(raw_term) | |
| 122 | + if not normalized_term or normalized_term in terms: | |
| 123 | + continue | |
| 124 | + terms.append(normalized_term) | |
| 190 | 125 | |
| 191 | 126 | return _SelectionContext( |
| 192 | - query_texts=tuple(self._build_query_texts(parsed_query, style_profile)), | |
| 193 | - matched_terms_by_intent={ | |
| 127 | + attribute_terms_by_intent={ | |
| 194 | 128 | intent_type: tuple(terms) |
| 195 | - for intent_type, terms in matched_terms_by_intent.items() | |
| 129 | + for intent_type, terms in attribute_terms_by_intent.items() | |
| 196 | 130 | }, |
| 197 | - query_vector=self._get_query_vector(parsed_query), | |
| 198 | 131 | ) |
| 199 | 132 | |
| 200 | - def _get_text_encoder(self) -> Any: | |
| 201 | - if self._text_encoder_getter is None: | |
| 202 | - return None | |
| 203 | - return self._text_encoder_getter() | |
| 133 | + @staticmethod | |
| 134 | + def _normalize_cached(selection_context: _SelectionContext, value: Any) -> str: | |
| 135 | + raw = str(value or "").strip() | |
| 136 | + if not raw: | |
| 137 | + return "" | |
| 138 | + cached = selection_context.normalized_text_cache.get(raw) | |
| 139 | + if cached is not None: | |
| 140 | + return cached | |
| 141 | + normalized = normalize_query_text(raw) | |
| 142 | + selection_context.normalized_text_cache[raw] = normalized | |
| 143 | + return normalized | |
| 204 | 144 | |
| 205 | 145 | def _resolve_dimensions( |
| 206 | 146 | self, |
| ... | ... | @@ -225,51 +165,6 @@ class StyleSkuSelector: |
| 225 | 165 | resolved[intent.intent_type] = matched_field |
| 226 | 166 | return resolved |
| 227 | 167 | |
| 228 | - def _build_candidates( | |
| 229 | - self, | |
| 230 | - skus: List[Dict[str, Any]], | |
| 231 | - resolved_dimensions: Dict[str, Optional[str]], | |
| 232 | - ) -> List[_SkuCandidate]: | |
| 233 | - if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): | |
| 234 | - return [] | |
| 235 | - | |
| 236 | - candidates: List[_SkuCandidate] = [] | |
| 237 | - for index, sku in enumerate(skus): | |
| 238 | - intent_values: Dict[str, str] = {} | |
| 239 | - normalized_intent_values: Dict[str, str] = {} | |
| 240 | - for intent_type, field_name in resolved_dimensions.items(): | |
| 241 | - if not field_name: | |
| 242 | - continue | |
| 243 | - raw = str(sku.get(field_name) or "").strip() | |
| 244 | - intent_values[intent_type] = raw | |
| 245 | - normalized_intent_values[intent_type] = normalize_query_text(raw) | |
| 246 | - | |
| 247 | - selection_parts: List[str] = [] | |
| 248 | - norm_parts: List[str] = [] | |
| 249 | - seen: set[str] = set() | |
| 250 | - for intent_type, raw in intent_values.items(): | |
| 251 | - nv = normalized_intent_values[intent_type] | |
| 252 | - if not nv or nv in seen: | |
| 253 | - continue | |
| 254 | - seen.add(nv) | |
| 255 | - selection_parts.append(raw) | |
| 256 | - norm_parts.append(nv) | |
| 257 | - | |
| 258 | - selection_text = " ".join(selection_parts).strip() | |
| 259 | - normalized_selection_text = " ".join(norm_parts).strip() | |
| 260 | - candidates.append( | |
| 261 | - _SkuCandidate( | |
| 262 | - index=index, | |
| 263 | - sku_id=str(sku.get("sku_id") or ""), | |
| 264 | - sku=sku, | |
| 265 | - selection_text=selection_text, | |
| 266 | - normalized_selection_text=normalized_selection_text, | |
| 267 | - intent_values=intent_values, | |
| 268 | - normalized_intent_values=normalized_intent_values, | |
| 269 | - ) | |
| 270 | - ) | |
| 271 | - return candidates | |
| 272 | - | |
| 273 | 168 | @staticmethod |
| 274 | 169 | def _empty_decision( |
| 275 | 170 | resolved_dimensions: Dict[str, Optional[str]], |
| ... | ... | @@ -286,13 +181,10 @@ class StyleSkuSelector: |
| 286 | 181 | def _is_text_match( |
| 287 | 182 | self, |
| 288 | 183 | intent_type: str, |
| 289 | - value: str, | |
| 290 | 184 | selection_context: _SelectionContext, |
| 291 | 185 | *, |
| 292 | - normalized_value: Optional[str] = None, | |
| 186 | + normalized_value: str, | |
| 293 | 187 | ) -> bool: |
| 294 | - if normalized_value is None: | |
| 295 | - normalized_value = normalize_query_text(value) | |
| 296 | 188 | if not normalized_value: |
| 297 | 189 | return False |
| 298 | 190 | |
| ... | ... | @@ -301,84 +193,44 @@ class StyleSkuSelector: |
| 301 | 193 | if cached is not None: |
| 302 | 194 | return cached |
| 303 | 195 | |
| 304 | - matched_terms = selection_context.matched_terms_by_intent.get(intent_type, ()) | |
| 305 | - has_term_match = any(term in normalized_value for term in matched_terms if term) | |
| 306 | - query_contains_value = any( | |
| 307 | - normalized_value in query_text | |
| 308 | - for query_text in selection_context.query_texts | |
| 309 | - ) | |
| 310 | - matched = bool(has_term_match or query_contains_value) | |
| 196 | + attribute_terms = selection_context.attribute_terms_by_intent.get(intent_type, ()) | |
| 197 | + matched = any(term in normalized_value for term in attribute_terms if term) | |
| 311 | 198 | selection_context.text_match_cache[cache_key] = matched |
| 312 | 199 | return matched |
| 313 | 200 | |
| 314 | 201 | def _find_first_text_match( |
| 315 | 202 | self, |
| 316 | - candidates: Sequence[_SkuCandidate], | |
| 203 | + skus: List[Dict[str, Any]], | |
| 204 | + resolved_dimensions: Dict[str, Optional[str]], | |
| 317 | 205 | selection_context: _SelectionContext, |
| 318 | - ) -> Optional[_SkuCandidate]: | |
| 319 | - for candidate in candidates: | |
| 320 | - if candidate.intent_values and all( | |
| 321 | - self._is_text_match( | |
| 322 | - intent_type, | |
| 323 | - value, | |
| 324 | - selection_context, | |
| 325 | - normalized_value=candidate.normalized_intent_values[intent_type], | |
| 326 | - ) | |
| 327 | - for intent_type, value in candidate.intent_values.items() | |
| 328 | - ): | |
| 329 | - return candidate | |
| 330 | - return None | |
| 206 | + ) -> Optional[Tuple[str, str]]: | |
| 207 | + for sku in skus: | |
| 208 | + selection_parts: List[str] = [] | |
| 209 | + seen_parts: set[str] = set() | |
| 210 | + matched = True | |
| 331 | 211 | |
| 332 | - def _select_by_embedding( | |
| 333 | - self, | |
| 334 | - candidates: Sequence[_SkuCandidate], | |
| 335 | - selection_context: _SelectionContext, | |
| 336 | - ) -> Tuple[Optional[_SkuCandidate], Optional[float]]: | |
| 337 | - if not candidates: | |
| 338 | - return None, None | |
| 339 | - text_encoder = self._get_text_encoder() | |
| 340 | - if selection_context.query_vector is None or text_encoder is None: | |
| 341 | - return None, None | |
| 342 | - | |
| 343 | - unique_texts = list( | |
| 344 | - dict.fromkeys( | |
| 345 | - candidate.normalized_selection_text | |
| 346 | - for candidate in candidates | |
| 347 | - if candidate.normalized_selection_text | |
| 348 | - and candidate.normalized_selection_text not in selection_context.selection_vector_cache | |
| 349 | - ) | |
| 350 | - ) | |
| 351 | - if unique_texts: | |
| 352 | - vectors = text_encoder.encode(unique_texts, priority=1) | |
| 353 | - for key, vector in zip(unique_texts, vectors): | |
| 354 | - selection_context.selection_vector_cache[key] = ( | |
| 355 | - np.asarray(vector, dtype=np.float32) if vector is not None else None | |
| 356 | - ) | |
| 357 | - | |
| 358 | - best_candidate: Optional[_SkuCandidate] = None | |
| 359 | - best_score: Optional[float] = None | |
| 360 | - query_vector_array = np.asarray(selection_context.query_vector, dtype=np.float32) | |
| 361 | - for candidate in candidates: | |
| 362 | - normalized_text = candidate.normalized_selection_text | |
| 363 | - if not normalized_text: | |
| 364 | - continue | |
| 212 | + for intent_type, field_name in resolved_dimensions.items(): | |
| 213 | + if not field_name: | |
| 214 | + matched = False | |
| 215 | + break | |
| 365 | 216 | |
| 366 | - score = selection_context.similarity_cache.get(normalized_text) | |
| 367 | - if score is None: | |
| 368 | - candidate_vector = selection_context.selection_vector_cache.get(normalized_text) | |
| 369 | - if candidate_vector is None: | |
| 370 | - selection_context.similarity_cache[normalized_text] = None | |
| 371 | - continue | |
| 372 | - score = float(np.inner(query_vector_array, candidate_vector)) | |
| 373 | - selection_context.similarity_cache[normalized_text] = score | |
| 217 | + raw_value = str(sku.get(field_name) or "").strip() | |
| 218 | + normalized_value = self._normalize_cached(selection_context, raw_value) | |
| 219 | + if not self._is_text_match( | |
| 220 | + intent_type, | |
| 221 | + selection_context, | |
| 222 | + normalized_value=normalized_value, | |
| 223 | + ): | |
| 224 | + matched = False | |
| 225 | + break | |
| 374 | 226 | |
| 375 | - if score is None: | |
| 376 | - continue | |
| 377 | - if best_score is None or score > best_score: | |
| 378 | - best_candidate = candidate | |
| 379 | - best_score = score | |
| 227 | + if raw_value and normalized_value not in seen_parts: | |
| 228 | + seen_parts.add(normalized_value) | |
| 229 | + selection_parts.append(raw_value) | |
| 380 | 230 | |
| 381 | - return best_candidate, best_score | |
| 231 | + if matched: | |
| 232 | + return str(sku.get("sku_id") or ""), " ".join(selection_parts).strip() | |
| 233 | + return None | |
| 382 | 234 | |
| 383 | 235 | def _select_for_source( |
| 384 | 236 | self, |
| ... | ... | @@ -395,36 +247,29 @@ class StyleSkuSelector: |
| 395 | 247 | if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): |
| 396 | 248 | return self._empty_decision(resolved_dimensions, matched_stage="unresolved") |
| 397 | 249 | |
| 398 | - candidates = self._build_candidates(skus, resolved_dimensions) | |
| 399 | - if not candidates: | |
| 400 | - return self._empty_decision(resolved_dimensions, matched_stage="no_candidates") | |
| 401 | - | |
| 402 | - text_match = self._find_first_text_match(candidates, selection_context) | |
| 403 | - if text_match is not None: | |
| 404 | - return self._build_decision(text_match, resolved_dimensions, matched_stage="text") | |
| 405 | - | |
| 406 | - chosen, similarity_score = self._select_by_embedding(candidates, selection_context) | |
| 407 | - if chosen is None: | |
| 250 | + text_match = self._find_first_text_match(skus, resolved_dimensions, selection_context) | |
| 251 | + if text_match is None: | |
| 408 | 252 | return self._empty_decision(resolved_dimensions, matched_stage="no_match") |
| 409 | 253 | return self._build_decision( |
| 410 | - chosen, | |
| 411 | - resolved_dimensions, | |
| 412 | - matched_stage="embedding", | |
| 413 | - similarity_score=similarity_score, | |
| 254 | + selected_sku_id=text_match[0], | |
| 255 | + selected_text=text_match[1], | |
| 256 | + resolved_dimensions=resolved_dimensions, | |
| 257 | + matched_stage="text", | |
| 414 | 258 | ) |
| 415 | 259 | |
| 416 | 260 | @staticmethod |
| 417 | 261 | def _build_decision( |
| 418 | - candidate: _SkuCandidate, | |
| 262 | + selected_sku_id: str, | |
| 263 | + selected_text: str, | |
| 419 | 264 | resolved_dimensions: Dict[str, Optional[str]], |
| 420 | 265 | *, |
| 421 | 266 | matched_stage: str, |
| 422 | 267 | similarity_score: Optional[float] = None, |
| 423 | 268 | ) -> SkuSelectionDecision: |
| 424 | 269 | return SkuSelectionDecision( |
| 425 | - selected_sku_id=candidate.sku_id or None, | |
| 426 | - rerank_suffix=str(candidate.selection_text or "").strip(), | |
| 427 | - selected_text=str(candidate.selection_text or "").strip(), | |
| 270 | + selected_sku_id=selected_sku_id or None, | |
| 271 | + rerank_suffix=str(selected_text or "").strip(), | |
| 272 | + selected_text=str(selected_text or "").strip(), | |
| 428 | 273 | matched_stage=matched_stage, |
| 429 | 274 | similarity_score=similarity_score, |
| 430 | 275 | resolved_dimensions=dict(resolved_dimensions), | ... | ... |
| ... | ... | @@ -0,0 +1,452 @@ |
| 1 | +""" | |
| 2 | +SKU selection for style-intent-aware search results. | |
| 3 | +""" | |
| 4 | + | |
| 5 | +from __future__ import annotations | |
| 6 | + | |
| 7 | +from dataclasses import dataclass, field | |
| 8 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple | |
| 9 | + | |
| 10 | +import numpy as np | |
| 11 | + | |
| 12 | +from query.style_intent import StyleIntentProfile, StyleIntentRegistry | |
| 13 | +from query.tokenization import normalize_query_text | |
| 14 | + | |
| 15 | + | |
| 16 | +@dataclass(frozen=True) | |
| 17 | +class SkuSelectionDecision: | |
| 18 | + selected_sku_id: Optional[str] | |
| 19 | + rerank_suffix: str | |
| 20 | + selected_text: str | |
| 21 | + matched_stage: str | |
| 22 | + similarity_score: Optional[float] = None | |
| 23 | + resolved_dimensions: Dict[str, Optional[str]] = field(default_factory=dict) | |
| 24 | + | |
| 25 | + def to_dict(self) -> Dict[str, Any]: | |
| 26 | + return { | |
| 27 | + "selected_sku_id": self.selected_sku_id, | |
| 28 | + "rerank_suffix": self.rerank_suffix, | |
| 29 | + "selected_text": self.selected_text, | |
| 30 | + "matched_stage": self.matched_stage, | |
| 31 | + "similarity_score": self.similarity_score, | |
| 32 | + "resolved_dimensions": dict(self.resolved_dimensions), | |
| 33 | + } | |
| 34 | + | |
| 35 | + | |
| 36 | +@dataclass | |
| 37 | +class _SkuCandidate: | |
| 38 | + index: int | |
| 39 | + sku_id: str | |
| 40 | + sku: Dict[str, Any] | |
| 41 | + selection_text: str | |
| 42 | + normalized_selection_text: str | |
| 43 | + intent_values: Dict[str, str] | |
| 44 | + normalized_intent_values: Dict[str, str] | |
| 45 | + | |
| 46 | + | |
| 47 | +@dataclass | |
| 48 | +class _SelectionContext: | |
| 49 | + query_texts: Tuple[str, ...] | |
| 50 | + matched_terms_by_intent: Dict[str, Tuple[str, ...]] | |
| 51 | + query_vector: Optional[np.ndarray] | |
| 52 | + text_match_cache: Dict[Tuple[str, str], bool] = field(default_factory=dict) | |
| 53 | + selection_vector_cache: Dict[str, Optional[np.ndarray]] = field(default_factory=dict) | |
| 54 | + similarity_cache: Dict[str, Optional[float]] = field(default_factory=dict) | |
| 55 | + | |
| 56 | + | |
| 57 | +class StyleSkuSelector: | |
| 58 | + """Selects the best SKU for an SPU based on detected style intent.""" | |
| 59 | + | |
| 60 | + def __init__( | |
| 61 | + self, | |
| 62 | + registry: StyleIntentRegistry, | |
| 63 | + *, | |
| 64 | + text_encoder_getter: Optional[Callable[[], Any]] = None, | |
| 65 | + ) -> None: | |
| 66 | + self.registry = registry | |
| 67 | + self._text_encoder_getter = text_encoder_getter | |
| 68 | + | |
| 69 | + def prepare_hits( | |
| 70 | + self, | |
| 71 | + es_hits: List[Dict[str, Any]], | |
| 72 | + parsed_query: Any, | |
| 73 | + ) -> Dict[str, SkuSelectionDecision]: | |
| 74 | + decisions: Dict[str, SkuSelectionDecision] = {} | |
| 75 | + style_profile = getattr(parsed_query, "style_intent_profile", None) | |
| 76 | + if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active: | |
| 77 | + return decisions | |
| 78 | + | |
| 79 | + selection_context = self._build_selection_context(parsed_query, style_profile) | |
| 80 | + | |
| 81 | + for hit in es_hits: | |
| 82 | + source = hit.get("_source") | |
| 83 | + if not isinstance(source, dict): | |
| 84 | + continue | |
| 85 | + | |
| 86 | + decision = self._select_for_source( | |
| 87 | + source, | |
| 88 | + style_profile=style_profile, | |
| 89 | + selection_context=selection_context, | |
| 90 | + ) | |
| 91 | + if decision is None: | |
| 92 | + continue | |
| 93 | + | |
| 94 | + if decision.rerank_suffix: | |
| 95 | + hit["_style_rerank_suffix"] = decision.rerank_suffix | |
| 96 | + else: | |
| 97 | + hit.pop("_style_rerank_suffix", None) | |
| 98 | + | |
| 99 | + doc_id = hit.get("_id") | |
| 100 | + if doc_id is not None: | |
| 101 | + decisions[str(doc_id)] = decision | |
| 102 | + | |
| 103 | + return decisions | |
| 104 | + | |
| 105 | + def apply_precomputed_decisions( | |
| 106 | + self, | |
| 107 | + es_hits: List[Dict[str, Any]], | |
| 108 | + decisions: Dict[str, SkuSelectionDecision], | |
| 109 | + ) -> None: | |
| 110 | + if not es_hits or not decisions: | |
| 111 | + return | |
| 112 | + | |
| 113 | + for hit in es_hits: | |
| 114 | + doc_id = hit.get("_id") | |
| 115 | + if doc_id is None: | |
| 116 | + continue | |
| 117 | + decision = decisions.get(str(doc_id)) | |
| 118 | + if decision is None: | |
| 119 | + continue | |
| 120 | + source = hit.get("_source") | |
| 121 | + if not isinstance(source, dict): | |
| 122 | + continue | |
| 123 | + self._apply_decision_to_source(source, decision) | |
| 124 | + if decision.rerank_suffix: | |
| 125 | + hit["_style_rerank_suffix"] = decision.rerank_suffix | |
| 126 | + else: | |
| 127 | + hit.pop("_style_rerank_suffix", None) | |
| 128 | + | |
| 129 | + def _build_query_texts( | |
| 130 | + self, | |
| 131 | + parsed_query: Any, | |
| 132 | + style_profile: StyleIntentProfile, | |
| 133 | + ) -> List[str]: | |
| 134 | + texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text] | |
| 135 | + if texts: | |
| 136 | + return list(dict.fromkeys(texts)) | |
| 137 | + | |
| 138 | + fallbacks: List[str] = [] | |
| 139 | + for value in ( | |
| 140 | + getattr(parsed_query, "original_query", None), | |
| 141 | + getattr(parsed_query, "query_normalized", None), | |
| 142 | + getattr(parsed_query, "rewritten_query", None), | |
| 143 | + ): | |
| 144 | + normalized = normalize_query_text(value) | |
| 145 | + if normalized: | |
| 146 | + fallbacks.append(normalized) | |
| 147 | + translations = getattr(parsed_query, "translations", {}) or {} | |
| 148 | + if isinstance(translations, dict): | |
| 149 | + for value in translations.values(): | |
| 150 | + normalized = normalize_query_text(value) | |
| 151 | + if normalized: | |
| 152 | + fallbacks.append(normalized) | |
| 153 | + return list(dict.fromkeys(fallbacks)) | |
| 154 | + | |
| 155 | + def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]: | |
| 156 | + query_vector = getattr(parsed_query, "query_vector", None) | |
| 157 | + if query_vector is not None: | |
| 158 | + return np.asarray(query_vector, dtype=np.float32) | |
| 159 | + | |
| 160 | + text_encoder = self._get_text_encoder() | |
| 161 | + if text_encoder is None: | |
| 162 | + return None | |
| 163 | + | |
| 164 | + query_text = ( | |
| 165 | + getattr(parsed_query, "rewritten_query", None) | |
| 166 | + or getattr(parsed_query, "query_normalized", None) | |
| 167 | + or getattr(parsed_query, "original_query", None) | |
| 168 | + ) | |
| 169 | + if not query_text: | |
| 170 | + return None | |
| 171 | + | |
| 172 | + vectors = text_encoder.encode([query_text], priority=1) | |
| 173 | + if vectors is None or len(vectors) == 0 or vectors[0] is None: | |
| 174 | + return None | |
| 175 | + return np.asarray(vectors[0], dtype=np.float32) | |
| 176 | + | |
| 177 | + def _build_selection_context( | |
| 178 | + self, | |
| 179 | + parsed_query: Any, | |
| 180 | + style_profile: StyleIntentProfile, | |
| 181 | + ) -> _SelectionContext: | |
| 182 | + matched_terms_by_intent: Dict[str, List[str]] = {} | |
| 183 | + for intent in style_profile.intents: | |
| 184 | + normalized_term = normalize_query_text(intent.matched_term) | |
| 185 | + if not normalized_term: | |
| 186 | + continue | |
| 187 | + matched_terms = matched_terms_by_intent.setdefault(intent.intent_type, []) | |
| 188 | + if normalized_term not in matched_terms: | |
| 189 | + matched_terms.append(normalized_term) | |
| 190 | + | |
| 191 | + return _SelectionContext( | |
| 192 | + query_texts=tuple(self._build_query_texts(parsed_query, style_profile)), | |
| 193 | + matched_terms_by_intent={ | |
| 194 | + intent_type: tuple(terms) | |
| 195 | + for intent_type, terms in matched_terms_by_intent.items() | |
| 196 | + }, | |
| 197 | + query_vector=self._get_query_vector(parsed_query), | |
| 198 | + ) | |
| 199 | + | |
| 200 | + def _get_text_encoder(self) -> Any: | |
| 201 | + if self._text_encoder_getter is None: | |
| 202 | + return None | |
| 203 | + return self._text_encoder_getter() | |
| 204 | + | |
| 205 | + def _resolve_dimensions( | |
| 206 | + self, | |
| 207 | + source: Dict[str, Any], | |
| 208 | + style_profile: StyleIntentProfile, | |
| 209 | + ) -> Dict[str, Optional[str]]: | |
| 210 | + option_names = { | |
| 211 | + "option1_value": normalize_query_text(source.get("option1_name")), | |
| 212 | + "option2_value": normalize_query_text(source.get("option2_name")), | |
| 213 | + "option3_value": normalize_query_text(source.get("option3_name")), | |
| 214 | + } | |
| 215 | + resolved: Dict[str, Optional[str]] = {} | |
| 216 | + for intent in style_profile.intents: | |
| 217 | + if intent.intent_type in resolved: | |
| 218 | + continue | |
| 219 | + aliases = set(intent.dimension_aliases or self.registry.get_dimension_aliases(intent.intent_type)) | |
| 220 | + matched_field = None | |
| 221 | + for field_name, option_name in option_names.items(): | |
| 222 | + if option_name and option_name in aliases: | |
| 223 | + matched_field = field_name | |
| 224 | + break | |
| 225 | + resolved[intent.intent_type] = matched_field | |
| 226 | + return resolved | |
| 227 | + | |
| 228 | + def _build_candidates( | |
| 229 | + self, | |
| 230 | + skus: List[Dict[str, Any]], | |
| 231 | + resolved_dimensions: Dict[str, Optional[str]], | |
| 232 | + ) -> List[_SkuCandidate]: | |
| 233 | + if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): | |
| 234 | + return [] | |
| 235 | + | |
| 236 | + candidates: List[_SkuCandidate] = [] | |
| 237 | + for index, sku in enumerate(skus): | |
| 238 | + intent_values: Dict[str, str] = {} | |
| 239 | + normalized_intent_values: Dict[str, str] = {} | |
| 240 | + for intent_type, field_name in resolved_dimensions.items(): | |
| 241 | + if not field_name: | |
| 242 | + continue | |
| 243 | + raw = str(sku.get(field_name) or "").strip() | |
| 244 | + intent_values[intent_type] = raw | |
| 245 | + normalized_intent_values[intent_type] = normalize_query_text(raw) | |
| 246 | + | |
| 247 | + selection_parts: List[str] = [] | |
| 248 | + norm_parts: List[str] = [] | |
| 249 | + seen: set[str] = set() | |
| 250 | + for intent_type, raw in intent_values.items(): | |
| 251 | + nv = normalized_intent_values[intent_type] | |
| 252 | + if not nv or nv in seen: | |
| 253 | + continue | |
| 254 | + seen.add(nv) | |
| 255 | + selection_parts.append(raw) | |
| 256 | + norm_parts.append(nv) | |
| 257 | + | |
| 258 | + selection_text = " ".join(selection_parts).strip() | |
| 259 | + normalized_selection_text = " ".join(norm_parts).strip() | |
| 260 | + candidates.append( | |
| 261 | + _SkuCandidate( | |
| 262 | + index=index, | |
| 263 | + sku_id=str(sku.get("sku_id") or ""), | |
| 264 | + sku=sku, | |
| 265 | + selection_text=selection_text, | |
| 266 | + normalized_selection_text=normalized_selection_text, | |
| 267 | + intent_values=intent_values, | |
| 268 | + normalized_intent_values=normalized_intent_values, | |
| 269 | + ) | |
| 270 | + ) | |
| 271 | + return candidates | |
| 272 | + | |
| 273 | + @staticmethod | |
| 274 | + def _empty_decision( | |
| 275 | + resolved_dimensions: Dict[str, Optional[str]], | |
| 276 | + matched_stage: str, | |
| 277 | + ) -> SkuSelectionDecision: | |
| 278 | + return SkuSelectionDecision( | |
| 279 | + selected_sku_id=None, | |
| 280 | + rerank_suffix="", | |
| 281 | + selected_text="", | |
| 282 | + matched_stage=matched_stage, | |
| 283 | + resolved_dimensions=dict(resolved_dimensions), | |
| 284 | + ) | |
| 285 | + | |
| 286 | + def _is_text_match( | |
| 287 | + self, | |
| 288 | + intent_type: str, | |
| 289 | + value: str, | |
| 290 | + selection_context: _SelectionContext, | |
| 291 | + *, | |
| 292 | + normalized_value: Optional[str] = None, | |
| 293 | + ) -> bool: | |
| 294 | + if normalized_value is None: | |
| 295 | + normalized_value = normalize_query_text(value) | |
| 296 | + if not normalized_value: | |
| 297 | + return False | |
| 298 | + | |
| 299 | + cache_key = (intent_type, normalized_value) | |
| 300 | + cached = selection_context.text_match_cache.get(cache_key) | |
| 301 | + if cached is not None: | |
| 302 | + return cached | |
| 303 | + | |
| 304 | + matched_terms = selection_context.matched_terms_by_intent.get(intent_type, ()) | |
| 305 | + has_term_match = any(term in normalized_value for term in matched_terms if term) | |
| 306 | + query_contains_value = any( | |
| 307 | + normalized_value in query_text | |
| 308 | + for query_text in selection_context.query_texts | |
| 309 | + ) | |
| 310 | + matched = bool(has_term_match or query_contains_value) | |
| 311 | + selection_context.text_match_cache[cache_key] = matched | |
| 312 | + return matched | |
| 313 | + | |
| 314 | + def _find_first_text_match( | |
| 315 | + self, | |
| 316 | + candidates: Sequence[_SkuCandidate], | |
| 317 | + selection_context: _SelectionContext, | |
| 318 | + ) -> Optional[_SkuCandidate]: | |
| 319 | + for candidate in candidates: | |
| 320 | + if candidate.intent_values and all( | |
| 321 | + self._is_text_match( | |
| 322 | + intent_type, | |
| 323 | + value, | |
| 324 | + selection_context, | |
| 325 | + normalized_value=candidate.normalized_intent_values[intent_type], | |
| 326 | + ) | |
| 327 | + for intent_type, value in candidate.intent_values.items() | |
| 328 | + ): | |
| 329 | + return candidate | |
| 330 | + return None | |
| 331 | + | |
| 332 | + def _select_by_embedding( | |
| 333 | + self, | |
| 334 | + candidates: Sequence[_SkuCandidate], | |
| 335 | + selection_context: _SelectionContext, | |
| 336 | + ) -> Tuple[Optional[_SkuCandidate], Optional[float]]: | |
| 337 | + if not candidates: | |
| 338 | + return None, None | |
| 339 | + text_encoder = self._get_text_encoder() | |
| 340 | + if selection_context.query_vector is None or text_encoder is None: | |
| 341 | + return None, None | |
| 342 | + | |
| 343 | + unique_texts = list( | |
| 344 | + dict.fromkeys( | |
| 345 | + candidate.normalized_selection_text | |
| 346 | + for candidate in candidates | |
| 347 | + if candidate.normalized_selection_text | |
| 348 | + and candidate.normalized_selection_text not in selection_context.selection_vector_cache | |
| 349 | + ) | |
| 350 | + ) | |
| 351 | + if unique_texts: | |
| 352 | + vectors = text_encoder.encode(unique_texts, priority=1) | |
| 353 | + for key, vector in zip(unique_texts, vectors): | |
| 354 | + selection_context.selection_vector_cache[key] = ( | |
| 355 | + np.asarray(vector, dtype=np.float32) if vector is not None else None | |
| 356 | + ) | |
| 357 | + | |
| 358 | + best_candidate: Optional[_SkuCandidate] = None | |
| 359 | + best_score: Optional[float] = None | |
| 360 | + query_vector_array = np.asarray(selection_context.query_vector, dtype=np.float32) | |
| 361 | + for candidate in candidates: | |
| 362 | + normalized_text = candidate.normalized_selection_text | |
| 363 | + if not normalized_text: | |
| 364 | + continue | |
| 365 | + | |
| 366 | + score = selection_context.similarity_cache.get(normalized_text) | |
| 367 | + if score is None: | |
| 368 | + candidate_vector = selection_context.selection_vector_cache.get(normalized_text) | |
| 369 | + if candidate_vector is None: | |
| 370 | + selection_context.similarity_cache[normalized_text] = None | |
| 371 | + continue | |
| 372 | + score = float(np.inner(query_vector_array, candidate_vector)) | |
| 373 | + selection_context.similarity_cache[normalized_text] = score | |
| 374 | + | |
| 375 | + if score is None: | |
| 376 | + continue | |
| 377 | + if best_score is None or score > best_score: | |
| 378 | + best_candidate = candidate | |
| 379 | + best_score = score | |
| 380 | + | |
| 381 | + return best_candidate, best_score | |
| 382 | + | |
| 383 | + def _select_for_source( | |
| 384 | + self, | |
| 385 | + source: Dict[str, Any], | |
| 386 | + *, | |
| 387 | + style_profile: StyleIntentProfile, | |
| 388 | + selection_context: _SelectionContext, | |
| 389 | + ) -> Optional[SkuSelectionDecision]: | |
| 390 | + skus = source.get("skus") | |
| 391 | + if not isinstance(skus, list) or not skus: | |
| 392 | + return None | |
| 393 | + | |
| 394 | + resolved_dimensions = self._resolve_dimensions(source, style_profile) | |
| 395 | + if not resolved_dimensions or any(not field_name for field_name in resolved_dimensions.values()): | |
| 396 | + return self._empty_decision(resolved_dimensions, matched_stage="unresolved") | |
| 397 | + | |
| 398 | + candidates = self._build_candidates(skus, resolved_dimensions) | |
| 399 | + if not candidates: | |
| 400 | + return self._empty_decision(resolved_dimensions, matched_stage="no_candidates") | |
| 401 | + | |
| 402 | + text_match = self._find_first_text_match(candidates, selection_context) | |
| 403 | + if text_match is not None: | |
| 404 | + return self._build_decision(text_match, resolved_dimensions, matched_stage="text") | |
| 405 | + | |
| 406 | + chosen, similarity_score = self._select_by_embedding(candidates, selection_context) | |
| 407 | + if chosen is None: | |
| 408 | + return self._empty_decision(resolved_dimensions, matched_stage="no_match") | |
| 409 | + return self._build_decision( | |
| 410 | + chosen, | |
| 411 | + resolved_dimensions, | |
| 412 | + matched_stage="embedding", | |
| 413 | + similarity_score=similarity_score, | |
| 414 | + ) | |
| 415 | + | |
| 416 | + @staticmethod | |
| 417 | + def _build_decision( | |
| 418 | + candidate: _SkuCandidate, | |
| 419 | + resolved_dimensions: Dict[str, Optional[str]], | |
| 420 | + *, | |
| 421 | + matched_stage: str, | |
| 422 | + similarity_score: Optional[float] = None, | |
| 423 | + ) -> SkuSelectionDecision: | |
| 424 | + return SkuSelectionDecision( | |
| 425 | + selected_sku_id=candidate.sku_id or None, | |
| 426 | + rerank_suffix=str(candidate.selection_text or "").strip(), | |
| 427 | + selected_text=str(candidate.selection_text or "").strip(), | |
| 428 | + matched_stage=matched_stage, | |
| 429 | + similarity_score=similarity_score, | |
| 430 | + resolved_dimensions=dict(resolved_dimensions), | |
| 431 | + ) | |
| 432 | + | |
| 433 | + @staticmethod | |
| 434 | + def _apply_decision_to_source(source: Dict[str, Any], decision: SkuSelectionDecision) -> None: | |
| 435 | + skus = source.get("skus") | |
| 436 | + if not isinstance(skus, list) or not skus or not decision.selected_sku_id: | |
| 437 | + return | |
| 438 | + | |
| 439 | + selected_index = None | |
| 440 | + for index, sku in enumerate(skus): | |
| 441 | + if str(sku.get("sku_id") or "") == decision.selected_sku_id: | |
| 442 | + selected_index = index | |
| 443 | + break | |
| 444 | + if selected_index is None: | |
| 445 | + return | |
| 446 | + | |
| 447 | + selected_sku = skus.pop(selected_index) | |
| 448 | + skus.insert(0, selected_sku) | |
| 449 | + | |
| 450 | + image_src = selected_sku.get("image_src") or selected_sku.get("imageSrc") | |
| 451 | + if image_src: | |
| 452 | + source["image_url"] = image_src | ... | ... |
tests/test_search_rerank_window.py
| ... | ... | @@ -63,6 +63,7 @@ def _build_style_intent_profile(intent_type: str, canonical_value: str, *dimensi |
| 63 | 63 | canonical_value=canonical_value, |
| 64 | 64 | matched_term=canonical_value, |
| 65 | 65 | matched_query_text=canonical_value, |
| 66 | + attribute_terms=(canonical_value,), | |
| 66 | 67 | dimension_aliases=tuple(aliases), |
| 67 | 68 | ), |
| 68 | 69 | ) | ... | ... |
| ... | ... | @@ -0,0 +1,106 @@ |
| 1 | +from types import SimpleNamespace | |
| 2 | + | |
| 3 | +from config import QueryConfig | |
| 4 | +from query.style_intent import DetectedStyleIntent, StyleIntentProfile, StyleIntentRegistry | |
| 5 | +from search.sku_intent_selector import StyleSkuSelector | |
| 6 | + | |
| 7 | + | |
| 8 | +def test_style_sku_selector_matches_first_sku_by_attribute_terms(): | |
| 9 | + registry = StyleIntentRegistry.from_query_config( | |
| 10 | + QueryConfig( | |
| 11 | + style_intent_terms={ | |
| 12 | + "color": [{"en_terms": ["navy"], "zh_terms": ["藏青"], "attribute_terms": ["navy"]}], | |
| 13 | + "size": [{"en_terms": ["xl"], "zh_terms": ["加大码"], "attribute_terms": ["x-large"]}], | |
| 14 | + }, | |
| 15 | + style_intent_dimension_aliases={ | |
| 16 | + "color": ["color", "颜色"], | |
| 17 | + "size": ["size", "尺码"], | |
| 18 | + }, | |
| 19 | + ) | |
| 20 | + ) | |
| 21 | + selector = StyleSkuSelector(registry) | |
| 22 | + parsed_query = SimpleNamespace( | |
| 23 | + style_intent_profile=StyleIntentProfile( | |
| 24 | + intents=( | |
| 25 | + DetectedStyleIntent( | |
| 26 | + intent_type="color", | |
| 27 | + canonical_value="navy", | |
| 28 | + matched_term="藏青", | |
| 29 | + matched_query_text="藏青", | |
| 30 | + attribute_terms=("navy",), | |
| 31 | + dimension_aliases=("color", "颜色"), | |
| 32 | + ), | |
| 33 | + DetectedStyleIntent( | |
| 34 | + intent_type="size", | |
| 35 | + canonical_value="x-large", | |
| 36 | + matched_term="xl", | |
| 37 | + matched_query_text="xl", | |
| 38 | + attribute_terms=("x-large",), | |
| 39 | + dimension_aliases=("size", "尺码"), | |
| 40 | + ), | |
| 41 | + ), | |
| 42 | + ) | |
| 43 | + ) | |
| 44 | + source = { | |
| 45 | + "option1_name": "Color", | |
| 46 | + "option2_name": "Size", | |
| 47 | + "skus": [ | |
| 48 | + {"sku_id": "1", "option1_value": "Black", "option2_value": "M"}, | |
| 49 | + {"sku_id": "2", "option1_value": "Navy Blue", "option2_value": "X-Large", "image_src": "matched.jpg"}, | |
| 50 | + {"sku_id": "3", "option1_value": "Navy", "option2_value": "XL"}, | |
| 51 | + ], | |
| 52 | + } | |
| 53 | + hits = [{"_id": "spu-1", "_source": source}] | |
| 54 | + | |
| 55 | + decisions = selector.prepare_hits(hits, parsed_query) | |
| 56 | + decision = decisions["spu-1"] | |
| 57 | + | |
| 58 | + assert decision.selected_sku_id == "2" | |
| 59 | + assert decision.selected_text == "Navy Blue X-Large" | |
| 60 | + assert decision.matched_stage == "text" | |
| 61 | + | |
| 62 | + selector.apply_precomputed_decisions(hits, decisions) | |
| 63 | + | |
| 64 | + assert source["skus"][0]["sku_id"] == "2" | |
| 65 | + assert source["image_url"] == "matched.jpg" | |
| 66 | + | |
| 67 | + | |
| 68 | +def test_style_sku_selector_returns_no_match_without_attribute_contains(): | |
| 69 | + registry = StyleIntentRegistry.from_query_config( | |
| 70 | + QueryConfig( | |
| 71 | + style_intent_terms={ | |
| 72 | + "color": [{"en_terms": ["beige"], "zh_terms": ["米色"], "attribute_terms": ["beige"]}], | |
| 73 | + }, | |
| 74 | + style_intent_dimension_aliases={"color": ["color", "颜色"]}, | |
| 75 | + ) | |
| 76 | + ) | |
| 77 | + selector = StyleSkuSelector(registry) | |
| 78 | + parsed_query = SimpleNamespace( | |
| 79 | + style_intent_profile=StyleIntentProfile( | |
| 80 | + intents=( | |
| 81 | + DetectedStyleIntent( | |
| 82 | + intent_type="color", | |
| 83 | + canonical_value="beige", | |
| 84 | + matched_term="米色", | |
| 85 | + matched_query_text="米色", | |
| 86 | + attribute_terms=("beige",), | |
| 87 | + dimension_aliases=("color", "颜色"), | |
| 88 | + ), | |
| 89 | + ), | |
| 90 | + ) | |
| 91 | + ) | |
| 92 | + hits = [{ | |
| 93 | + "_id": "spu-1", | |
| 94 | + "_source": { | |
| 95 | + "option1_name": "Color", | |
| 96 | + "skus": [ | |
| 97 | + {"sku_id": "1", "option1_value": "Khaki"}, | |
| 98 | + {"sku_id": "2", "option1_value": "Light Brown"}, | |
| 99 | + ], | |
| 100 | + }, | |
| 101 | + }] | |
| 102 | + | |
| 103 | + decisions = selector.prepare_hits(hits, parsed_query) | |
| 104 | + | |
| 105 | + assert decisions["spu-1"].selected_sku_id is None | |
| 106 | + assert decisions["spu-1"].matched_stage == "no_match" | ... | ... |
tests/test_style_intent.py
| ... | ... | @@ -7,8 +7,8 @@ from query.style_intent import StyleIntentDetector, StyleIntentRegistry |
| 7 | 7 | def test_style_intent_detector_matches_original_and_translated_queries(): |
| 8 | 8 | query_config = QueryConfig( |
| 9 | 9 | style_intent_terms={ |
| 10 | - "color": [["black", "黑色", "black"]], | |
| 11 | - "size": [["xl", "x-large", "加大码"]], | |
| 10 | + "color": [{"en_terms": ["black"], "zh_terms": ["黑色"], "attribute_terms": ["black"]}], | |
| 11 | + "size": [{"en_terms": ["xl", "x-large"], "zh_terms": ["加大码"], "attribute_terms": ["x-large"]}], | |
| 12 | 12 | }, |
| 13 | 13 | style_intent_dimension_aliases={ |
| 14 | 14 | "color": ["color", "颜色"], |
| ... | ... | @@ -31,5 +31,30 @@ def test_style_intent_detector_matches_original_and_translated_queries(): |
| 31 | 31 | |
| 32 | 32 | assert profile.is_active is True |
| 33 | 33 | assert profile.get_canonical_values("color") == {"black"} |
| 34 | - assert profile.get_canonical_values("size") == {"xl"} | |
| 34 | + assert profile.get_canonical_values("size") == {"x-large"} | |
| 35 | 35 | assert len(profile.query_variants) == 2 |
| 36 | + | |
| 37 | + | |
| 38 | +def test_style_intent_detector_uses_original_query_when_language_translation_missing(): | |
| 39 | + query_config = QueryConfig( | |
| 40 | + style_intent_terms={ | |
| 41 | + "color": [{"en_terms": ["black"], "zh_terms": ["黑色"], "attribute_terms": ["black"]}], | |
| 42 | + }, | |
| 43 | + style_intent_dimension_aliases={"color": ["color", "颜色"]}, | |
| 44 | + ) | |
| 45 | + detector = StyleIntentDetector( | |
| 46 | + StyleIntentRegistry.from_query_config(query_config), | |
| 47 | + tokenizer=lambda text: text.split(), | |
| 48 | + ) | |
| 49 | + | |
| 50 | + parsed_query = SimpleNamespace( | |
| 51 | + original_query="black dress", | |
| 52 | + query_normalized="black dress", | |
| 53 | + rewritten_query="black dress", | |
| 54 | + translations={"zh": "连衣裙"}, | |
| 55 | + ) | |
| 56 | + | |
| 57 | + profile = detector.detect(parsed_query) | |
| 58 | + | |
| 59 | + assert profile.get_canonical_values("color") == {"black"} | |
| 60 | + assert profile.intents[0].attribute_terms == ("black",) | ... | ... |