From cda1cd6231ec713689f779d3a0f464b582f47110 Mon Sep 17 00:00:00 2001 From: tangwang Date: Mon, 23 Mar 2026 22:35:20 +0800 Subject: [PATCH] 意图分析&应用 baseline --- config/config.yaml | 12 ++++++++++-- config/dictionaries/style_intent_color.csv | 15 +++++++++++++++ config/dictionaries/style_intent_size.csv | 8 ++++++++ config/loader.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ config/schema.py | 3 +++ docs/TODO-意图判断.md | 12 ++++++++++++ docs/数据统计/options名称和取值统计.md | 118 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ query/query_parser.py | 65 +++++++++++++++++++++++++++++------------------------------------ query/style_intent.py | 261 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ query/tokenization.py | 122 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ search/rerank_client.py | 12 ++++++++++-- search/searcher.py | 332 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- search/sku_intent_selector.py | 405 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_search_rerank_window.py | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- tests/test_style_intent.py | 35 +++++++++++++++++++++++++++++++++++ 15 files changed, 1254 insertions(+), 292 deletions(-) create mode 100644 config/dictionaries/style_intent_color.csv create mode 100644 config/dictionaries/style_intent_size.csv create mode 100644 docs/数据统计/options名称和取值统计.md create mode 100644 query/style_intent.py create mode 100644 query/tokenization.py create mode 100644 search/sku_intent_selector.py create mode 100644 tests/test_style_intent.py diff --git a/config/config.yaml b/config/config.yaml index e9d9349..5335ebc 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -17,9 +17,9 @@ runtime: embedding_port: 6005 embedding_text_port: 6005 embedding_image_port: 6008 - translator_host: "127.0.0.1" + translator_host: "0.0.0.0" translator_port: 6006 - reranker_host: "127.0.0.1" + reranker_host: "0.0.0.0" reranker_port: 6007 # 基础设施连接(敏感项优先读环境变量:ES_*、REDIS_*、DB_*、DASHSCOPE_API_KEY、DEEPL_AUTH_KEY) @@ -116,6 +116,14 @@ query_config: translation_embedding_wait_budget_ms_source_in_index: 500 # 80 translation_embedding_wait_budget_ms_source_not_in_index: 500 #200 + style_intent: + enabled: true + color_dictionary_path: "config/dictionaries/style_intent_color.csv" + size_dictionary_path: "config/dictionaries/style_intent_size.csv" + dimension_aliases: + color: ["color", "colors", "colour", "colours", "颜色", "色", "色系"] + size: ["size", "sizes", "sizing", "尺码", "尺寸", "码数", "号码", "码"] + # 动态多语言检索字段配置 # multilingual_fields 会被拼成 title.{lang}/brief.{lang}/... 形式; # shared_fields 为无语言后缀字段。 diff --git a/config/dictionaries/style_intent_color.csv b/config/dictionaries/style_intent_color.csv new file mode 100644 index 0000000..4068f18 --- /dev/null +++ b/config/dictionaries/style_intent_color.csv @@ -0,0 +1,15 @@ +black,black,blk,黑,黑色 +white,white,wht,白,白色 +red,red,reddish,红,红色 +blue,blue,blu,蓝,蓝色 +green,green,grn,绿,绿色 +yellow,yellow,ylw,黄,黄色 +pink,pink,粉,粉色 +purple,purple,violet,紫,紫色 +gray,gray,grey,灰,灰色 +brown,brown,棕,棕色,咖啡色 +beige,beige,khaki,米色,卡其色 +navy,navy,navy blue,藏青,藏蓝,深蓝 +silver,silver,银,银色 +gold,gold,金,金色 +orange,orange,橙,橙色 diff --git a/config/dictionaries/style_intent_size.csv b/config/dictionaries/style_intent_size.csv new file mode 100644 index 0000000..011dc26 --- /dev/null +++ b/config/dictionaries/style_intent_size.csv @@ -0,0 +1,8 @@ +xs,xs,extra small,x-small,加小码 +s,s,small,小码,小号 +m,m,medium,中码,中号 +l,l,large,大码,大号 +xl,xl,x-large,extra large,加大码 +xxl,xxl,2xl,xx-large,双加大码 +xxxl,xxxl,3xl,xxx-large,三加大码 +one size,one size,onesize,free size,均码 diff --git a/config/loader.py b/config/loader.py index fa49031..584a37d 100644 --- a/config/loader.py +++ b/config/loader.py @@ -95,6 +95,29 @@ def _read_rewrite_dictionary(path: Path) -> Dict[str, str]: return rewrite_dict +def _read_synonym_csv_dictionary(path: Path) -> List[List[str]]: + rows: List[List[str]] = [] + if not path.exists(): + return rows + + with open(path, "r", encoding="utf-8") as handle: + for raw_line in handle: + line = raw_line.strip() + if not line or line.startswith("#"): + continue + parts = [segment.strip() for segment in line.split(",")] + normalized = [segment for segment in parts if segment] + if normalized: + rows.append(normalized) + return rows + + +_DEFAULT_STYLE_INTENT_DIMENSION_ALIASES: Dict[str, List[str]] = { + "color": ["color", "colors", "colour", "colours", "颜色", "色", "色系"], + "size": ["size", "sizes", "sizing", "尺码", "尺寸", "码数", "号码", "码"], +} + + class AppConfigLoader: """Load the unified application configuration.""" @@ -253,6 +276,45 @@ class AppConfigLoader: if isinstance(query_cfg.get("text_query_strategy"), dict) else {} ) + style_intent_cfg = ( + query_cfg.get("style_intent") + if isinstance(query_cfg.get("style_intent"), dict) + else {} + ) + + def _resolve_project_path(value: Any, default_path: Path) -> Path: + if value in (None, ""): + return default_path + candidate = Path(str(value)) + if candidate.is_absolute(): + return candidate + return self.project_root / candidate + + style_color_path = _resolve_project_path( + style_intent_cfg.get("color_dictionary_path"), + self.config_dir / "dictionaries" / "style_intent_color.csv", + ) + style_size_path = _resolve_project_path( + style_intent_cfg.get("size_dictionary_path"), + self.config_dir / "dictionaries" / "style_intent_size.csv", + ) + configured_dimension_aliases = ( + style_intent_cfg.get("dimension_aliases") + if isinstance(style_intent_cfg.get("dimension_aliases"), dict) + else {} + ) + style_dimension_aliases: Dict[str, List[str]] = {} + for intent_type, default_aliases in _DEFAULT_STYLE_INTENT_DIMENSION_ALIASES.items(): + aliases = configured_dimension_aliases.get(intent_type) + if isinstance(aliases, list) and aliases: + style_dimension_aliases[intent_type] = [str(alias) for alias in aliases if str(alias).strip()] + else: + style_dimension_aliases[intent_type] = list(default_aliases) + + style_intent_terms = { + "color": _read_synonym_csv_dictionary(style_color_path), + "size": _read_synonym_csv_dictionary(style_size_path), + } query_config = QueryConfig( supported_languages=list(query_cfg.get("supported_languages") or ["zh", "en"]), default_language=str(query_cfg.get("default_language") or "en"), @@ -324,6 +386,9 @@ class AppConfigLoader: translation_embedding_wait_budget_ms_source_not_in_index=int( query_cfg.get("translation_embedding_wait_budget_ms_source_not_in_index", 200) ), + style_intent_enabled=bool(style_intent_cfg.get("enabled", True)), + style_intent_terms=style_intent_terms, + style_intent_dimension_aliases=style_dimension_aliases, ) function_score_cfg = raw.get("function_score") if isinstance(raw.get("function_score"), dict) else {} diff --git a/config/schema.py b/config/schema.py index 60ac0f1..690c2b1 100644 --- a/config/schema.py +++ b/config/schema.py @@ -64,6 +64,9 @@ class QueryConfig: # 检测语言不在 index_languages 内:翻译对召回更关键,预算较长。 translation_embedding_wait_budget_ms_source_in_index: int = 80 translation_embedding_wait_budget_ms_source_not_in_index: int = 200 + style_intent_enabled: bool = True + style_intent_terms: Dict[str, List[List[str]]] = field(default_factory=dict) + style_intent_dimension_aliases: Dict[str, List[str]] = field(default_factory=dict) @dataclass(frozen=True) diff --git a/docs/TODO-意图判断.md b/docs/TODO-意图判断.md index 8b72fb7..4686d4d 100644 --- a/docs/TODO-意图判断.md +++ b/docs/TODO-意图判断.md @@ -39,3 +39,15 @@ intent 考虑由 QueryParser 编排、具体实现拆成独立模块,主义好 5. TODO: 搜索接口里,results[].skus 不是全量子 SKU:由 sku_filter_dimension 控制在应用层按维度分组折叠,每个「维度取值组合」只保留一条 SKU(组内第一条)。请求未传该字段时,Pydantic 默认是 ["option1"],等价于只按 option1_value 去重;服务端不会读取店铺主题的「主展示维」,需调用方与装修配置对齐并传入正确维度。因此当用户有款式等更细粒度意图、而款式落在 option2/option3(或对应 option*_name)时,若仍用默认只按 option1(常见为颜色)折叠,同一颜色下多种款式只会出现一条代表 SKU,无法从返回的 skus 里拿到该颜色下的全部款式行。(若业务需要全量子款,需传包含对应维度的 sku_filter_dimension,或传 null/[] 跳过折叠——以当前 ResultFormatter 实现为准。) + + + + +GPT5.4 主要思路: +我准备按这条线改,尽量把逻辑收束成两个新模块而不是继续堆在 Searcher 里: + +QueryParser 负责产出统一的“款式意图分析结果”,包含:命中的意图类型、标准化值、命中的 query 变体、分词结果、维度候选词。 +独立的 SKU 选择模块负责:维度解析、候选值抽取、三阶段匹配、选中 SKU 后的主图替换,以及生成 rerank 用的属性后缀。 +Searcher 只编排:解析 query -> 首次 ES 召回 -> 若有款式意图则对 rerank 窗口内全部 hits 做 SKU 预筛选并补充 rerank doc 文本 -> run_rerank -> 分页 -> page fill -> ResultFormatter。这样旧的页内 SKU 置顶逻辑可以下沉或删除,避免前后两个阶段重复做同一件事。我再补看一下你最近写的意图草稿文档,确认有没有现成词表示例或边界条件,避免我这边把词表格式和你预期做偏。 + + diff --git a/docs/数据统计/options名称和取值统计.md b/docs/数据统计/options名称和取值统计.md new file mode 100644 index 0000000..8a4e8c5 --- /dev/null +++ b/docs/数据统计/options名称和取值统计.md @@ -0,0 +1,118 @@ + +## prompt +参考 @docs/常用查询 - ES.md 帮我查询 tenant_Id=163的@search_products.json (1314-1332) 这些字段的统计分布,也就是这里每个字段的所有的值及其对应的文档数 + + +已按 `docs/常用查询 - ES.md` 的写法,在索引 **`search_products_tenant_163`** 上做了 `terms` 聚合(租户已按索引拆分,用 `match_all` 即可,等价于只查 tenant 163)。集群在本机 `localhost:9200` 上执行成功。 + +## 结果概要 + +- **总文档数**:8977 +- 六个字段的 **`sum_other_doc_count` 均为 0**,在 `size: 20000` 下没有截断,即:**当前数据里每个字段的取值都已全部列出**。 +- **`__missing__`**:来自聚合里的 `"missing": "__missing__"`,表示该字段缺失或为空的文档数(便于和「有值」一起看)。 +- **`option1_values` / `option2_values`** 多为多值数组:同一篇文档可落入多个桶,因此 **各桶 `doc_count` 之和会大于 8977**(这是 ES `terms` 的正常语义,表示「含该取值的文档数」,不是互斥分区)。 +- **`option1_name` / `option2_name` / `option3_name`** 及 **`option3_values`** 的桶计数之和与文档关系更接近单值/稀疏(见下方校验)。 + +### `option1_name`(13 个取值) + +| 值 | 文档数 | +|---:|---:| +| __missing__ | 4538 | +| Size | 2916 | +| Color | 1174 | +| 颜色 | 244 | +| COLOR | 56 | +| color | 16 | +| Colour | 15 | +| Variant | 9 | +| Style | 3 | +| Colors | 2 | +| Scent | 2 | +| Cup Size | 1 | +| Pattern Name | 1 | + +### `option2_name`(13 个取值) + +| 值 | 文档数 | +|---:|---:| +| __missing__ | 4666 | +| Color | 2879 | +| Size | 1134 | +| 尺码 | 244 | +| SIZE | 17 | +| size | 16 | +| Style | 12 | +| Cup Size | 4 | +| Item Package Quantity | 1 | +| Number of Items | 1 | +| Ring Size | 1 | +| Scent | 1 | +| Team Name | 1 | + +### `option3_name`(11 个取值) + +| 值 | 文档数 | +|---:|---:| +| __missing__ | 8831 | +| Fit Type | 52 | +| Special Size | 31 | +| Number of Items | 17 | +| Special Size Type | 16 | +| Team Name | 15 | +| Material Type | 8 | +| Style | 4 | +| Configuration | 1 | +| Hand Orientation | 1 | +| Pattern Name | 1 | + +### `option1_values` / `option2_values` + +- **桶数量**:`option1_values` **3463**,`option2_values` **3605**(完整列表过长,不宜贴在对话里)。 +- **文档中出现最多的取值(前 10)** + - **option1_values**:`__missing__` 4538,其后如 Medium 749、Black 653、Large 635、8/9/White/10/7/Small… + - **option2_values**:`__missing__` 4666,其后如 L 1161、M 1158、S 1064、XL 1036、Black 876… + +### `option3_values`(94 个桶,已全部枚举) + +除 `__missing__` 8831 外,另有 93 个具体取值;上文 Markdown 表里已列全(从 Standard 到 `standard` 等)。 + +--- + +## 可复用的查询(与文档一致) + +```bash +curl -sS -u 'saas:4hOaLaf41y2VuI8y' -X GET 'http://localhost:9200/search_products_tenant_163/_search?pretty' \ + -H 'Content-Type: application/json' -d '{ + "query": { "match_all": {} }, + "size": 0, + "track_total_hits": true, + "aggs": { + "option1_name_dist": { + "terms": { "field": "option1_name", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } + }, + "option2_name_dist": { + "terms": { "field": "option2_name", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } + }, + "option3_name_dist": { + "terms": { "field": "option3_name", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } + }, + "option1_values_dist": { + "terms": { "field": "option1_values", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } + }, + "option2_values_dist": { + "terms": { "field": "option2_values", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } + }, + "option3_values_dist": { + "terms": { "field": "option3_values", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } + } + } +}' +``` + +若某租户取值超过 2 万,需把 `size` 再加大或用 **composite** 分页扫完全部桶。 + +## 完整 JSON(含 3463 + 3605 条 value 分布) + +已整理为:`/tmp/tenant163_option_fields_distribution.json`(约 620KB),结构为按字段分组的 `values: [{ "value", "doc_count" }, ...]`,便于你用脚本或表格工具打开。 + +如需去掉 `__missing__` 桶,从请求里删掉各聚合中的 `"missing": "__missing__"` 即可。 \ No newline at end of file diff --git a/query/query_parser.py b/query/query_parser.py index 655d827..0063308 100644 --- a/query/query_parser.py +++ b/query/query_parser.py @@ -12,7 +12,6 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import logging -import re from concurrent.futures import ThreadPoolExecutor, wait from embeddings.text_encoder import TextEmbeddingEncoder @@ -20,25 +19,14 @@ from config import SearchConfig from translation import create_translation_client from .language_detector import LanguageDetector from .query_rewriter import QueryRewriter, QueryNormalizer +from .style_intent import StyleIntentDetector, StyleIntentProfile, StyleIntentRegistry +from .tokenization import extract_token_strings, simple_tokenize_query logger = logging.getLogger(__name__) import hanlp # type: ignore -def simple_tokenize_query(text: str) -> List[str]: - """ - Lightweight tokenizer for suggestion-side heuristics only. - - - Consecutive CJK characters form one token - - Latin / digit runs (with internal hyphens) form tokens - """ - if not text: - return [] - pattern = re.compile(r"[\u4e00-\u9fff]+|[A-Za-z0-9_]+(?:-[A-Za-z0-9_]+)*") - return pattern.findall(text) - - @dataclass(slots=True) class ParsedQuery: """Container for query parser facts.""" @@ -50,6 +38,7 @@ class ParsedQuery: translations: Dict[str, str] = field(default_factory=dict) query_vector: Optional[np.ndarray] = None query_tokens: List[str] = field(default_factory=list) + style_intent_profile: Optional[StyleIntentProfile] = None def to_dict(self) -> Dict[str, Any]: """Convert to dictionary representation.""" @@ -60,6 +49,9 @@ class ParsedQuery: "detected_language": self.detected_language, "translations": self.translations, "query_tokens": self.query_tokens, + "style_intent_profile": ( + self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None + ), } @@ -97,6 +89,11 @@ class QueryParser: self.language_detector = LanguageDetector() self.rewriter = QueryRewriter(config.query_config.rewrite_dictionary) self._tokenizer = tokenizer or self._build_tokenizer() + self.style_intent_registry = StyleIntentRegistry.from_query_config(config.query_config) + self.style_intent_detector = StyleIntentDetector( + self.style_intent_registry, + tokenizer=self._tokenizer, + ) # Eager initialization (startup-time failure visibility, no lazy init in request path) if self.config.query_config.enable_text_embedding and self._text_encoder is None: @@ -172,28 +169,7 @@ class QueryParser: @staticmethod def _extract_tokens(tokenizer_result: Any) -> List[str]: """Normalize tokenizer output into a flat token string list.""" - if not tokenizer_result: - return [] - if isinstance(tokenizer_result, str): - token = tokenizer_result.strip() - return [token] if token else [] - - tokens: List[str] = [] - for item in tokenizer_result: - token: Optional[str] = None - if isinstance(item, str): - token = item - elif isinstance(item, (list, tuple)) and item: - token = str(item[0]) - elif item is not None: - token = str(item) - - if token is None: - continue - token = token.strip() - if token: - tokens.append(token) - return tokens + return extract_token_strings(tokenizer_result) def _get_query_tokens(self, query: str) -> List[str]: return self._extract_tokens(self._tokenizer(query)) @@ -425,6 +401,22 @@ class QueryParser: context.store_intermediate_result("translations", translations) # Build result + base_result = ParsedQuery( + original_query=query, + query_normalized=normalized, + rewritten_query=query_text, + detected_language=detected_lang, + translations=translations, + query_vector=query_vector, + query_tokens=query_tokens, + ) + style_intent_profile = self.style_intent_detector.detect(base_result) + if context: + context.store_intermediate_result( + "style_intent_profile", + style_intent_profile.to_dict(), + ) + result = ParsedQuery( original_query=query, query_normalized=normalized, @@ -433,6 +425,7 @@ class QueryParser: translations=translations, query_vector=query_vector, query_tokens=query_tokens, + style_intent_profile=style_intent_profile, ) if context and hasattr(context, 'logger'): diff --git a/query/style_intent.py b/query/style_intent.py new file mode 100644 index 0000000..13525fc --- /dev/null +++ b/query/style_intent.py @@ -0,0 +1,261 @@ +""" +Style intent detection for query understanding. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple + +from .tokenization import TokenizedText, normalize_query_text, tokenize_text + + +@dataclass(frozen=True) +class StyleIntentDefinition: + intent_type: str + term_groups: Tuple[Tuple[str, ...], ...] + dimension_aliases: Tuple[str, ...] + synonym_to_canonical: Dict[str, str] + max_term_ngram: int = 3 + + @classmethod + def from_rows( + cls, + intent_type: str, + rows: Sequence[Sequence[str]], + dimension_aliases: Sequence[str], + ) -> "StyleIntentDefinition": + term_groups: List[Tuple[str, ...]] = [] + synonym_to_canonical: Dict[str, str] = {} + max_ngram = 1 + + for row in rows: + normalized_terms: List[str] = [] + for raw_term in row: + term = normalize_query_text(raw_term) + if not term or term in normalized_terms: + continue + normalized_terms.append(term) + if not normalized_terms: + continue + + canonical = normalized_terms[0] + term_groups.append(tuple(normalized_terms)) + for term in normalized_terms: + synonym_to_canonical[term] = canonical + max_ngram = max(max_ngram, len(term.split())) + + aliases = tuple( + dict.fromkeys( + term + for term in ( + normalize_query_text(alias) + for alias in dimension_aliases + ) + if term + ) + ) + + return cls( + intent_type=intent_type, + term_groups=tuple(term_groups), + dimension_aliases=aliases, + synonym_to_canonical=synonym_to_canonical, + max_term_ngram=max_ngram, + ) + + def match_candidates(self, candidates: Iterable[str]) -> Set[str]: + matched: Set[str] = set() + for candidate in candidates: + canonical = self.synonym_to_canonical.get(normalize_query_text(candidate)) + if canonical: + matched.add(canonical) + return matched + + def match_text( + self, + text: str, + *, + tokenizer: Optional[Callable[[str], Any]] = None, + ) -> Set[str]: + bundle = tokenize_text(text, tokenizer=tokenizer, max_ngram=self.max_term_ngram) + return self.match_candidates(bundle.candidates) + + +@dataclass(frozen=True) +class DetectedStyleIntent: + intent_type: str + canonical_value: str + matched_term: str + matched_query_text: str + dimension_aliases: Tuple[str, ...] + + def to_dict(self) -> Dict[str, Any]: + return { + "intent_type": self.intent_type, + "canonical_value": self.canonical_value, + "matched_term": self.matched_term, + "matched_query_text": self.matched_query_text, + "dimension_aliases": list(self.dimension_aliases), + } + + +@dataclass(frozen=True) +class StyleIntentProfile: + query_variants: Tuple[TokenizedText, ...] = field(default_factory=tuple) + intents: Tuple[DetectedStyleIntent, ...] = field(default_factory=tuple) + + @property + def is_active(self) -> bool: + return bool(self.intents) + + def get_intents(self, intent_type: Optional[str] = None) -> List[DetectedStyleIntent]: + if intent_type is None: + return list(self.intents) + normalized = normalize_query_text(intent_type) + return [intent for intent in self.intents if intent.intent_type == normalized] + + def get_canonical_values(self, intent_type: str) -> Set[str]: + return {intent.canonical_value for intent in self.get_intents(intent_type)} + + def to_dict(self) -> Dict[str, Any]: + return { + "active": self.is_active, + "intents": [intent.to_dict() for intent in self.intents], + "query_variants": [ + { + "text": variant.text, + "normalized_text": variant.normalized_text, + "fine_tokens": list(variant.fine_tokens), + "coarse_tokens": list(variant.coarse_tokens), + "candidates": list(variant.candidates), + } + for variant in self.query_variants + ], + } + + +class StyleIntentRegistry: + """Holds style intent vocabularies and matching helpers.""" + + def __init__( + self, + definitions: Dict[str, StyleIntentDefinition], + *, + enabled: bool = True, + ) -> None: + self.definitions = definitions + self.enabled = bool(enabled) + + @classmethod + def from_query_config(cls, query_config: Any) -> "StyleIntentRegistry": + style_terms = getattr(query_config, "style_intent_terms", {}) or {} + dimension_aliases = getattr(query_config, "style_intent_dimension_aliases", {}) or {} + definitions: Dict[str, StyleIntentDefinition] = {} + + for intent_type, rows in style_terms.items(): + definition = StyleIntentDefinition.from_rows( + intent_type=normalize_query_text(intent_type), + rows=rows or [], + dimension_aliases=dimension_aliases.get(intent_type, []), + ) + if definition.synonym_to_canonical: + definitions[definition.intent_type] = definition + + return cls( + definitions, + enabled=bool(getattr(query_config, "style_intent_enabled", True)), + ) + + def get_definition(self, intent_type: str) -> Optional[StyleIntentDefinition]: + return self.definitions.get(normalize_query_text(intent_type)) + + def get_dimension_aliases(self, intent_type: str) -> Tuple[str, ...]: + definition = self.get_definition(intent_type) + return definition.dimension_aliases if definition else tuple() + + +class StyleIntentDetector: + """Detects style intents from parsed query variants.""" + + def __init__( + self, + registry: StyleIntentRegistry, + *, + tokenizer: Optional[Callable[[str], Any]] = None, + ) -> None: + self.registry = registry + self.tokenizer = tokenizer + + def _build_query_variants(self, parsed_query: Any) -> Tuple[TokenizedText, ...]: + seen = set() + variants: List[TokenizedText] = [] + texts = [ + getattr(parsed_query, "original_query", None), + getattr(parsed_query, "query_normalized", None), + getattr(parsed_query, "rewritten_query", None), + ] + + translations = getattr(parsed_query, "translations", {}) or {} + if isinstance(translations, dict): + texts.extend(translations.values()) + + for raw_text in texts: + text = str(raw_text or "").strip() + if not text: + continue + normalized = normalize_query_text(text) + if not normalized or normalized in seen: + continue + seen.add(normalized) + variants.append( + tokenize_text( + text, + tokenizer=self.tokenizer, + max_ngram=max( + (definition.max_term_ngram for definition in self.registry.definitions.values()), + default=3, + ), + ) + ) + + return tuple(variants) + + def detect(self, parsed_query: Any) -> StyleIntentProfile: + if not self.registry.enabled or not self.registry.definitions: + return StyleIntentProfile() + + query_variants = self._build_query_variants(parsed_query) + detected: List[DetectedStyleIntent] = [] + seen_pairs = set() + + for variant in query_variants: + for intent_type, definition in self.registry.definitions.items(): + matched_canonicals = definition.match_candidates(variant.candidates) + if not matched_canonicals: + continue + + for candidate in variant.candidates: + normalized_candidate = normalize_query_text(candidate) + canonical = definition.synonym_to_canonical.get(normalized_candidate) + if not canonical or canonical not in matched_canonicals: + continue + pair = (intent_type, canonical) + if pair in seen_pairs: + continue + seen_pairs.add(pair) + detected.append( + DetectedStyleIntent( + intent_type=intent_type, + canonical_value=canonical, + matched_term=normalized_candidate, + matched_query_text=variant.text, + dimension_aliases=definition.dimension_aliases, + ) + ) + break + + return StyleIntentProfile( + query_variants=query_variants, + intents=tuple(detected), + ) diff --git a/query/tokenization.py b/query/tokenization.py new file mode 100644 index 0000000..61beaf2 --- /dev/null +++ b/query/tokenization.py @@ -0,0 +1,122 @@ +""" +Shared tokenization helpers for query understanding. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import re +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple + + +_TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]+|[A-Za-z0-9_]+(?:-[A-Za-z0-9_]+)*") + + +def normalize_query_text(text: Optional[str]) -> str: + if text is None: + return "" + return " ".join(str(text).strip().casefold().split()) + + +def simple_tokenize_query(text: str) -> List[str]: + """ + Lightweight tokenizer for coarse query matching. + + - Consecutive CJK characters form one token + - Latin / digit runs (with internal hyphens) form tokens + """ + if not text: + return [] + return _TOKEN_PATTERN.findall(text) + + +def extract_token_strings(tokenizer_result: Any) -> List[str]: + """Normalize tokenizer output into a flat token string list.""" + if not tokenizer_result: + return [] + if isinstance(tokenizer_result, str): + token = tokenizer_result.strip() + return [token] if token else [] + + tokens: List[str] = [] + for item in tokenizer_result: + token: Optional[str] = None + if isinstance(item, str): + token = item + elif isinstance(item, (list, tuple)) and item: + token = str(item[0]) + elif item is not None: + token = str(item) + + if token is None: + continue + token = token.strip() + if token: + tokens.append(token) + return tokens + + +def _dedupe_preserve_order(values: Iterable[str]) -> List[str]: + result: List[str] = [] + seen = set() + for value in values: + normalized = normalize_query_text(value) + if not normalized or normalized in seen: + continue + seen.add(normalized) + result.append(normalized) + return result + + +def _build_phrase_candidates(tokens: Sequence[str], max_ngram: int) -> List[str]: + if not tokens: + return [] + + phrases: List[str] = [] + upper = max(1, int(max_ngram)) + for size in range(1, upper + 1): + if size > len(tokens): + break + for start in range(0, len(tokens) - size + 1): + phrase = " ".join(tokens[start:start + size]).strip() + if phrase: + phrases.append(phrase) + return phrases + + +@dataclass(frozen=True) +class TokenizedText: + text: str + normalized_text: str + fine_tokens: Tuple[str, ...] + coarse_tokens: Tuple[str, ...] + candidates: Tuple[str, ...] + + +def tokenize_text( + text: str, + *, + tokenizer: Optional[Callable[[str], Any]] = None, + max_ngram: int = 3, +) -> TokenizedText: + normalized_text = normalize_query_text(text) + coarse_tokens = _dedupe_preserve_order(simple_tokenize_query(text)) + + fine_raw = extract_token_strings(tokenizer(text)) if tokenizer is not None and text else [] + fine_tokens = _dedupe_preserve_order(fine_raw) + + candidates = _dedupe_preserve_order( + list(fine_tokens) + + list(coarse_tokens) + + _build_phrase_candidates(fine_tokens, max_ngram=max_ngram) + + _build_phrase_candidates(coarse_tokens, max_ngram=max_ngram) + + ([normalized_text] if normalized_text else []) + ) + + return TokenizedText( + text=text, + normalized_text=normalized_text, + fine_tokens=tuple(fine_tokens), + coarse_tokens=tuple(coarse_tokens), + candidates=tuple(candidates), + ) diff --git a/search/rerank_client.py b/search/rerank_client.py index 7108b5f..7953dac 100644 --- a/search/rerank_client.py +++ b/search/rerank_client.py @@ -62,11 +62,19 @@ def build_docs_from_hits( need_category_path = "{category_path}" in doc_template for hit in es_hits: src = hit.get("_source") or {} + title_suffix = str(hit.get("_style_rerank_suffix") or "").strip() if only_title: - docs.append(pick_lang_text(src.get("title"))) + title = pick_lang_text(src.get("title")) + if title_suffix: + title = f"{title} {title_suffix}".strip() + docs.append(title) else: values = _SafeDict( - title=pick_lang_text(src.get("title")), + title=( + f"{pick_lang_text(src.get('title'))} {title_suffix}".strip() + if title_suffix + else pick_lang_text(src.get("title")) + ), brief=pick_lang_text(src.get("brief")) if need_brief else "", vendor=pick_lang_text(src.get("vendor")) if need_vendor else "", description=pick_lang_text(src.get("description")) if need_description else "", diff --git a/search/searcher.py b/search/searcher.py index 5285d2c..ba875f9 100644 --- a/search/searcher.py +++ b/search/searcher.py @@ -10,12 +10,13 @@ import time, json import logging import hashlib from string import Formatter -import numpy as np from utils.es_client import ESClient from query import QueryParser, ParsedQuery +from query.style_intent import StyleIntentRegistry from embeddings.image_encoder import CLIPImageEncoder from .es_query_builder import ESQueryBuilder +from .sku_intent_selector import SkuSelectionDecision, StyleSkuSelector from config import SearchConfig from config.tenant_config_loader import get_tenant_config_loader from context.request_context import RequestContext, RequestContextStage @@ -115,6 +116,12 @@ class Searcher: else: self.image_encoder = image_encoder self.source_fields = config.query_config.source_fields + self.style_intent_registry = StyleIntentRegistry.from_query_config(self.config.query_config) + self.style_sku_selector = StyleSkuSelector( + self.style_intent_registry, + text_encoder_getter=lambda: getattr(self.query_parser, "text_encoder", None), + tokenizer_getter=lambda: getattr(self.query_parser, "_tokenizer", None), + ) # Query builder - simplified single-layer architecture self.query_builder = ESQueryBuilder( @@ -155,7 +162,11 @@ class Searcher: return es_query["_source"] = {"includes": self.source_fields} - def _resolve_rerank_source_filter(self, doc_template: str) -> Dict[str, Any]: + def _resolve_rerank_source_filter( + self, + doc_template: str, + parsed_query: Optional[ParsedQuery] = None, + ) -> Dict[str, Any]: """ Build a lightweight _source filter for rerank prefetch. @@ -182,6 +193,16 @@ class Searcher: if not includes: includes.add("title") + if self._has_style_intent(parsed_query): + includes.update( + { + "skus", + "option1_name", + "option2_name", + "option3_name", + } + ) + return {"includes": sorted(includes)} def _fetch_hits_by_ids( @@ -225,256 +246,23 @@ class Searcher: return hits_by_id, int(resp.get("took", 0) or 0) @staticmethod - def _normalize_sku_match_text(value: Optional[str]) -> str: - """Normalize free text for lightweight SKU option matching.""" - if value is None: - return "" - return " ".join(str(value).strip().casefold().split()) - - @staticmethod - def _sku_option1_embedding_key( - sku: Dict[str, Any], - spu_option1_name: Optional[Any] = None, - ) -> Optional[str]: - """ - Text sent to the embedding service for option1 must be "name:value" - (option name from SKU row or SPU-level option1_name). - """ - value_raw = sku.get("option1_value") - if value_raw is None: - return None - value = str(value_raw).strip() - if not value: - return None - name = sku.get("option1_name") - if name is None or not str(name).strip(): - name = spu_option1_name - name_str = str(name).strip() if name is not None and str(name).strip() else "" - if name_str: - value = f"{name_str}:{value}" - return value.casefold() - - def _build_sku_query_texts(self, parsed_query: ParsedQuery) -> List[str]: - """Collect original and translated query texts for SKU option matching.""" - candidates: List[str] = [] - for text in ( - getattr(parsed_query, "original_query", None), - getattr(parsed_query, "query_normalized", None), - getattr(parsed_query, "rewritten_query", None), - ): - normalized = self._normalize_sku_match_text(text) - if normalized: - candidates.append(normalized) - - translations = getattr(parsed_query, "translations", {}) or {} - if isinstance(translations, dict): - for text in translations.values(): - normalized = self._normalize_sku_match_text(text) - if normalized: - candidates.append(normalized) - - deduped: List[str] = [] - seen = set() - for text in candidates: - if text in seen: - continue - seen.add(text) - deduped.append(text) - return deduped - - def _find_query_matching_sku_index( - self, - skus: List[Dict[str, Any]], - query_texts: List[str], - spu_option1_name: Optional[Any] = None, - ) -> Optional[int]: - """Return the first SKU whose option1_value (or name:value) appears in query texts.""" - if not skus or not query_texts: - return None - - for index, sku in enumerate(skus): - option1_value = self._normalize_sku_match_text(sku.get("option1_value")) - if not option1_value: - continue - if any(option1_value in query_text for query_text in query_texts): - return index - embed_key = self._sku_option1_embedding_key(sku, spu_option1_name) - if embed_key and embed_key != option1_value: - composite_norm = self._normalize_sku_match_text(embed_key.replace(":", " ")) - if any(composite_norm in query_text for query_text in query_texts): - return index - if any(embed_key.casefold() in query_text for query_text in query_texts): - return index - return None - - def _encode_query_vector_for_sku_matching( - self, - parsed_query: ParsedQuery, - context: Optional[RequestContext] = None, - ) -> Optional[np.ndarray]: - """Best-effort fallback query embedding for final-page SKU matching.""" - query_text = ( - getattr(parsed_query, "rewritten_query", None) - or getattr(parsed_query, "query_normalized", None) - or getattr(parsed_query, "original_query", None) - ) - if not query_text: - return None - - text_encoder = getattr(self.query_parser, "text_encoder", None) - if text_encoder is None: - return None - - try: - vectors = text_encoder.encode([query_text], priority=1) - except Exception as exc: - logger.warning("Failed to encode query vector for SKU matching: %s", exc, exc_info=True) - if context is not None: - context.add_warning(f"SKU query embedding failed: {exc}") - return None - - if vectors is None or len(vectors) == 0: - return None - - vector = vectors[0] - if vector is None: - return None - return np.asarray(vector, dtype=np.float32) - - def _select_sku_by_embedding( - self, - skus: List[Dict[str, Any]], - option1_vectors: Dict[str, np.ndarray], - query_vector: np.ndarray, - spu_option1_name: Optional[Any] = None, - ) -> Tuple[Optional[int], Optional[float]]: - """Select the SKU whose option1 embedding key (name:value) is most similar to the query.""" - best_index: Optional[int] = None - best_score: Optional[float] = None - - for index, sku in enumerate(skus): - embed_key = self._sku_option1_embedding_key(sku, spu_option1_name) - if not embed_key: - continue - option_vector = option1_vectors.get(embed_key) - if option_vector is None: - continue - score = float(np.inner(query_vector, option_vector)) - if best_score is None or score > best_score: - best_index = index - best_score = score - - return best_index, best_score - - @staticmethod - def _promote_matching_sku(source: Dict[str, Any], match_index: int) -> Optional[Dict[str, Any]]: - """Move the matched SKU to the front and swap the SPU image.""" - skus = source.get("skus") - if not isinstance(skus, list) or match_index < 0 or match_index >= len(skus): - return None - - matched_sku = skus.pop(match_index) - skus.insert(0, matched_sku) + def _has_style_intent(parsed_query: Optional[ParsedQuery]) -> bool: + profile = getattr(parsed_query, "style_intent_profile", None) + return bool(getattr(profile, "is_active", False)) - image_src = matched_sku.get("image_src") or matched_sku.get("imageSrc") - if image_src: - source["image_url"] = image_src - return matched_sku - - def _apply_sku_sorting_for_page_hits( + def _apply_style_intent_to_hits( self, es_hits: List[Dict[str, Any]], parsed_query: ParsedQuery, context: Optional[RequestContext] = None, - ) -> None: - """Sort each page hit's SKUs so the best-matching SKU is first.""" - if not es_hits: - return - - query_texts = self._build_sku_query_texts(parsed_query) - unmatched_hits: List[Dict[str, Any]] = [] - option1_values_to_encode: List[str] = [] - seen_option1_values = set() - text_matched = 0 - embedding_matched = 0 - - for hit in es_hits: - source = hit.get("_source") - if not isinstance(source, dict): - continue - skus = source.get("skus") - if not isinstance(skus, list) or not skus: - continue - - spu_option1_name = source.get("option1_name") - match_index = self._find_query_matching_sku_index( - skus, query_texts, spu_option1_name=spu_option1_name - ) - if match_index is not None: - self._promote_matching_sku(source, match_index) - text_matched += 1 - continue - - unmatched_hits.append(hit) - for sku in skus: - embed_key = self._sku_option1_embedding_key(sku, spu_option1_name) - if not embed_key or embed_key in seen_option1_values: - continue - seen_option1_values.add(embed_key) - option1_values_to_encode.append(embed_key) - - if not unmatched_hits or not option1_values_to_encode: - return - - query_vector = getattr(parsed_query, "query_vector", None) - if query_vector is None: - query_vector = self._encode_query_vector_for_sku_matching(parsed_query, context=context) - if query_vector is None: - return - - text_encoder = getattr(self.query_parser, "text_encoder", None) - if text_encoder is None: - return - - try: - encoded_option_vectors = text_encoder.encode(option1_values_to_encode, priority=1) - except Exception as exc: - logger.warning("Failed to encode SKU option1 values for final-page sorting: %s", exc, exc_info=True) - if context is not None: - context.add_warning(f"SKU option embedding failed: {exc}") - return - - option1_vectors: Dict[str, np.ndarray] = {} - for option1_value, vector in zip(option1_values_to_encode, encoded_option_vectors): - if vector is None: - continue - option1_vectors[option1_value] = np.asarray(vector, dtype=np.float32) - - query_vector_array = np.asarray(query_vector, dtype=np.float32) - for hit in unmatched_hits: - source = hit.get("_source") - if not isinstance(source, dict): - continue - skus = source.get("skus") - if not isinstance(skus, list) or not skus: - continue - match_index, _ = self._select_sku_by_embedding( - skus, - option1_vectors, - query_vector_array, - spu_option1_name=source.get("option1_name"), - ) - if match_index is None: - continue - self._promote_matching_sku(source, match_index) - embedding_matched += 1 - - if text_matched or embedding_matched: - logger.info( - "Final-page SKU sorting completed | text_matched=%s | embedding_matched=%s", - text_matched, - embedding_matched, + ) -> Dict[str, SkuSelectionDecision]: + decisions = self.style_sku_selector.prepare_hits(es_hits, parsed_query) + if decisions and context is not None: + context.store_intermediate_result( + "style_intent_sku_decisions", + {doc_id: decision.to_dict() for doc_id, decision in decisions.items()}, ) + return decisions def search( self, @@ -583,7 +371,8 @@ class Searcher: context.metadata['feature_flags'] = { 'translation_enabled': enable_translation, 'embedding_enabled': enable_embedding, - 'rerank_enabled': do_rerank + 'rerank_enabled': do_rerank, + 'style_intent_enabled': bool(self.style_intent_registry.enabled), } # Step 1: Parse query @@ -607,6 +396,7 @@ class Searcher: domain="default", is_simple_query=True ) + context.metadata["feature_flags"]["style_intent_active"] = self._has_style_intent(parsed_query) context.logger.info( f"查询解析完成 | 原查询: '{parsed_query.original_query}' | " @@ -667,7 +457,10 @@ class Searcher: es_query_for_fetch = es_query rerank_prefetch_source = None if in_rerank_window: - rerank_prefetch_source = self._resolve_rerank_source_filter(effective_doc_template) + rerank_prefetch_source = self._resolve_rerank_source_filter( + effective_doc_template, + parsed_query=parsed_query, + ) es_query_for_fetch = dict(es_query) es_query_for_fetch["_source"] = rerank_prefetch_source @@ -751,6 +544,20 @@ class Searcher: finally: context.end_stage(RequestContextStage.ELASTICSEARCH_SEARCH_PRIMARY) + style_intent_decisions: Dict[str, SkuSelectionDecision] = {} + if self._has_style_intent(parsed_query) and in_rerank_window: + style_intent_decisions = self._apply_style_intent_to_hits( + es_response.get("hits", {}).get("hits") or [], + parsed_query, + context=context, + ) + if style_intent_decisions: + context.logger.info( + "款式意图 SKU 预筛选完成 | hits=%s", + len(style_intent_decisions), + extra={'reqid': context.reqid, 'uid': context.uid} + ) + # Optional Step 4.5: AI reranking(仅当请求范围在重排窗口内时执行) if do_rerank and in_rerank_window: context.start_stage(RequestContextStage.RERANKING) @@ -841,6 +648,11 @@ class Searcher: if "_source" in detail_hit: hit["_source"] = detail_hit.get("_source") or {} filled += 1 + if style_intent_decisions: + self.style_sku_selector.apply_precomputed_decisions( + sliced, + style_intent_decisions, + ) if fill_took: es_response["took"] = int((es_response.get("took", 0) or 0) + fill_took) context.logger.info( @@ -883,7 +695,18 @@ class Searcher: continue rerank_debug_by_doc[str(doc_id)] = item - self._apply_sku_sorting_for_page_hits(es_hits, parsed_query, context=context) + if self._has_style_intent(parsed_query): + if in_rerank_window and style_intent_decisions: + self.style_sku_selector.apply_precomputed_decisions( + es_hits, + style_intent_decisions, + ) + elif not in_rerank_window: + style_intent_decisions = self._apply_style_intent_to_hits( + es_hits, + parsed_query, + context=context, + ) # Format results using ResultFormatter formatted_results = ResultFormatter.format_search_results( @@ -902,6 +725,11 @@ class Searcher: rerank_debug = None if doc_id is not None: rerank_debug = rerank_debug_by_doc.get(str(doc_id)) + style_intent_debug = None + if doc_id is not None and style_intent_decisions: + decision = style_intent_decisions.get(str(doc_id)) + if decision is not None: + style_intent_debug = decision.to_dict() raw_score = hit.get("_score") try: @@ -940,6 +768,9 @@ class Searcher: debug_entry["fused_score"] = rerank_debug.get("fused_score") debug_entry["matched_queries"] = rerank_debug.get("matched_queries") + if style_intent_debug: + debug_entry["style_intent_sku"] = style_intent_debug + per_result_debug.append(debug_entry) # Format facets @@ -987,7 +818,8 @@ class Searcher: "translations": context.query_analysis.translations, "has_vector": context.query_analysis.query_vector is not None, "is_simple_query": context.query_analysis.is_simple_query, - "domain": context.query_analysis.domain + "domain": context.query_analysis.domain, + "style_intent_profile": context.get_intermediate_result("style_intent_profile"), }, "es_query": context.get_intermediate_result('es_query', {}), "es_response": { diff --git a/search/sku_intent_selector.py b/search/sku_intent_selector.py new file mode 100644 index 0000000..c832573 --- /dev/null +++ b/search/sku_intent_selector.py @@ -0,0 +1,405 @@ +""" +SKU selection for style-intent-aware search results. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple + +import numpy as np + +from query.style_intent import StyleIntentProfile, StyleIntentRegistry +from query.tokenization import normalize_query_text + + +@dataclass(frozen=True) +class SkuSelectionDecision: + selected_sku_id: Optional[str] + rerank_suffix: str + selected_text: str + matched_stage: str + similarity_score: Optional[float] = None + resolved_dimensions: Dict[str, Optional[str]] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "selected_sku_id": self.selected_sku_id, + "rerank_suffix": self.rerank_suffix, + "selected_text": self.selected_text, + "matched_stage": self.matched_stage, + "similarity_score": self.similarity_score, + "resolved_dimensions": dict(self.resolved_dimensions), + } + + +@dataclass +class _SkuCandidate: + index: int + sku_id: str + sku: Dict[str, Any] + selection_text: str + intent_texts: Dict[str, str] + + +class StyleSkuSelector: + """Selects the best SKU for an SPU based on detected style intent.""" + + def __init__( + self, + registry: StyleIntentRegistry, + *, + text_encoder_getter: Optional[Callable[[], Any]] = None, + tokenizer_getter: Optional[Callable[[], Any]] = None, + ) -> None: + self.registry = registry + self._text_encoder_getter = text_encoder_getter + self._tokenizer_getter = tokenizer_getter + + def prepare_hits( + self, + es_hits: List[Dict[str, Any]], + parsed_query: Any, + ) -> Dict[str, SkuSelectionDecision]: + decisions: Dict[str, SkuSelectionDecision] = {} + style_profile = getattr(parsed_query, "style_intent_profile", None) + if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active: + return decisions + + query_texts = self._build_query_texts(parsed_query, style_profile) + query_vector = self._get_query_vector(parsed_query) + tokenizer = self._get_tokenizer() + + for hit in es_hits: + source = hit.get("_source") + if not isinstance(source, dict): + continue + + decision = self._select_for_source( + source, + style_profile=style_profile, + query_texts=query_texts, + query_vector=query_vector, + tokenizer=tokenizer, + ) + if decision is None: + continue + + self._apply_decision_to_source(source, decision) + if decision.rerank_suffix: + hit["_style_rerank_suffix"] = decision.rerank_suffix + + doc_id = hit.get("_id") + if doc_id is not None: + decisions[str(doc_id)] = decision + + return decisions + + def apply_precomputed_decisions( + self, + es_hits: List[Dict[str, Any]], + decisions: Dict[str, SkuSelectionDecision], + ) -> None: + if not es_hits or not decisions: + return + + for hit in es_hits: + doc_id = hit.get("_id") + if doc_id is None: + continue + decision = decisions.get(str(doc_id)) + if decision is None: + continue + source = hit.get("_source") + if not isinstance(source, dict): + continue + self._apply_decision_to_source(source, decision) + if decision.rerank_suffix: + hit["_style_rerank_suffix"] = decision.rerank_suffix + + def _build_query_texts( + self, + parsed_query: Any, + style_profile: StyleIntentProfile, + ) -> List[str]: + texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text] + if texts: + return list(dict.fromkeys(texts)) + + fallbacks: List[str] = [] + for value in ( + getattr(parsed_query, "original_query", None), + getattr(parsed_query, "query_normalized", None), + getattr(parsed_query, "rewritten_query", None), + ): + normalized = normalize_query_text(value) + if normalized: + fallbacks.append(normalized) + translations = getattr(parsed_query, "translations", {}) or {} + if isinstance(translations, dict): + for value in translations.values(): + normalized = normalize_query_text(value) + if normalized: + fallbacks.append(normalized) + return list(dict.fromkeys(fallbacks)) + + def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]: + query_vector = getattr(parsed_query, "query_vector", None) + if query_vector is not None: + return np.asarray(query_vector, dtype=np.float32) + + text_encoder = self._get_text_encoder() + if text_encoder is None: + return None + + query_text = ( + getattr(parsed_query, "rewritten_query", None) + or getattr(parsed_query, "query_normalized", None) + or getattr(parsed_query, "original_query", None) + ) + if not query_text: + return None + + vectors = text_encoder.encode([query_text], priority=1) + if vectors is None or len(vectors) == 0 or vectors[0] is None: + return None + return np.asarray(vectors[0], dtype=np.float32) + + def _get_text_encoder(self) -> Any: + if self._text_encoder_getter is None: + return None + return self._text_encoder_getter() + + def _get_tokenizer(self) -> Any: + if self._tokenizer_getter is None: + return None + return self._tokenizer_getter() + + @staticmethod + def _fallback_sku_text(sku: Dict[str, Any]) -> str: + parts = [] + for field_name in ("option1_value", "option2_value", "option3_value"): + value = str(sku.get(field_name) or "").strip() + if value: + parts.append(value) + return " ".join(parts) + + def _resolve_dimensions( + self, + source: Dict[str, Any], + style_profile: StyleIntentProfile, + ) -> Dict[str, Optional[str]]: + option_names = { + "option1_value": normalize_query_text(source.get("option1_name")), + "option2_value": normalize_query_text(source.get("option2_name")), + "option3_value": normalize_query_text(source.get("option3_name")), + } + resolved: Dict[str, Optional[str]] = {} + for intent in style_profile.intents: + if intent.intent_type in resolved: + continue + aliases = set(intent.dimension_aliases or self.registry.get_dimension_aliases(intent.intent_type)) + matched_field = None + for field_name, option_name in option_names.items(): + if option_name and option_name in aliases: + matched_field = field_name + break + resolved[intent.intent_type] = matched_field + return resolved + + def _build_candidates( + self, + skus: List[Dict[str, Any]], + resolved_dimensions: Dict[str, Optional[str]], + ) -> List[_SkuCandidate]: + candidates: List[_SkuCandidate] = [] + for index, sku in enumerate(skus): + fallback_text = self._fallback_sku_text(sku) + intent_texts: Dict[str, str] = {} + for intent_type, field_name in resolved_dimensions.items(): + if field_name: + value = str(sku.get(field_name) or "").strip() + intent_texts[intent_type] = value or fallback_text + else: + intent_texts[intent_type] = fallback_text + + selection_parts: List[str] = [] + seen = set() + for value in intent_texts.values(): + normalized = normalize_query_text(value) + if not normalized or normalized in seen: + continue + seen.add(normalized) + selection_parts.append(str(value).strip()) + + selection_text = " ".join(selection_parts).strip() or fallback_text + candidates.append( + _SkuCandidate( + index=index, + sku_id=str(sku.get("sku_id") or ""), + sku=sku, + selection_text=selection_text, + intent_texts=intent_texts, + ) + ) + return candidates + + @staticmethod + def _is_direct_match( + candidate: _SkuCandidate, + query_texts: Sequence[str], + ) -> bool: + if not candidate.intent_texts or not query_texts: + return False + for value in candidate.intent_texts.values(): + normalized_value = normalize_query_text(value) + if not normalized_value: + return False + if not any(normalized_value in query_text for query_text in query_texts): + return False + return True + + def _is_generalized_match( + self, + candidate: _SkuCandidate, + style_profile: StyleIntentProfile, + tokenizer: Any, + ) -> bool: + if not candidate.intent_texts: + return False + + for intent_type, value in candidate.intent_texts.items(): + definition = self.registry.get_definition(intent_type) + if definition is None: + return False + matched_canonicals = definition.match_text(value, tokenizer=tokenizer) + if not matched_canonicals.intersection(style_profile.get_canonical_values(intent_type)): + return False + return True + + def _select_by_embedding( + self, + candidates: Sequence[_SkuCandidate], + query_vector: Optional[np.ndarray], + ) -> Tuple[Optional[_SkuCandidate], Optional[float]]: + if not candidates: + return None, None + text_encoder = self._get_text_encoder() + if query_vector is None or text_encoder is None: + return candidates[0], None + + unique_texts = list( + dict.fromkeys( + normalize_query_text(candidate.selection_text) + for candidate in candidates + if normalize_query_text(candidate.selection_text) + ) + ) + if not unique_texts: + return candidates[0], None + + vectors = text_encoder.encode(unique_texts, priority=1) + vector_map: Dict[str, np.ndarray] = {} + for key, vector in zip(unique_texts, vectors): + if vector is None: + continue + vector_map[key] = np.asarray(vector, dtype=np.float32) + + best_candidate: Optional[_SkuCandidate] = None + best_score: Optional[float] = None + query_vector_array = np.asarray(query_vector, dtype=np.float32) + for candidate in candidates: + normalized_text = normalize_query_text(candidate.selection_text) + candidate_vector = vector_map.get(normalized_text) + if candidate_vector is None: + continue + score = float(np.inner(query_vector_array, candidate_vector)) + if best_score is None or score > best_score: + best_candidate = candidate + best_score = score + + return best_candidate or candidates[0], best_score + + def _select_for_source( + self, + source: Dict[str, Any], + *, + style_profile: StyleIntentProfile, + query_texts: Sequence[str], + query_vector: Optional[np.ndarray], + tokenizer: Any, + ) -> Optional[SkuSelectionDecision]: + skus = source.get("skus") + if not isinstance(skus, list) or not skus: + return None + + resolved_dimensions = self._resolve_dimensions(source, style_profile) + candidates = self._build_candidates(skus, resolved_dimensions) + if not candidates: + return None + + direct_matches = [candidate for candidate in candidates if self._is_direct_match(candidate, query_texts)] + if len(direct_matches) == 1: + chosen = direct_matches[0] + return self._build_decision(chosen, resolved_dimensions, matched_stage="direct") + + generalized_matches: List[_SkuCandidate] = [] + if not direct_matches: + generalized_matches = [ + candidate + for candidate in candidates + if self._is_generalized_match(candidate, style_profile, tokenizer) + ] + if len(generalized_matches) == 1: + chosen = generalized_matches[0] + return self._build_decision(chosen, resolved_dimensions, matched_stage="generalized") + + embedding_pool = direct_matches or generalized_matches or candidates + chosen, similarity_score = self._select_by_embedding(embedding_pool, query_vector) + if chosen is None: + return None + stage = "embedding_from_matches" if direct_matches or generalized_matches else "embedding_from_all" + return self._build_decision( + chosen, + resolved_dimensions, + matched_stage=stage, + similarity_score=similarity_score, + ) + + @staticmethod + def _build_decision( + candidate: _SkuCandidate, + resolved_dimensions: Dict[str, Optional[str]], + *, + matched_stage: str, + similarity_score: Optional[float] = None, + ) -> SkuSelectionDecision: + return SkuSelectionDecision( + selected_sku_id=candidate.sku_id or None, + rerank_suffix=str(candidate.selection_text or "").strip(), + selected_text=str(candidate.selection_text or "").strip(), + matched_stage=matched_stage, + similarity_score=similarity_score, + resolved_dimensions=dict(resolved_dimensions), + ) + + @staticmethod + def _apply_decision_to_source(source: Dict[str, Any], decision: SkuSelectionDecision) -> None: + skus = source.get("skus") + if not isinstance(skus, list) or not skus or not decision.selected_sku_id: + return + + selected_index = None + for index, sku in enumerate(skus): + if str(sku.get("sku_id") or "") == decision.selected_sku_id: + selected_index = index + break + if selected_index is None: + return + + selected_sku = skus.pop(selected_index) + skus.insert(0, selected_sku) + + image_src = selected_sku.get("image_src") or selected_sku.get("imageSrc") + if image_src: + source["image_url"] = image_src diff --git a/tests/test_search_rerank_window.py b/tests/test_search_rerank_window.py index d90c8f0..c03da39 100644 --- a/tests/test_search_rerank_window.py +++ b/tests/test_search_rerank_window.py @@ -18,6 +18,7 @@ from config import ( SearchConfig, ) from context import create_request_context +from query.style_intent import DetectedStyleIntent, StyleIntentProfile from search.searcher import Searcher @@ -30,6 +31,7 @@ class _FakeParsedQuery: translations: Dict[str, str] = None query_vector: Any = None domain: str = "default" + style_intent_profile: Any = None def to_dict(self) -> Dict[str, Any]: return { @@ -39,9 +41,27 @@ class _FakeParsedQuery: "detected_language": self.detected_language, "translations": self.translations or {}, "domain": self.domain, + "style_intent_profile": ( + self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None + ), } +def _build_style_intent_profile(intent_type: str, canonical_value: str, *dimension_aliases: str) -> StyleIntentProfile: + aliases = dimension_aliases or (intent_type,) + return StyleIntentProfile( + intents=( + DetectedStyleIntent( + intent_type=intent_type, + canonical_value=canonical_value, + matched_term=canonical_value, + matched_query_text=canonical_value, + dimension_aliases=tuple(aliases), + ), + ) + ) + + class _FakeQueryParser: def parse( self, @@ -340,6 +360,57 @@ def test_searcher_rerank_prefetch_source_follows_doc_template(monkeypatch): assert es_client.calls[0]["body"]["_source"] == {"includes": ["brief", "title", "vendor"]} +def test_searcher_rerank_prefetch_source_includes_sku_fields_when_style_intent_active(monkeypatch): + es_client = _FakeESClient() + searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client) + context = create_request_context(reqid="t1c", uid="u1c") + + monkeypatch.setattr( + "search.searcher.get_tenant_config_loader", + lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), + ) + monkeypatch.setattr( + "search.rerank_client.run_rerank", + lambda **kwargs: (kwargs["es_response"], None, []), + ) + + class _IntentQueryParser: + text_encoder = None + + def parse( + self, + query: str, + tenant_id: str, + generate_vector: bool, + context: Any, + target_languages: Any = None, + ): + return _FakeParsedQuery( + original_query=query, + query_normalized=query, + rewritten_query=query, + translations={}, + style_intent_profile=_build_style_intent_profile( + "color", "black", "color", "colors", "颜色" + ), + ) + + searcher.query_parser = _IntentQueryParser() + + searcher.search( + query="black dress", + tenant_id="162", + from_=0, + size=5, + context=context, + enable_rerank=None, + ) + + assert es_client.calls[0]["body"]["_source"] == { + "includes": ["option1_name", "option2_name", "option3_name", "skus", "title"] + } + + def test_searcher_skips_rerank_when_request_explicitly_false(monkeypatch): es_client = _FakeESClient() searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client) @@ -434,6 +505,9 @@ def test_searcher_promotes_sku_when_option1_matches_translated_query(monkeypatch query_normalized=query, rewritten_query=query, translations={"en": "black dress"}, + style_intent_profile=_build_style_intent_profile( + "color", "black", "color", "colors", "颜色" + ), ) searcher.query_parser = _TranslatedQueryParser() @@ -481,8 +555,8 @@ def test_searcher_promotes_sku_by_embedding_when_query_has_no_direct_option_matc encoder = _FakeTextEncoder( { "linen summer dress": [0.8, 0.2], - "color:red": [1.0, 0.0], - "color:blue": [0.0, 1.0], + "red": [1.0, 0.0], + "blue": [0.0, 1.0], } ) @@ -503,6 +577,9 @@ def test_searcher_promotes_sku_by_embedding_when_query_has_no_direct_option_matc rewritten_query=query, translations={}, query_vector=np.array([0.0, 1.0], dtype=np.float32), + style_intent_profile=_build_style_intent_profile( + "color", "blue", "color", "colors", "颜色" + ), ) searcher.query_parser = _EmbeddingQueryParser() diff --git a/tests/test_style_intent.py b/tests/test_style_intent.py new file mode 100644 index 0000000..d46217a --- /dev/null +++ b/tests/test_style_intent.py @@ -0,0 +1,35 @@ +from types import SimpleNamespace + +from config import QueryConfig +from query.style_intent import StyleIntentDetector, StyleIntentRegistry + + +def test_style_intent_detector_matches_original_and_translated_queries(): + query_config = QueryConfig( + style_intent_terms={ + "color": [["black", "黑色", "black"]], + "size": [["xl", "x-large", "加大码"]], + }, + style_intent_dimension_aliases={ + "color": ["color", "颜色"], + "size": ["size", "尺码"], + }, + ) + detector = StyleIntentDetector( + StyleIntentRegistry.from_query_config(query_config), + tokenizer=lambda text: text.split(), + ) + + parsed_query = SimpleNamespace( + original_query="黑色 连衣裙", + query_normalized="黑色 连衣裙", + rewritten_query="黑色 连衣裙", + translations={"en": "black dress xl"}, + ) + + profile = detector.detect(parsed_query) + + assert profile.is_active is True + assert profile.get_canonical_values("color") == {"black"} + assert profile.get_canonical_values("size") == {"xl"} + assert len(profile.query_variants) == 2 -- libgit2 0.21.2