Commit ceaf6d03e28f33c267072d3138ed2b962b521cf2
1 parent
ed13851c
召回限定:must条件补充主干词命中逻辑。baseline的主干词提取方法。
TODO-keywords限定-done.txt
Showing
8 changed files
with
394 additions
and
24 deletions
Show diff stats
config/config.yaml
| @@ -100,16 +100,19 @@ query_config: | @@ -100,16 +100,19 @@ query_config: | ||
| 100 | 100 | ||
| 101 | # 查询翻译模型(须与 services.translation.capabilities 中某项一致) | 101 | # 查询翻译模型(须与 services.translation.capabilities 中某项一致) |
| 102 | # 源语种在租户 index_languages 内:主召回可打在源语种字段,用下面三项。 | 102 | # 源语种在租户 index_languages 内:主召回可打在源语种字段,用下面三项。 |
| 103 | - # zh_to_en_model: "opus-mt-zh-en" | ||
| 104 | - # en_to_zh_model: "opus-mt-en-zh" | ||
| 105 | - # default_translation_model: "nllb-200-distilled-600m" | ||
| 106 | - zh_to_en_model: "deepl" | ||
| 107 | - en_to_zh_model: "deepl" | ||
| 108 | - default_translation_model: "deepl" | 103 | + zh_to_en_model: "nllb-200-distilled-600m" # "opus-mt-zh-en" |
| 104 | + en_to_zh_model: "nllb-200-distilled-600m" # "opus-mt-en-zh" | ||
| 105 | + default_translation_model: "nllb-200-distilled-600m" | ||
| 106 | + # zh_to_en_model: "deepl" | ||
| 107 | + # en_to_zh_model: "deepl" | ||
| 108 | + # default_translation_model: "deepl" | ||
| 109 | # 源语种不在 index_languages:翻译对可检索文本更关键,可单独指定(缺省则与上一组相同) | 109 | # 源语种不在 index_languages:翻译对可检索文本更关键,可单独指定(缺省则与上一组相同) |
| 110 | - zh_to_en_model__source_not_in_index: "deepl" | ||
| 111 | - en_to_zh_model__source_not_in_index: "deepl" | ||
| 112 | - default_translation_model__source_not_in_index: "deepl" | 110 | + zh_to_en_model__source_not_in_index: "nllb-200-distilled-600m" |
| 111 | + en_to_zh_model__source_not_in_index: "nllb-200-distilled-600m" | ||
| 112 | + default_translation_model__source_not_in_index: "nllb-200-distilled-600m" | ||
| 113 | + # zh_to_en_model__source_not_in_index: "deepl" | ||
| 114 | + # en_to_zh_model__source_not_in_index: "deepl" | ||
| 115 | + # default_translation_model__source_not_in_index: "deepl" | ||
| 113 | 116 | ||
| 114 | # 查询解析阶段:翻译与 query 向量并发执行,共用同一等待预算(毫秒)。 | 117 | # 查询解析阶段:翻译与 query 向量并发执行,共用同一等待预算(毫秒)。 |
| 115 | # 检测语言已在租户 index_languages 内:较短;不在索引语言内:较长(翻译对召回更关键)。 | 118 | # 检测语言已在租户 index_languages 内:较短;不在索引语言内:较长(翻译对召回更关键)。 |
| @@ -153,8 +156,8 @@ query_config: | @@ -153,8 +156,8 @@ query_config: | ||
| 153 | 156 | ||
| 154 | # 统一文本召回策略(主查询 + 翻译查询) | 157 | # 统一文本召回策略(主查询 + 翻译查询) |
| 155 | text_query_strategy: | 158 | text_query_strategy: |
| 156 | - base_minimum_should_match: "75%" | ||
| 157 | - translation_minimum_should_match: "75%" | 159 | + base_minimum_should_match: "60%" |
| 160 | + translation_minimum_should_match: "60%" | ||
| 158 | translation_boost: 0.75 | 161 | translation_boost: 0.75 |
| 159 | tie_breaker_base_query: 0.5 | 162 | tie_breaker_base_query: 0.5 |
| 160 | best_fields_boost: 2.0 | 163 | best_fields_boost: 2.0 |
| @@ -207,8 +210,8 @@ query_config: | @@ -207,8 +210,8 @@ query_config: | ||
| 207 | - skus | 210 | - skus |
| 208 | 211 | ||
| 209 | # KNN:文本向量与多模态(图片)向量各自 boost 与召回(k / num_candidates) | 212 | # KNN:文本向量与多模态(图片)向量各自 boost 与召回(k / num_candidates) |
| 210 | - knn_text_boost: 20 | ||
| 211 | - knn_image_boost: 20 | 213 | + knn_text_boost: 4 |
| 214 | + knn_image_boost: 4 | ||
| 212 | 215 | ||
| 213 | knn_text_k: 150 | 216 | knn_text_k: 150 |
| 214 | knn_text_num_candidates: 400 | 217 | knn_text_num_candidates: 400 |
| @@ -247,7 +250,7 @@ rerank: | @@ -247,7 +250,7 @@ rerank: | ||
| 247 | knn_image_weight: 1.0 | 250 | knn_image_weight: 1.0 |
| 248 | knn_tie_breaker: 0.1 | 251 | knn_tie_breaker: 0.1 |
| 249 | knn_bias: 0.6 | 252 | knn_bias: 0.6 |
| 250 | - knn_exponent: 0.2 | 253 | + knn_exponent: 0.0 |
| 251 | 254 | ||
| 252 | # 可扩展服务/provider 注册表(单一配置源) | 255 | # 可扩展服务/provider 注册表(单一配置源) |
| 253 | services: | 256 | services: |
| @@ -0,0 +1,93 @@ | @@ -0,0 +1,93 @@ | ||
| 1 | +@query/query_parser.py @scripts/es_debug_search.py | ||
| 2 | +原始query、以及每一个翻译,都要有一个对应的keywords_query(token分词后,得到名词) | ||
| 3 | +参考这段代码,获取每一个长度大于 1 的名词,然后用空格拼接起来,作为keywords_query | ||
| 4 | +import hanlp | ||
| 5 | +from typing import List, Tuple, Dict, Any | ||
| 6 | + | ||
| 7 | +class KeywordExtractor: | ||
| 8 | + """ | ||
| 9 | + 基于 HanLP 的名词关键词提取器 | ||
| 10 | + """ | ||
| 11 | + def __init__(self): | ||
| 12 | + # 加载带位置信息的分词模型(细粒度) | ||
| 13 | + self.tok = hanlp.load(hanlp.pretrained.tok.CTB9_TOK_ELECTRA_BASE_CRF) | ||
| 14 | + self.tok.config.output_spans = True # 启用位置输出 | ||
| 15 | + | ||
| 16 | + # 加载词性标注模型 | ||
| 17 | + self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL) | ||
| 18 | + | ||
| 19 | + def extract_keywords(self, query: str) -> str: | ||
| 20 | + """ | ||
| 21 | + 从查询中提取关键词(名词,长度 ≥ 2) | ||
| 22 | + | ||
| 23 | + Args: | ||
| 24 | + query: 输入文本 | ||
| 25 | + | ||
| 26 | + Returns: | ||
| 27 | + 拼接后的关键词字符串,非连续词之间自动插入空格 | ||
| 28 | + """ | ||
| 29 | + query = query.strip() | ||
| 30 | + # 分词结果带位置:[[word, start, end], ...] | ||
| 31 | + tok_result_with_position = self.tok(query) | ||
| 32 | + tok_result = [x[0] for x in tok_result_with_position] | ||
| 33 | + | ||
| 34 | + # 词性标注 | ||
| 35 | + pos_tag_result = list(zip(tok_result, self.pos_tag(tok_result))) | ||
| 36 | + | ||
| 37 | + # 需要忽略的词 | ||
| 38 | + ignore_keywords = ['玩具'] | ||
| 39 | + | ||
| 40 | + keywords = [] | ||
| 41 | + last_end_pos = 0 | ||
| 42 | + | ||
| 43 | + for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position): | ||
| 44 | + if len(word) >= 2 and postag.startswith('N'): | ||
| 45 | + if word in ignore_keywords: | ||
| 46 | + continue | ||
| 47 | + # 如果当前词与上一个词在原文中不连续,插入空格 | ||
| 48 | + if start_pos != last_end_pos and keywords: | ||
| 49 | + keywords.append(" ") | ||
| 50 | + keywords.append(word) | ||
| 51 | + last_end_pos = end_pos | ||
| 52 | + # 可选:打印调试信息 | ||
| 53 | + # print(f'分词: {word} | 词性: {postag} | 起始: {start_pos} | 结束: {end_pos}') | ||
| 54 | + | ||
| 55 | + return "".join(keywords).strip() | ||
| 56 | + | ||
| 57 | + | ||
| 58 | +最后,在组织检索表达式时,目前是每一个 query (base_query base_query_trans_en base_query_trans_zh 三种情况)。 会组成一个bool查询,以base_query为例: | ||
| 59 | + "bool": { | ||
| 60 | + "should": [ | ||
| 61 | + { | ||
| 62 | + "bool": { | ||
| 63 | + "_name": "base_query", | ||
| 64 | + "must": [ | ||
| 65 | + { | ||
| 66 | + "combined_fields": { | ||
| 67 | +... | ||
| 68 | + } | ||
| 69 | + } | ||
| 70 | + ], | ||
| 71 | + "should": [ | ||
| 72 | + { | ||
| 73 | + "multi_match": { | ||
| 74 | +... "type": "best_fields", | ||
| 75 | +... | ||
| 76 | + }, | ||
| 77 | + { | ||
| 78 | + "multi_match": { | ||
| 79 | +... | ||
| 80 | + "type": "phrase", | ||
| 81 | +... | ||
| 82 | + } | ||
| 83 | + } | ||
| 84 | + ] | ||
| 85 | + } | ||
| 86 | + }, | ||
| 87 | + | ||
| 88 | +base_query_trans_en base_query_trans_zh 也是同样 | ||
| 89 | + | ||
| 90 | +在这个布尔查询的must里面加一项:keywords,搜索的字段和combined_fields一样,命中比例要求50% | ||
| 91 | + | ||
| 92 | + | ||
| 93 | +结合现有代码做出合理的设计,呈现简单清晰的数据接口,而不是打补丁 | ||
| 0 | \ No newline at end of file | 94 | \ No newline at end of file |
query/__init__.py
| @@ -2,6 +2,7 @@ | @@ -2,6 +2,7 @@ | ||
| 2 | 2 | ||
| 3 | from .language_detector import LanguageDetector | 3 | from .language_detector import LanguageDetector |
| 4 | from .query_rewriter import QueryRewriter, QueryNormalizer | 4 | from .query_rewriter import QueryRewriter, QueryNormalizer |
| 5 | +from .keyword_extractor import KEYWORDS_QUERY_BASE_KEY | ||
| 5 | from .query_parser import QueryParser, ParsedQuery | 6 | from .query_parser import QueryParser, ParsedQuery |
| 6 | 7 | ||
| 7 | __all__ = [ | 8 | __all__ = [ |
| @@ -10,4 +11,5 @@ __all__ = [ | @@ -10,4 +11,5 @@ __all__ = [ | ||
| 10 | 'QueryNormalizer', | 11 | 'QueryNormalizer', |
| 11 | 'QueryParser', | 12 | 'QueryParser', |
| 12 | 'ParsedQuery', | 13 | 'ParsedQuery', |
| 14 | + 'KEYWORDS_QUERY_BASE_KEY', | ||
| 13 | ] | 15 | ] |
| @@ -0,0 +1,86 @@ | @@ -0,0 +1,86 @@ | ||
| 1 | +""" | ||
| 2 | +HanLP-based noun keyword string for lexical constraints (token POS starts with N, length >= 2). | ||
| 3 | + | ||
| 4 | +``ParsedQuery.keywords_queries`` uses the same key layout as text variants: | ||
| 5 | +``KEYWORDS_QUERY_BASE_KEY`` for the rewritten source query, and ISO-like language | ||
| 6 | +codes for each ``ParsedQuery.translations`` entry (non-empty extractions only). | ||
| 7 | +""" | ||
| 8 | + | ||
| 9 | +from __future__ import annotations | ||
| 10 | + | ||
| 11 | +import logging | ||
| 12 | +from typing import Any, Dict, List, Optional | ||
| 13 | + | ||
| 14 | +logger = logging.getLogger(__name__) | ||
| 15 | + | ||
| 16 | +import hanlp # type: ignore | ||
| 17 | + | ||
| 18 | +# Aligns with ``rewritten_query`` / ES ``base_query`` (not a language code). | ||
| 19 | +KEYWORDS_QUERY_BASE_KEY = "base" | ||
| 20 | + | ||
| 21 | + | ||
| 22 | +class KeywordExtractor: | ||
| 23 | + """基于 HanLP 的名词关键词提取器(与分词位置对齐,非连续名词间插入空格)。""" | ||
| 24 | + | ||
| 25 | + def __init__( | ||
| 26 | + self, | ||
| 27 | + tokenizer: Optional[Any] = None, | ||
| 28 | + *, | ||
| 29 | + ignore_keywords: Optional[List[str]] = None, | ||
| 30 | + ): | ||
| 31 | + if tokenizer is not None: | ||
| 32 | + self.tok = tokenizer | ||
| 33 | + else: | ||
| 34 | + self.tok = hanlp.load(hanlp.pretrained.tok.CTB9_TOK_ELECTRA_BASE_CRF) | ||
| 35 | + self.tok.config.output_spans = True | ||
| 36 | + self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL) | ||
| 37 | + self.ignore_keywords = frozenset(ignore_keywords or ["玩具"]) | ||
| 38 | + | ||
| 39 | + def extract_keywords(self, query: str) -> str: | ||
| 40 | + """ | ||
| 41 | + 从查询中提取关键词(名词,长度 ≥ 2),以空格分隔非连续片段。 | ||
| 42 | + """ | ||
| 43 | + query = (query or "").strip() | ||
| 44 | + if not query: | ||
| 45 | + return "" | ||
| 46 | + tok_result_with_position = self.tok(query) | ||
| 47 | + tok_result = [x[0] for x in tok_result_with_position] | ||
| 48 | + if not tok_result: | ||
| 49 | + return "" | ||
| 50 | + pos_tags = self.pos_tag(tok_result) | ||
| 51 | + pos_tag_result = list(zip(tok_result, pos_tags)) | ||
| 52 | + keywords: List[str] = [] | ||
| 53 | + last_end_pos = 0 | ||
| 54 | + for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position): | ||
| 55 | + if len(word) >= 2 and str(postag).startswith("N"): | ||
| 56 | + if word in self.ignore_keywords: | ||
| 57 | + continue | ||
| 58 | + if start_pos != last_end_pos and keywords: | ||
| 59 | + keywords.append(" ") | ||
| 60 | + keywords.append(word) | ||
| 61 | + last_end_pos = end_pos | ||
| 62 | + return "".join(keywords).strip() | ||
| 63 | + | ||
| 64 | + | ||
| 65 | +def collect_keywords_queries( | ||
| 66 | + extractor: KeywordExtractor, | ||
| 67 | + rewritten_query: str, | ||
| 68 | + translations: Dict[str, str], | ||
| 69 | +) -> Dict[str, str]: | ||
| 70 | + """ | ||
| 71 | + Build the keyword map for all lexical variants (base + translations). | ||
| 72 | + | ||
| 73 | + Omits entries when extraction yields an empty string. | ||
| 74 | + """ | ||
| 75 | + out: Dict[str, str] = {} | ||
| 76 | + base_kw = extractor.extract_keywords(rewritten_query) | ||
| 77 | + if base_kw: | ||
| 78 | + out[KEYWORDS_QUERY_BASE_KEY] = base_kw | ||
| 79 | + for lang, text in translations.items(): | ||
| 80 | + lang_key = str(lang or "").strip().lower() | ||
| 81 | + if not lang_key or not (text or "").strip(): | ||
| 82 | + continue | ||
| 83 | + kw = extractor.extract_keywords(text) | ||
| 84 | + if kw: | ||
| 85 | + out[lang_key] = kw | ||
| 86 | + return out |
query/query_parser.py
| @@ -27,6 +27,7 @@ from .product_title_exclusion import ( | @@ -27,6 +27,7 @@ from .product_title_exclusion import ( | ||
| 27 | from .query_rewriter import QueryRewriter, QueryNormalizer | 27 | from .query_rewriter import QueryRewriter, QueryNormalizer |
| 28 | from .style_intent import StyleIntentDetector, StyleIntentProfile, StyleIntentRegistry | 28 | from .style_intent import StyleIntentDetector, StyleIntentProfile, StyleIntentRegistry |
| 29 | from .tokenization import extract_token_strings, simple_tokenize_query | 29 | from .tokenization import extract_token_strings, simple_tokenize_query |
| 30 | +from .keyword_extractor import KeywordExtractor, collect_keywords_queries | ||
| 30 | 31 | ||
| 31 | logger = logging.getLogger(__name__) | 32 | logger = logging.getLogger(__name__) |
| 32 | 33 | ||
| @@ -59,7 +60,14 @@ def rerank_query_text( | @@ -59,7 +60,14 @@ def rerank_query_text( | ||
| 59 | 60 | ||
| 60 | @dataclass(slots=True) | 61 | @dataclass(slots=True) |
| 61 | class ParsedQuery: | 62 | class ParsedQuery: |
| 62 | - """Container for query parser facts.""" | 63 | + """ |
| 64 | + Container for query parser facts. | ||
| 65 | + | ||
| 66 | + ``keywords_queries`` parallels text variants: key ``base`` (see | ||
| 67 | + ``keyword_extractor.KEYWORDS_QUERY_BASE_KEY``) for ``rewritten_query``, | ||
| 68 | + and the same language codes as ``translations`` for each translated string. | ||
| 69 | + Entries with no extracted nouns are omitted. | ||
| 70 | + """ | ||
| 63 | 71 | ||
| 64 | original_query: str | 72 | original_query: str |
| 65 | query_normalized: str | 73 | query_normalized: str |
| @@ -69,6 +77,7 @@ class ParsedQuery: | @@ -69,6 +77,7 @@ class ParsedQuery: | ||
| 69 | query_vector: Optional[np.ndarray] = None | 77 | query_vector: Optional[np.ndarray] = None |
| 70 | image_query_vector: Optional[np.ndarray] = None | 78 | image_query_vector: Optional[np.ndarray] = None |
| 71 | query_tokens: List[str] = field(default_factory=list) | 79 | query_tokens: List[str] = field(default_factory=list) |
| 80 | + keywords_queries: Dict[str, str] = field(default_factory=dict) | ||
| 72 | style_intent_profile: Optional[StyleIntentProfile] = None | 81 | style_intent_profile: Optional[StyleIntentProfile] = None |
| 73 | product_title_exclusion_profile: Optional[ProductTitleExclusionProfile] = None | 82 | product_title_exclusion_profile: Optional[ProductTitleExclusionProfile] = None |
| 74 | 83 | ||
| @@ -91,6 +100,7 @@ class ParsedQuery: | @@ -91,6 +100,7 @@ class ParsedQuery: | ||
| 91 | "has_query_vector": self.query_vector is not None, | 100 | "has_query_vector": self.query_vector is not None, |
| 92 | "has_image_query_vector": self.image_query_vector is not None, | 101 | "has_image_query_vector": self.image_query_vector is not None, |
| 93 | "query_tokens": self.query_tokens, | 102 | "query_tokens": self.query_tokens, |
| 103 | + "keywords_queries": dict(self.keywords_queries), | ||
| 94 | "style_intent_profile": ( | 104 | "style_intent_profile": ( |
| 95 | self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None | 105 | self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None |
| 96 | ), | 106 | ), |
| @@ -138,6 +148,7 @@ class QueryParser: | @@ -138,6 +148,7 @@ class QueryParser: | ||
| 138 | self.language_detector = LanguageDetector() | 148 | self.language_detector = LanguageDetector() |
| 139 | self.rewriter = QueryRewriter(config.query_config.rewrite_dictionary) | 149 | self.rewriter = QueryRewriter(config.query_config.rewrite_dictionary) |
| 140 | self._tokenizer = tokenizer or self._build_tokenizer() | 150 | self._tokenizer = tokenizer or self._build_tokenizer() |
| 151 | + self._keyword_extractor = KeywordExtractor(tokenizer=self._tokenizer) | ||
| 141 | self.style_intent_registry = StyleIntentRegistry.from_query_config(config.query_config) | 152 | self.style_intent_registry = StyleIntentRegistry.from_query_config(config.query_config) |
| 142 | self.style_intent_detector = StyleIntentDetector( | 153 | self.style_intent_detector = StyleIntentDetector( |
| 143 | self.style_intent_registry, | 154 | self.style_intent_registry, |
| @@ -523,6 +534,16 @@ class QueryParser: | @@ -523,6 +534,16 @@ class QueryParser: | ||
| 523 | if translations and context: | 534 | if translations and context: |
| 524 | context.store_intermediate_result("translations", translations) | 535 | context.store_intermediate_result("translations", translations) |
| 525 | 536 | ||
| 537 | + keywords_queries: Dict[str, str] = {} | ||
| 538 | + try: | ||
| 539 | + keywords_queries = collect_keywords_queries( | ||
| 540 | + self._keyword_extractor, | ||
| 541 | + query_text, | ||
| 542 | + translations, | ||
| 543 | + ) | ||
| 544 | + except Exception as e: | ||
| 545 | + log_info(f"Keyword extraction failed | Error: {e}") | ||
| 546 | + | ||
| 526 | # Build result | 547 | # Build result |
| 527 | base_result = ParsedQuery( | 548 | base_result = ParsedQuery( |
| 528 | original_query=query, | 549 | original_query=query, |
| @@ -533,6 +554,7 @@ class QueryParser: | @@ -533,6 +554,7 @@ class QueryParser: | ||
| 533 | query_vector=query_vector, | 554 | query_vector=query_vector, |
| 534 | image_query_vector=image_query_vector, | 555 | image_query_vector=image_query_vector, |
| 535 | query_tokens=query_tokens, | 556 | query_tokens=query_tokens, |
| 557 | + keywords_queries=keywords_queries, | ||
| 536 | ) | 558 | ) |
| 537 | style_intent_profile = self.style_intent_detector.detect(base_result) | 559 | style_intent_profile = self.style_intent_detector.detect(base_result) |
| 538 | product_title_exclusion_profile = self.product_title_exclusion_detector.detect(base_result) | 560 | product_title_exclusion_profile = self.product_title_exclusion_detector.detect(base_result) |
| @@ -555,6 +577,7 @@ class QueryParser: | @@ -555,6 +577,7 @@ class QueryParser: | ||
| 555 | query_vector=query_vector, | 577 | query_vector=query_vector, |
| 556 | image_query_vector=image_query_vector, | 578 | image_query_vector=image_query_vector, |
| 557 | query_tokens=query_tokens, | 579 | query_tokens=query_tokens, |
| 580 | + keywords_queries=keywords_queries, | ||
| 558 | style_intent_profile=style_intent_profile, | 581 | style_intent_profile=style_intent_profile, |
| 559 | product_title_exclusion_profile=product_title_exclusion_profile, | 582 | product_title_exclusion_profile=product_title_exclusion_profile, |
| 560 | ) | 583 | ) |
search/es_query_builder.py
| @@ -12,6 +12,7 @@ from typing import Dict, Any, List, Optional, Tuple | @@ -12,6 +12,7 @@ from typing import Dict, Any, List, Optional, Tuple | ||
| 12 | 12 | ||
| 13 | import numpy as np | 13 | import numpy as np |
| 14 | from config import FunctionScoreConfig | 14 | from config import FunctionScoreConfig |
| 15 | +from query.keyword_extractor import KEYWORDS_QUERY_BASE_KEY | ||
| 15 | 16 | ||
| 16 | 17 | ||
| 17 | class ESQueryBuilder: | 18 | class ESQueryBuilder: |
| @@ -39,6 +40,7 @@ class ESQueryBuilder: | @@ -39,6 +40,7 @@ class ESQueryBuilder: | ||
| 39 | knn_image_num_candidates: int = 400, | 40 | knn_image_num_candidates: int = 400, |
| 40 | base_minimum_should_match: str = "70%", | 41 | base_minimum_should_match: str = "70%", |
| 41 | translation_minimum_should_match: str = "70%", | 42 | translation_minimum_should_match: str = "70%", |
| 43 | + keywords_minimum_should_match: str = "50%", | ||
| 42 | translation_boost: float = 0.4, | 44 | translation_boost: float = 0.4, |
| 43 | tie_breaker_base_query: float = 0.9, | 45 | tie_breaker_base_query: float = 0.9, |
| 44 | best_fields_boosts: Optional[Dict[str, float]] = None, | 46 | best_fields_boosts: Optional[Dict[str, float]] = None, |
| @@ -85,6 +87,7 @@ class ESQueryBuilder: | @@ -85,6 +87,7 @@ class ESQueryBuilder: | ||
| 85 | self.knn_image_num_candidates = int(knn_image_num_candidates) | 87 | self.knn_image_num_candidates = int(knn_image_num_candidates) |
| 86 | self.base_minimum_should_match = base_minimum_should_match | 88 | self.base_minimum_should_match = base_minimum_should_match |
| 87 | self.translation_minimum_should_match = translation_minimum_should_match | 89 | self.translation_minimum_should_match = translation_minimum_should_match |
| 90 | + self.keywords_minimum_should_match = str(keywords_minimum_should_match) | ||
| 88 | self.translation_boost = float(translation_boost) | 91 | self.translation_boost = float(translation_boost) |
| 89 | self.tie_breaker_base_query = float(tie_breaker_base_query) | 92 | self.tie_breaker_base_query = float(tie_breaker_base_query) |
| 90 | default_best_fields = { | 93 | default_best_fields = { |
| @@ -505,6 +508,7 @@ class ESQueryBuilder: | @@ -505,6 +508,7 @@ class ESQueryBuilder: | ||
| 505 | clause_name: str, | 508 | clause_name: str, |
| 506 | *, | 509 | *, |
| 507 | is_source: bool, | 510 | is_source: bool, |
| 511 | + keywords_query: Optional[str] = None, | ||
| 508 | ) -> Optional[Dict[str, Any]]: | 512 | ) -> Optional[Dict[str, Any]]: |
| 509 | combined_fields = self._match_field_strings(lang) | 513 | combined_fields = self._match_field_strings(lang) |
| 510 | if not combined_fields: | 514 | if not combined_fields: |
| @@ -512,6 +516,26 @@ class ESQueryBuilder: | @@ -512,6 +516,26 @@ class ESQueryBuilder: | ||
| 512 | minimum_should_match = ( | 516 | minimum_should_match = ( |
| 513 | self.base_minimum_should_match if is_source else self.translation_minimum_should_match | 517 | self.base_minimum_should_match if is_source else self.translation_minimum_should_match |
| 514 | ) | 518 | ) |
| 519 | + must_clauses: List[Dict[str, Any]] = [ | ||
| 520 | + { | ||
| 521 | + "combined_fields": { | ||
| 522 | + "query": lang_query, | ||
| 523 | + "fields": combined_fields, | ||
| 524 | + "minimum_should_match": minimum_should_match, | ||
| 525 | + } | ||
| 526 | + } | ||
| 527 | + ] | ||
| 528 | + kw = (keywords_query or "").strip() | ||
| 529 | + if kw: | ||
| 530 | + must_clauses.append( | ||
| 531 | + { | ||
| 532 | + "combined_fields": { | ||
| 533 | + "query": kw, | ||
| 534 | + "fields": combined_fields, | ||
| 535 | + "minimum_should_match": self.keywords_minimum_should_match, | ||
| 536 | + } | ||
| 537 | + } | ||
| 538 | + ) | ||
| 515 | should_clauses = [ | 539 | should_clauses = [ |
| 516 | clause | 540 | clause |
| 517 | for clause in ( | 541 | for clause in ( |
| @@ -523,15 +547,7 @@ class ESQueryBuilder: | @@ -523,15 +547,7 @@ class ESQueryBuilder: | ||
| 523 | clause: Dict[str, Any] = { | 547 | clause: Dict[str, Any] = { |
| 524 | "bool": { | 548 | "bool": { |
| 525 | "_name": clause_name, | 549 | "_name": clause_name, |
| 526 | - "must": [ | ||
| 527 | - { | ||
| 528 | - "combined_fields": { | ||
| 529 | - "query": lang_query, | ||
| 530 | - "fields": combined_fields, | ||
| 531 | - "minimum_should_match": minimum_should_match, | ||
| 532 | - } | ||
| 533 | - } | ||
| 534 | - ], | 550 | + "must": must_clauses, |
| 535 | } | 551 | } |
| 536 | } | 552 | } |
| 537 | if should_clauses: | 553 | if should_clauses: |
| @@ -572,6 +588,11 @@ class ESQueryBuilder: | @@ -572,6 +588,11 @@ class ESQueryBuilder: | ||
| 572 | base_query_text = ( | 588 | base_query_text = ( |
| 573 | getattr(parsed_query, "rewritten_query", None) if parsed_query else None | 589 | getattr(parsed_query, "rewritten_query", None) if parsed_query else None |
| 574 | ) or query_text | 590 | ) or query_text |
| 591 | + kw_by_variant: Dict[str, str] = ( | ||
| 592 | + getattr(parsed_query, "keywords_queries", None) or {} | ||
| 593 | + if parsed_query | ||
| 594 | + else {} | ||
| 595 | + ) | ||
| 575 | 596 | ||
| 576 | if base_query_text: | 597 | if base_query_text: |
| 577 | base_clause = self._build_lexical_language_clause( | 598 | base_clause = self._build_lexical_language_clause( |
| @@ -579,6 +600,7 @@ class ESQueryBuilder: | @@ -579,6 +600,7 @@ class ESQueryBuilder: | ||
| 579 | base_query_text, | 600 | base_query_text, |
| 580 | "base_query", | 601 | "base_query", |
| 581 | is_source=True, | 602 | is_source=True, |
| 603 | + keywords_query=(kw_by_variant.get(KEYWORDS_QUERY_BASE_KEY) or "").strip(), | ||
| 582 | ) | 604 | ) |
| 583 | if base_clause: | 605 | if base_clause: |
| 584 | should_clauses.append(base_clause) | 606 | should_clauses.append(base_clause) |
| @@ -590,11 +612,13 @@ class ESQueryBuilder: | @@ -590,11 +612,13 @@ class ESQueryBuilder: | ||
| 590 | continue | 612 | continue |
| 591 | if normalized_lang == source_lang and normalized_text == base_query_text: | 613 | if normalized_lang == source_lang and normalized_text == base_query_text: |
| 592 | continue | 614 | continue |
| 615 | + trans_kw = (kw_by_variant.get(normalized_lang) or "").strip() | ||
| 593 | trans_clause = self._build_lexical_language_clause( | 616 | trans_clause = self._build_lexical_language_clause( |
| 594 | normalized_lang, | 617 | normalized_lang, |
| 595 | normalized_text, | 618 | normalized_text, |
| 596 | f"base_query_trans_{normalized_lang}", | 619 | f"base_query_trans_{normalized_lang}", |
| 597 | is_source=False, | 620 | is_source=False, |
| 621 | + keywords_query=trans_kw, | ||
| 598 | ) | 622 | ) |
| 599 | if trans_clause: | 623 | if trans_clause: |
| 600 | should_clauses.append(trans_clause) | 624 | should_clauses.append(trans_clause) |
tests/test_es_query_builder_text_recall_languages.py
| @@ -11,6 +11,7 @@ from typing import Any, Dict, List | @@ -11,6 +11,7 @@ from typing import Any, Dict, List | ||
| 11 | 11 | ||
| 12 | import numpy as np | 12 | import numpy as np |
| 13 | 13 | ||
| 14 | +from query.keyword_extractor import KEYWORDS_QUERY_BASE_KEY | ||
| 14 | from search.es_query_builder import ESQueryBuilder | 15 | from search.es_query_builder import ESQueryBuilder |
| 15 | 16 | ||
| 16 | 17 | ||
| @@ -129,6 +130,29 @@ def test_zh_query_index_zh_en_includes_base_zh_and_trans_en(): | @@ -129,6 +130,29 @@ def test_zh_query_index_zh_en_includes_base_zh_and_trans_en(): | ||
| 129 | assert "title.en" in _title_fields(idx["base_query_trans_en"]) | 130 | assert "title.en" in _title_fields(idx["base_query_trans_en"]) |
| 130 | 131 | ||
| 131 | 132 | ||
| 133 | +def test_keywords_combined_fields_second_must_same_fields_and_50pct(): | ||
| 134 | + """When ParsedQuery.keywords_queries is set, must includes a second combined_fields.""" | ||
| 135 | + qb = _builder_multilingual_title_only(default_language="en") | ||
| 136 | + parsed = SimpleNamespace( | ||
| 137 | + rewritten_query="连衣裙", | ||
| 138 | + detected_language="zh", | ||
| 139 | + translations={"en": "red dress"}, | ||
| 140 | + keywords_queries={KEYWORDS_QUERY_BASE_KEY: "连衣 裙", "en": "dress"}, | ||
| 141 | + ) | ||
| 142 | + q = qb.build_query(query_text="连衣裙", parsed_query=parsed, enable_knn=False) | ||
| 143 | + idx = _clauses_index(q) | ||
| 144 | + base = idx["base_query"] | ||
| 145 | + assert len(base["must"]) == 2 | ||
| 146 | + assert base["must"][0]["combined_fields"]["query"] == "连衣裙" | ||
| 147 | + assert base["must"][1]["combined_fields"]["query"] == "连衣 裙" | ||
| 148 | + assert base["must"][1]["combined_fields"]["minimum_should_match"] == "50%" | ||
| 149 | + assert base["must"][1]["combined_fields"]["fields"] == base["must"][0]["combined_fields"]["fields"] | ||
| 150 | + trans = idx["base_query_trans_en"] | ||
| 151 | + assert len(trans["must"]) == 2 | ||
| 152 | + assert trans["must"][1]["combined_fields"]["query"] == "dress" | ||
| 153 | + assert trans["must"][1]["combined_fields"]["minimum_should_match"] == "50%" | ||
| 154 | + | ||
| 155 | + | ||
| 132 | def test_en_query_index_zh_en_includes_base_en_and_trans_zh(): | 156 | def test_en_query_index_zh_en_includes_base_en_and_trans_zh(): |
| 133 | qb = _builder_multilingual_title_only(default_language="en") | 157 | qb = _builder_multilingual_title_only(default_language="en") |
| 134 | q = _build( | 158 | q = _build( |
| @@ -0,0 +1,115 @@ | @@ -0,0 +1,115 @@ | ||
| 1 | +import hanlp | ||
| 2 | +from typing import List, Tuple, Dict, Any | ||
| 3 | + | ||
| 4 | +class KeywordExtractor: | ||
| 5 | + """ | ||
| 6 | + 基于 HanLP 的名词关键词提取器 | ||
| 7 | + """ | ||
| 8 | + def __init__(self): | ||
| 9 | + # 加载带位置信息的分词模型(细粒度) | ||
| 10 | + self.tok = hanlp.load(hanlp.pretrained.tok.CTB9_TOK_ELECTRA_BASE_CRF) | ||
| 11 | + self.tok.config.output_spans = True # 启用位置输出 | ||
| 12 | + | ||
| 13 | + # 加载词性标注模型 | ||
| 14 | + self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL) | ||
| 15 | + | ||
| 16 | + def extract_keywords(self, query: str) -> str: | ||
| 17 | + """ | ||
| 18 | + 从查询中提取关键词(名词,长度 ≥ 2) | ||
| 19 | + | ||
| 20 | + Args: | ||
| 21 | + query: 输入文本 | ||
| 22 | + | ||
| 23 | + Returns: | ||
| 24 | + 拼接后的关键词字符串,非连续词之间自动插入空格 | ||
| 25 | + """ | ||
| 26 | + query = query.strip() | ||
| 27 | + # 分词结果带位置:[[word, start, end], ...] | ||
| 28 | + tok_result_with_position = self.tok(query) | ||
| 29 | + tok_result = [x[0] for x in tok_result_with_position] | ||
| 30 | + | ||
| 31 | + # 词性标注 | ||
| 32 | + pos_tag_result = list(zip(tok_result, self.pos_tag(tok_result))) | ||
| 33 | + | ||
| 34 | + # 需要忽略的词 | ||
| 35 | + ignore_keywords = ['玩具'] | ||
| 36 | + | ||
| 37 | + keywords = [] | ||
| 38 | + last_end_pos = 0 | ||
| 39 | + | ||
| 40 | + for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position): | ||
| 41 | + if len(word) >= 2 and postag.startswith('N'): | ||
| 42 | + if word in ignore_keywords: | ||
| 43 | + continue | ||
| 44 | + # 如果当前词与上一个词在原文中不连续,插入空格 | ||
| 45 | + if start_pos != last_end_pos and keywords: | ||
| 46 | + keywords.append(" ") | ||
| 47 | + keywords.append(word) | ||
| 48 | + last_end_pos = end_pos | ||
| 49 | + # 可选:打印调试信息 | ||
| 50 | + # print(f'分词: {word} | 词性: {postag} | 起始: {start_pos} | 结束: {end_pos}') | ||
| 51 | + | ||
| 52 | + return "".join(keywords).strip() | ||
| 53 | + | ||
| 54 | + | ||
| 55 | +# 测试代码 | ||
| 56 | +if __name__ == "__main__": | ||
| 57 | + extractor = KeywordExtractor() | ||
| 58 | + | ||
| 59 | + test_queries = [ | ||
| 60 | + # 中文(保留 9 个代表性查询) | ||
| 61 | + "2.4G遥控大蛇", | ||
| 62 | + "充气的篮球", | ||
| 63 | + "遥控 塑料 飞船 汽车 ", | ||
| 64 | + "亚克力相框", | ||
| 65 | + "8寸 搪胶蘑菇钉", | ||
| 66 | + "7寸娃娃", | ||
| 67 | + "太空沙套装", | ||
| 68 | + "脚蹬工程车", | ||
| 69 | + "捏捏乐钥匙扣", | ||
| 70 | + | ||
| 71 | + # 英文(新增) | ||
| 72 | + "plastic toy car", | ||
| 73 | + "remote control helicopter", | ||
| 74 | + "inflatable beach ball", | ||
| 75 | + "music keychain", | ||
| 76 | + "sand play set", | ||
| 77 | + # 常见商品搜索 | ||
| 78 | + "plastic dinosaur toy", | ||
| 79 | + "wireless bluetooth speaker", | ||
| 80 | + "4K action camera", | ||
| 81 | + "stainless steel water bottle", | ||
| 82 | + "baby stroller with cup holder", | ||
| 83 | + | ||
| 84 | + # 疑问式 / 自然语言 | ||
| 85 | + "what is the best smartphone under 500 dollars", | ||
| 86 | + "how to clean a laptop screen", | ||
| 87 | + "where can I buy organic coffee beans", | ||
| 88 | + | ||
| 89 | + # 含数字、特殊字符 | ||
| 90 | + "USB-C to HDMI adapter 4K", | ||
| 91 | + "LED strip lights 16.4ft", | ||
| 92 | + "Nintendo Switch OLED model", | ||
| 93 | + "iPhone 15 Pro Max case", | ||
| 94 | + | ||
| 95 | + # 简短词组 | ||
| 96 | + "gaming mouse", | ||
| 97 | + "mechanical keyboard", | ||
| 98 | + "wireless earbuds", | ||
| 99 | + | ||
| 100 | + # 长尾词 | ||
| 101 | + "rechargeable AA batteries with charger", | ||
| 102 | + "foldable picnic blanket waterproof", | ||
| 103 | + | ||
| 104 | + # 商品属性组合 | ||
| 105 | + "women's running shoes size 8", | ||
| 106 | + "men's cotton t-shirt crew neck", | ||
| 107 | + | ||
| 108 | + | ||
| 109 | + # 其他语种(保留原样,用于多语言测试) | ||
| 110 | + "свет USB с пультом дистанционного управления красочные", # 俄语 | ||
| 111 | + ] | ||
| 112 | + | ||
| 113 | + for q in test_queries: | ||
| 114 | + keywords = extractor.extract_keywords(q) | ||
| 115 | + print(f"{q:30} => {keywords}") |