diff --git a/config/config.yaml b/config/config.yaml index ca42b90..2d5ce00 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -100,16 +100,19 @@ query_config: # 查询翻译模型(须与 services.translation.capabilities 中某项一致) # 源语种在租户 index_languages 内:主召回可打在源语种字段,用下面三项。 - # zh_to_en_model: "opus-mt-zh-en" - # en_to_zh_model: "opus-mt-en-zh" - # default_translation_model: "nllb-200-distilled-600m" - zh_to_en_model: "deepl" - en_to_zh_model: "deepl" - default_translation_model: "deepl" + zh_to_en_model: "nllb-200-distilled-600m" # "opus-mt-zh-en" + en_to_zh_model: "nllb-200-distilled-600m" # "opus-mt-en-zh" + default_translation_model: "nllb-200-distilled-600m" + # zh_to_en_model: "deepl" + # en_to_zh_model: "deepl" + # default_translation_model: "deepl" # 源语种不在 index_languages:翻译对可检索文本更关键,可单独指定(缺省则与上一组相同) - zh_to_en_model__source_not_in_index: "deepl" - en_to_zh_model__source_not_in_index: "deepl" - default_translation_model__source_not_in_index: "deepl" + zh_to_en_model__source_not_in_index: "nllb-200-distilled-600m" + en_to_zh_model__source_not_in_index: "nllb-200-distilled-600m" + default_translation_model__source_not_in_index: "nllb-200-distilled-600m" + # zh_to_en_model__source_not_in_index: "deepl" + # en_to_zh_model__source_not_in_index: "deepl" + # default_translation_model__source_not_in_index: "deepl" # 查询解析阶段:翻译与 query 向量并发执行,共用同一等待预算(毫秒)。 # 检测语言已在租户 index_languages 内:较短;不在索引语言内:较长(翻译对召回更关键)。 @@ -153,8 +156,8 @@ query_config: # 统一文本召回策略(主查询 + 翻译查询) text_query_strategy: - base_minimum_should_match: "75%" - translation_minimum_should_match: "75%" + base_minimum_should_match: "60%" + translation_minimum_should_match: "60%" translation_boost: 0.75 tie_breaker_base_query: 0.5 best_fields_boost: 2.0 @@ -207,8 +210,8 @@ query_config: - skus # KNN:文本向量与多模态(图片)向量各自 boost 与召回(k / num_candidates) - knn_text_boost: 20 - knn_image_boost: 20 + knn_text_boost: 4 + knn_image_boost: 4 knn_text_k: 150 knn_text_num_candidates: 400 @@ -247,7 +250,7 @@ rerank: knn_image_weight: 1.0 knn_tie_breaker: 0.1 knn_bias: 0.6 - knn_exponent: 0.2 + knn_exponent: 0.0 # 可扩展服务/provider 注册表(单一配置源) services: diff --git a/docs/TODO-keywords限定-done.txt b/docs/TODO-keywords限定-done.txt new file mode 100644 index 0000000..a68186e --- /dev/null +++ b/docs/TODO-keywords限定-done.txt @@ -0,0 +1,93 @@ +@query/query_parser.py @scripts/es_debug_search.py +原始query、以及每一个翻译,都要有一个对应的keywords_query(token分词后,得到名词) +参考这段代码,获取每一个长度大于 1 的名词,然后用空格拼接起来,作为keywords_query +import hanlp +from typing import List, Tuple, Dict, Any + +class KeywordExtractor: + """ + 基于 HanLP 的名词关键词提取器 + """ + def __init__(self): + # 加载带位置信息的分词模型(细粒度) + self.tok = hanlp.load(hanlp.pretrained.tok.CTB9_TOK_ELECTRA_BASE_CRF) + self.tok.config.output_spans = True # 启用位置输出 + + # 加载词性标注模型 + self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL) + + def extract_keywords(self, query: str) -> str: + """ + 从查询中提取关键词(名词,长度 ≥ 2) + + Args: + query: 输入文本 + + Returns: + 拼接后的关键词字符串,非连续词之间自动插入空格 + """ + query = query.strip() + # 分词结果带位置:[[word, start, end], ...] + tok_result_with_position = self.tok(query) + tok_result = [x[0] for x in tok_result_with_position] + + # 词性标注 + pos_tag_result = list(zip(tok_result, self.pos_tag(tok_result))) + + # 需要忽略的词 + ignore_keywords = ['玩具'] + + keywords = [] + last_end_pos = 0 + + for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position): + if len(word) >= 2 and postag.startswith('N'): + if word in ignore_keywords: + continue + # 如果当前词与上一个词在原文中不连续,插入空格 + if start_pos != last_end_pos and keywords: + keywords.append(" ") + keywords.append(word) + last_end_pos = end_pos + # 可选:打印调试信息 + # print(f'分词: {word} | 词性: {postag} | 起始: {start_pos} | 结束: {end_pos}') + + return "".join(keywords).strip() + + +最后,在组织检索表达式时,目前是每一个 query (base_query base_query_trans_en base_query_trans_zh 三种情况)。 会组成一个bool查询,以base_query为例: + "bool": { + "should": [ + { + "bool": { + "_name": "base_query", + "must": [ + { + "combined_fields": { +... + } + } + ], + "should": [ + { + "multi_match": { +... "type": "best_fields", +... + }, + { + "multi_match": { +... + "type": "phrase", +... + } + } + ] + } + }, + +base_query_trans_en base_query_trans_zh 也是同样 + +在这个布尔查询的must里面加一项:keywords,搜索的字段和combined_fields一样,命中比例要求50% + + +结合现有代码做出合理的设计,呈现简单清晰的数据接口,而不是打补丁 \ No newline at end of file diff --git a/query/__init__.py b/query/__init__.py index 4a3bea2..181cb02 100644 --- a/query/__init__.py +++ b/query/__init__.py @@ -2,6 +2,7 @@ from .language_detector import LanguageDetector from .query_rewriter import QueryRewriter, QueryNormalizer +from .keyword_extractor import KEYWORDS_QUERY_BASE_KEY from .query_parser import QueryParser, ParsedQuery __all__ = [ @@ -10,4 +11,5 @@ __all__ = [ 'QueryNormalizer', 'QueryParser', 'ParsedQuery', + 'KEYWORDS_QUERY_BASE_KEY', ] diff --git a/query/keyword_extractor.py b/query/keyword_extractor.py new file mode 100644 index 0000000..082c8a6 --- /dev/null +++ b/query/keyword_extractor.py @@ -0,0 +1,86 @@ +""" +HanLP-based noun keyword string for lexical constraints (token POS starts with N, length >= 2). + +``ParsedQuery.keywords_queries`` uses the same key layout as text variants: +``KEYWORDS_QUERY_BASE_KEY`` for the rewritten source query, and ISO-like language +codes for each ``ParsedQuery.translations`` entry (non-empty extractions only). +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +import hanlp # type: ignore + +# Aligns with ``rewritten_query`` / ES ``base_query`` (not a language code). +KEYWORDS_QUERY_BASE_KEY = "base" + + +class KeywordExtractor: + """基于 HanLP 的名词关键词提取器(与分词位置对齐,非连续名词间插入空格)。""" + + def __init__( + self, + tokenizer: Optional[Any] = None, + *, + ignore_keywords: Optional[List[str]] = None, + ): + if tokenizer is not None: + self.tok = tokenizer + else: + self.tok = hanlp.load(hanlp.pretrained.tok.CTB9_TOK_ELECTRA_BASE_CRF) + self.tok.config.output_spans = True + self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL) + self.ignore_keywords = frozenset(ignore_keywords or ["玩具"]) + + def extract_keywords(self, query: str) -> str: + """ + 从查询中提取关键词(名词,长度 ≥ 2),以空格分隔非连续片段。 + """ + query = (query or "").strip() + if not query: + return "" + tok_result_with_position = self.tok(query) + tok_result = [x[0] for x in tok_result_with_position] + if not tok_result: + return "" + pos_tags = self.pos_tag(tok_result) + pos_tag_result = list(zip(tok_result, pos_tags)) + keywords: List[str] = [] + last_end_pos = 0 + for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position): + if len(word) >= 2 and str(postag).startswith("N"): + if word in self.ignore_keywords: + continue + if start_pos != last_end_pos and keywords: + keywords.append(" ") + keywords.append(word) + last_end_pos = end_pos + return "".join(keywords).strip() + + +def collect_keywords_queries( + extractor: KeywordExtractor, + rewritten_query: str, + translations: Dict[str, str], +) -> Dict[str, str]: + """ + Build the keyword map for all lexical variants (base + translations). + + Omits entries when extraction yields an empty string. + """ + out: Dict[str, str] = {} + base_kw = extractor.extract_keywords(rewritten_query) + if base_kw: + out[KEYWORDS_QUERY_BASE_KEY] = base_kw + for lang, text in translations.items(): + lang_key = str(lang or "").strip().lower() + if not lang_key or not (text or "").strip(): + continue + kw = extractor.extract_keywords(text) + if kw: + out[lang_key] = kw + return out diff --git a/query/query_parser.py b/query/query_parser.py index 43cf343..7024248 100644 --- a/query/query_parser.py +++ b/query/query_parser.py @@ -27,6 +27,7 @@ from .product_title_exclusion import ( from .query_rewriter import QueryRewriter, QueryNormalizer from .style_intent import StyleIntentDetector, StyleIntentProfile, StyleIntentRegistry from .tokenization import extract_token_strings, simple_tokenize_query +from .keyword_extractor import KeywordExtractor, collect_keywords_queries logger = logging.getLogger(__name__) @@ -59,7 +60,14 @@ def rerank_query_text( @dataclass(slots=True) class ParsedQuery: - """Container for query parser facts.""" + """ + Container for query parser facts. + + ``keywords_queries`` parallels text variants: key ``base`` (see + ``keyword_extractor.KEYWORDS_QUERY_BASE_KEY``) for ``rewritten_query``, + and the same language codes as ``translations`` for each translated string. + Entries with no extracted nouns are omitted. + """ original_query: str query_normalized: str @@ -69,6 +77,7 @@ class ParsedQuery: query_vector: Optional[np.ndarray] = None image_query_vector: Optional[np.ndarray] = None query_tokens: List[str] = field(default_factory=list) + keywords_queries: Dict[str, str] = field(default_factory=dict) style_intent_profile: Optional[StyleIntentProfile] = None product_title_exclusion_profile: Optional[ProductTitleExclusionProfile] = None @@ -91,6 +100,7 @@ class ParsedQuery: "has_query_vector": self.query_vector is not None, "has_image_query_vector": self.image_query_vector is not None, "query_tokens": self.query_tokens, + "keywords_queries": dict(self.keywords_queries), "style_intent_profile": ( self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None ), @@ -138,6 +148,7 @@ class QueryParser: self.language_detector = LanguageDetector() self.rewriter = QueryRewriter(config.query_config.rewrite_dictionary) self._tokenizer = tokenizer or self._build_tokenizer() + self._keyword_extractor = KeywordExtractor(tokenizer=self._tokenizer) self.style_intent_registry = StyleIntentRegistry.from_query_config(config.query_config) self.style_intent_detector = StyleIntentDetector( self.style_intent_registry, @@ -523,6 +534,16 @@ class QueryParser: if translations and context: context.store_intermediate_result("translations", translations) + keywords_queries: Dict[str, str] = {} + try: + keywords_queries = collect_keywords_queries( + self._keyword_extractor, + query_text, + translations, + ) + except Exception as e: + log_info(f"Keyword extraction failed | Error: {e}") + # Build result base_result = ParsedQuery( original_query=query, @@ -533,6 +554,7 @@ class QueryParser: query_vector=query_vector, image_query_vector=image_query_vector, query_tokens=query_tokens, + keywords_queries=keywords_queries, ) style_intent_profile = self.style_intent_detector.detect(base_result) product_title_exclusion_profile = self.product_title_exclusion_detector.detect(base_result) @@ -555,6 +577,7 @@ class QueryParser: query_vector=query_vector, image_query_vector=image_query_vector, query_tokens=query_tokens, + keywords_queries=keywords_queries, style_intent_profile=style_intent_profile, product_title_exclusion_profile=product_title_exclusion_profile, ) diff --git a/search/es_query_builder.py b/search/es_query_builder.py index 09f6bfe..59d8543 100644 --- a/search/es_query_builder.py +++ b/search/es_query_builder.py @@ -12,6 +12,7 @@ from typing import Dict, Any, List, Optional, Tuple import numpy as np from config import FunctionScoreConfig +from query.keyword_extractor import KEYWORDS_QUERY_BASE_KEY class ESQueryBuilder: @@ -39,6 +40,7 @@ class ESQueryBuilder: knn_image_num_candidates: int = 400, base_minimum_should_match: str = "70%", translation_minimum_should_match: str = "70%", + keywords_minimum_should_match: str = "50%", translation_boost: float = 0.4, tie_breaker_base_query: float = 0.9, best_fields_boosts: Optional[Dict[str, float]] = None, @@ -85,6 +87,7 @@ class ESQueryBuilder: self.knn_image_num_candidates = int(knn_image_num_candidates) self.base_minimum_should_match = base_minimum_should_match self.translation_minimum_should_match = translation_minimum_should_match + self.keywords_minimum_should_match = str(keywords_minimum_should_match) self.translation_boost = float(translation_boost) self.tie_breaker_base_query = float(tie_breaker_base_query) default_best_fields = { @@ -505,6 +508,7 @@ class ESQueryBuilder: clause_name: str, *, is_source: bool, + keywords_query: Optional[str] = None, ) -> Optional[Dict[str, Any]]: combined_fields = self._match_field_strings(lang) if not combined_fields: @@ -512,6 +516,26 @@ class ESQueryBuilder: minimum_should_match = ( self.base_minimum_should_match if is_source else self.translation_minimum_should_match ) + must_clauses: List[Dict[str, Any]] = [ + { + "combined_fields": { + "query": lang_query, + "fields": combined_fields, + "minimum_should_match": minimum_should_match, + } + } + ] + kw = (keywords_query or "").strip() + if kw: + must_clauses.append( + { + "combined_fields": { + "query": kw, + "fields": combined_fields, + "minimum_should_match": self.keywords_minimum_should_match, + } + } + ) should_clauses = [ clause for clause in ( @@ -523,15 +547,7 @@ class ESQueryBuilder: clause: Dict[str, Any] = { "bool": { "_name": clause_name, - "must": [ - { - "combined_fields": { - "query": lang_query, - "fields": combined_fields, - "minimum_should_match": minimum_should_match, - } - } - ], + "must": must_clauses, } } if should_clauses: @@ -572,6 +588,11 @@ class ESQueryBuilder: base_query_text = ( getattr(parsed_query, "rewritten_query", None) if parsed_query else None ) or query_text + kw_by_variant: Dict[str, str] = ( + getattr(parsed_query, "keywords_queries", None) or {} + if parsed_query + else {} + ) if base_query_text: base_clause = self._build_lexical_language_clause( @@ -579,6 +600,7 @@ class ESQueryBuilder: base_query_text, "base_query", is_source=True, + keywords_query=(kw_by_variant.get(KEYWORDS_QUERY_BASE_KEY) or "").strip(), ) if base_clause: should_clauses.append(base_clause) @@ -590,11 +612,13 @@ class ESQueryBuilder: continue if normalized_lang == source_lang and normalized_text == base_query_text: continue + trans_kw = (kw_by_variant.get(normalized_lang) or "").strip() trans_clause = self._build_lexical_language_clause( normalized_lang, normalized_text, f"base_query_trans_{normalized_lang}", is_source=False, + keywords_query=trans_kw, ) if trans_clause: should_clauses.append(trans_clause) diff --git a/tests/test_es_query_builder_text_recall_languages.py b/tests/test_es_query_builder_text_recall_languages.py index ff98e64..8cfca7c 100644 --- a/tests/test_es_query_builder_text_recall_languages.py +++ b/tests/test_es_query_builder_text_recall_languages.py @@ -11,6 +11,7 @@ from typing import Any, Dict, List import numpy as np +from query.keyword_extractor import KEYWORDS_QUERY_BASE_KEY from search.es_query_builder import ESQueryBuilder @@ -129,6 +130,29 @@ def test_zh_query_index_zh_en_includes_base_zh_and_trans_en(): assert "title.en" in _title_fields(idx["base_query_trans_en"]) +def test_keywords_combined_fields_second_must_same_fields_and_50pct(): + """When ParsedQuery.keywords_queries is set, must includes a second combined_fields.""" + qb = _builder_multilingual_title_only(default_language="en") + parsed = SimpleNamespace( + rewritten_query="连衣裙", + detected_language="zh", + translations={"en": "red dress"}, + keywords_queries={KEYWORDS_QUERY_BASE_KEY: "连衣 裙", "en": "dress"}, + ) + q = qb.build_query(query_text="连衣裙", parsed_query=parsed, enable_knn=False) + idx = _clauses_index(q) + base = idx["base_query"] + assert len(base["must"]) == 2 + assert base["must"][0]["combined_fields"]["query"] == "连衣裙" + assert base["must"][1]["combined_fields"]["query"] == "连衣 裙" + assert base["must"][1]["combined_fields"]["minimum_should_match"] == "50%" + assert base["must"][1]["combined_fields"]["fields"] == base["must"][0]["combined_fields"]["fields"] + trans = idx["base_query_trans_en"] + assert len(trans["must"]) == 2 + assert trans["must"][1]["combined_fields"]["query"] == "dress" + assert trans["must"][1]["combined_fields"]["minimum_should_match"] == "50%" + + def test_en_query_index_zh_en_includes_base_en_and_trans_zh(): qb = _builder_multilingual_title_only(default_language="en") q = _build( diff --git a/tests/test_keywords_query.py b/tests/test_keywords_query.py new file mode 100644 index 0000000..440f643 --- /dev/null +++ b/tests/test_keywords_query.py @@ -0,0 +1,115 @@ +import hanlp +from typing import List, Tuple, Dict, Any + +class KeywordExtractor: + """ + 基于 HanLP 的名词关键词提取器 + """ + def __init__(self): + # 加载带位置信息的分词模型(细粒度) + self.tok = hanlp.load(hanlp.pretrained.tok.CTB9_TOK_ELECTRA_BASE_CRF) + self.tok.config.output_spans = True # 启用位置输出 + + # 加载词性标注模型 + self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL) + + def extract_keywords(self, query: str) -> str: + """ + 从查询中提取关键词(名词,长度 ≥ 2) + + Args: + query: 输入文本 + + Returns: + 拼接后的关键词字符串,非连续词之间自动插入空格 + """ + query = query.strip() + # 分词结果带位置:[[word, start, end], ...] + tok_result_with_position = self.tok(query) + tok_result = [x[0] for x in tok_result_with_position] + + # 词性标注 + pos_tag_result = list(zip(tok_result, self.pos_tag(tok_result))) + + # 需要忽略的词 + ignore_keywords = ['玩具'] + + keywords = [] + last_end_pos = 0 + + for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position): + if len(word) >= 2 and postag.startswith('N'): + if word in ignore_keywords: + continue + # 如果当前词与上一个词在原文中不连续,插入空格 + if start_pos != last_end_pos and keywords: + keywords.append(" ") + keywords.append(word) + last_end_pos = end_pos + # 可选:打印调试信息 + # print(f'分词: {word} | 词性: {postag} | 起始: {start_pos} | 结束: {end_pos}') + + return "".join(keywords).strip() + + +# 测试代码 +if __name__ == "__main__": + extractor = KeywordExtractor() + + test_queries = [ + # 中文(保留 9 个代表性查询) + "2.4G遥控大蛇", + "充气的篮球", + "遥控 塑料 飞船 汽车 ", + "亚克力相框", + "8寸 搪胶蘑菇钉", + "7寸娃娃", + "太空沙套装", + "脚蹬工程车", + "捏捏乐钥匙扣", + + # 英文(新增) + "plastic toy car", + "remote control helicopter", + "inflatable beach ball", + "music keychain", + "sand play set", + # 常见商品搜索 + "plastic dinosaur toy", + "wireless bluetooth speaker", + "4K action camera", + "stainless steel water bottle", + "baby stroller with cup holder", + + # 疑问式 / 自然语言 + "what is the best smartphone under 500 dollars", + "how to clean a laptop screen", + "where can I buy organic coffee beans", + + # 含数字、特殊字符 + "USB-C to HDMI adapter 4K", + "LED strip lights 16.4ft", + "Nintendo Switch OLED model", + "iPhone 15 Pro Max case", + + # 简短词组 + "gaming mouse", + "mechanical keyboard", + "wireless earbuds", + + # 长尾词 + "rechargeable AA batteries with charger", + "foldable picnic blanket waterproof", + + # 商品属性组合 + "women's running shoes size 8", + "men's cotton t-shirt crew neck", + + + # 其他语种(保留原样,用于多语言测试) + "свет USB с пультом дистанционного управления красочные", # 俄语 + ] + + for q in test_queries: + keywords = extractor.extract_keywords(q) + print(f"{q:30} => {keywords}") -- libgit2 0.21.2