Commit ceaf6d03e28f33c267072d3138ed2b962b521cf2
1 parent
ed13851c
召回限定:must条件补充主干词命中逻辑。baseline的主干词提取方法。
TODO-keywords限定-done.txt
Showing
8 changed files
with
394 additions
and
24 deletions
Show diff stats
config/config.yaml
| ... | ... | @@ -100,16 +100,19 @@ query_config: |
| 100 | 100 | |
| 101 | 101 | # 查询翻译模型(须与 services.translation.capabilities 中某项一致) |
| 102 | 102 | # 源语种在租户 index_languages 内:主召回可打在源语种字段,用下面三项。 |
| 103 | - # zh_to_en_model: "opus-mt-zh-en" | |
| 104 | - # en_to_zh_model: "opus-mt-en-zh" | |
| 105 | - # default_translation_model: "nllb-200-distilled-600m" | |
| 106 | - zh_to_en_model: "deepl" | |
| 107 | - en_to_zh_model: "deepl" | |
| 108 | - default_translation_model: "deepl" | |
| 103 | + zh_to_en_model: "nllb-200-distilled-600m" # "opus-mt-zh-en" | |
| 104 | + en_to_zh_model: "nllb-200-distilled-600m" # "opus-mt-en-zh" | |
| 105 | + default_translation_model: "nllb-200-distilled-600m" | |
| 106 | + # zh_to_en_model: "deepl" | |
| 107 | + # en_to_zh_model: "deepl" | |
| 108 | + # default_translation_model: "deepl" | |
| 109 | 109 | # 源语种不在 index_languages:翻译对可检索文本更关键,可单独指定(缺省则与上一组相同) |
| 110 | - zh_to_en_model__source_not_in_index: "deepl" | |
| 111 | - en_to_zh_model__source_not_in_index: "deepl" | |
| 112 | - default_translation_model__source_not_in_index: "deepl" | |
| 110 | + zh_to_en_model__source_not_in_index: "nllb-200-distilled-600m" | |
| 111 | + en_to_zh_model__source_not_in_index: "nllb-200-distilled-600m" | |
| 112 | + default_translation_model__source_not_in_index: "nllb-200-distilled-600m" | |
| 113 | + # zh_to_en_model__source_not_in_index: "deepl" | |
| 114 | + # en_to_zh_model__source_not_in_index: "deepl" | |
| 115 | + # default_translation_model__source_not_in_index: "deepl" | |
| 113 | 116 | |
| 114 | 117 | # 查询解析阶段:翻译与 query 向量并发执行,共用同一等待预算(毫秒)。 |
| 115 | 118 | # 检测语言已在租户 index_languages 内:较短;不在索引语言内:较长(翻译对召回更关键)。 |
| ... | ... | @@ -153,8 +156,8 @@ query_config: |
| 153 | 156 | |
| 154 | 157 | # 统一文本召回策略(主查询 + 翻译查询) |
| 155 | 158 | text_query_strategy: |
| 156 | - base_minimum_should_match: "75%" | |
| 157 | - translation_minimum_should_match: "75%" | |
| 159 | + base_minimum_should_match: "60%" | |
| 160 | + translation_minimum_should_match: "60%" | |
| 158 | 161 | translation_boost: 0.75 |
| 159 | 162 | tie_breaker_base_query: 0.5 |
| 160 | 163 | best_fields_boost: 2.0 |
| ... | ... | @@ -207,8 +210,8 @@ query_config: |
| 207 | 210 | - skus |
| 208 | 211 | |
| 209 | 212 | # KNN:文本向量与多模态(图片)向量各自 boost 与召回(k / num_candidates) |
| 210 | - knn_text_boost: 20 | |
| 211 | - knn_image_boost: 20 | |
| 213 | + knn_text_boost: 4 | |
| 214 | + knn_image_boost: 4 | |
| 212 | 215 | |
| 213 | 216 | knn_text_k: 150 |
| 214 | 217 | knn_text_num_candidates: 400 |
| ... | ... | @@ -247,7 +250,7 @@ rerank: |
| 247 | 250 | knn_image_weight: 1.0 |
| 248 | 251 | knn_tie_breaker: 0.1 |
| 249 | 252 | knn_bias: 0.6 |
| 250 | - knn_exponent: 0.2 | |
| 253 | + knn_exponent: 0.0 | |
| 251 | 254 | |
| 252 | 255 | # 可扩展服务/provider 注册表(单一配置源) |
| 253 | 256 | services: | ... | ... |
| ... | ... | @@ -0,0 +1,93 @@ |
| 1 | +@query/query_parser.py @scripts/es_debug_search.py | |
| 2 | +原始query、以及每一个翻译,都要有一个对应的keywords_query(token分词后,得到名词) | |
| 3 | +参考这段代码,获取每一个长度大于 1 的名词,然后用空格拼接起来,作为keywords_query | |
| 4 | +import hanlp | |
| 5 | +from typing import List, Tuple, Dict, Any | |
| 6 | + | |
| 7 | +class KeywordExtractor: | |
| 8 | + """ | |
| 9 | + 基于 HanLP 的名词关键词提取器 | |
| 10 | + """ | |
| 11 | + def __init__(self): | |
| 12 | + # 加载带位置信息的分词模型(细粒度) | |
| 13 | + self.tok = hanlp.load(hanlp.pretrained.tok.CTB9_TOK_ELECTRA_BASE_CRF) | |
| 14 | + self.tok.config.output_spans = True # 启用位置输出 | |
| 15 | + | |
| 16 | + # 加载词性标注模型 | |
| 17 | + self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL) | |
| 18 | + | |
| 19 | + def extract_keywords(self, query: str) -> str: | |
| 20 | + """ | |
| 21 | + 从查询中提取关键词(名词,长度 ≥ 2) | |
| 22 | + | |
| 23 | + Args: | |
| 24 | + query: 输入文本 | |
| 25 | + | |
| 26 | + Returns: | |
| 27 | + 拼接后的关键词字符串,非连续词之间自动插入空格 | |
| 28 | + """ | |
| 29 | + query = query.strip() | |
| 30 | + # 分词结果带位置:[[word, start, end], ...] | |
| 31 | + tok_result_with_position = self.tok(query) | |
| 32 | + tok_result = [x[0] for x in tok_result_with_position] | |
| 33 | + | |
| 34 | + # 词性标注 | |
| 35 | + pos_tag_result = list(zip(tok_result, self.pos_tag(tok_result))) | |
| 36 | + | |
| 37 | + # 需要忽略的词 | |
| 38 | + ignore_keywords = ['玩具'] | |
| 39 | + | |
| 40 | + keywords = [] | |
| 41 | + last_end_pos = 0 | |
| 42 | + | |
| 43 | + for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position): | |
| 44 | + if len(word) >= 2 and postag.startswith('N'): | |
| 45 | + if word in ignore_keywords: | |
| 46 | + continue | |
| 47 | + # 如果当前词与上一个词在原文中不连续,插入空格 | |
| 48 | + if start_pos != last_end_pos and keywords: | |
| 49 | + keywords.append(" ") | |
| 50 | + keywords.append(word) | |
| 51 | + last_end_pos = end_pos | |
| 52 | + # 可选:打印调试信息 | |
| 53 | + # print(f'分词: {word} | 词性: {postag} | 起始: {start_pos} | 结束: {end_pos}') | |
| 54 | + | |
| 55 | + return "".join(keywords).strip() | |
| 56 | + | |
| 57 | + | |
| 58 | +最后,在组织检索表达式时,目前是每一个 query (base_query base_query_trans_en base_query_trans_zh 三种情况)。 会组成一个bool查询,以base_query为例: | |
| 59 | + "bool": { | |
| 60 | + "should": [ | |
| 61 | + { | |
| 62 | + "bool": { | |
| 63 | + "_name": "base_query", | |
| 64 | + "must": [ | |
| 65 | + { | |
| 66 | + "combined_fields": { | |
| 67 | +... | |
| 68 | + } | |
| 69 | + } | |
| 70 | + ], | |
| 71 | + "should": [ | |
| 72 | + { | |
| 73 | + "multi_match": { | |
| 74 | +... "type": "best_fields", | |
| 75 | +... | |
| 76 | + }, | |
| 77 | + { | |
| 78 | + "multi_match": { | |
| 79 | +... | |
| 80 | + "type": "phrase", | |
| 81 | +... | |
| 82 | + } | |
| 83 | + } | |
| 84 | + ] | |
| 85 | + } | |
| 86 | + }, | |
| 87 | + | |
| 88 | +base_query_trans_en base_query_trans_zh 也是同样 | |
| 89 | + | |
| 90 | +在这个布尔查询的must里面加一项:keywords,搜索的字段和combined_fields一样,命中比例要求50% | |
| 91 | + | |
| 92 | + | |
| 93 | +结合现有代码做出合理的设计,呈现简单清晰的数据接口,而不是打补丁 | |
| 0 | 94 | \ No newline at end of file | ... | ... |
query/__init__.py
| ... | ... | @@ -2,6 +2,7 @@ |
| 2 | 2 | |
| 3 | 3 | from .language_detector import LanguageDetector |
| 4 | 4 | from .query_rewriter import QueryRewriter, QueryNormalizer |
| 5 | +from .keyword_extractor import KEYWORDS_QUERY_BASE_KEY | |
| 5 | 6 | from .query_parser import QueryParser, ParsedQuery |
| 6 | 7 | |
| 7 | 8 | __all__ = [ |
| ... | ... | @@ -10,4 +11,5 @@ __all__ = [ |
| 10 | 11 | 'QueryNormalizer', |
| 11 | 12 | 'QueryParser', |
| 12 | 13 | 'ParsedQuery', |
| 14 | + 'KEYWORDS_QUERY_BASE_KEY', | |
| 13 | 15 | ] | ... | ... |
| ... | ... | @@ -0,0 +1,86 @@ |
| 1 | +""" | |
| 2 | +HanLP-based noun keyword string for lexical constraints (token POS starts with N, length >= 2). | |
| 3 | + | |
| 4 | +``ParsedQuery.keywords_queries`` uses the same key layout as text variants: | |
| 5 | +``KEYWORDS_QUERY_BASE_KEY`` for the rewritten source query, and ISO-like language | |
| 6 | +codes for each ``ParsedQuery.translations`` entry (non-empty extractions only). | |
| 7 | +""" | |
| 8 | + | |
| 9 | +from __future__ import annotations | |
| 10 | + | |
| 11 | +import logging | |
| 12 | +from typing import Any, Dict, List, Optional | |
| 13 | + | |
| 14 | +logger = logging.getLogger(__name__) | |
| 15 | + | |
| 16 | +import hanlp # type: ignore | |
| 17 | + | |
| 18 | +# Aligns with ``rewritten_query`` / ES ``base_query`` (not a language code). | |
| 19 | +KEYWORDS_QUERY_BASE_KEY = "base" | |
| 20 | + | |
| 21 | + | |
| 22 | +class KeywordExtractor: | |
| 23 | + """基于 HanLP 的名词关键词提取器(与分词位置对齐,非连续名词间插入空格)。""" | |
| 24 | + | |
| 25 | + def __init__( | |
| 26 | + self, | |
| 27 | + tokenizer: Optional[Any] = None, | |
| 28 | + *, | |
| 29 | + ignore_keywords: Optional[List[str]] = None, | |
| 30 | + ): | |
| 31 | + if tokenizer is not None: | |
| 32 | + self.tok = tokenizer | |
| 33 | + else: | |
| 34 | + self.tok = hanlp.load(hanlp.pretrained.tok.CTB9_TOK_ELECTRA_BASE_CRF) | |
| 35 | + self.tok.config.output_spans = True | |
| 36 | + self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL) | |
| 37 | + self.ignore_keywords = frozenset(ignore_keywords or ["玩具"]) | |
| 38 | + | |
| 39 | + def extract_keywords(self, query: str) -> str: | |
| 40 | + """ | |
| 41 | + 从查询中提取关键词(名词,长度 ≥ 2),以空格分隔非连续片段。 | |
| 42 | + """ | |
| 43 | + query = (query or "").strip() | |
| 44 | + if not query: | |
| 45 | + return "" | |
| 46 | + tok_result_with_position = self.tok(query) | |
| 47 | + tok_result = [x[0] for x in tok_result_with_position] | |
| 48 | + if not tok_result: | |
| 49 | + return "" | |
| 50 | + pos_tags = self.pos_tag(tok_result) | |
| 51 | + pos_tag_result = list(zip(tok_result, pos_tags)) | |
| 52 | + keywords: List[str] = [] | |
| 53 | + last_end_pos = 0 | |
| 54 | + for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position): | |
| 55 | + if len(word) >= 2 and str(postag).startswith("N"): | |
| 56 | + if word in self.ignore_keywords: | |
| 57 | + continue | |
| 58 | + if start_pos != last_end_pos and keywords: | |
| 59 | + keywords.append(" ") | |
| 60 | + keywords.append(word) | |
| 61 | + last_end_pos = end_pos | |
| 62 | + return "".join(keywords).strip() | |
| 63 | + | |
| 64 | + | |
| 65 | +def collect_keywords_queries( | |
| 66 | + extractor: KeywordExtractor, | |
| 67 | + rewritten_query: str, | |
| 68 | + translations: Dict[str, str], | |
| 69 | +) -> Dict[str, str]: | |
| 70 | + """ | |
| 71 | + Build the keyword map for all lexical variants (base + translations). | |
| 72 | + | |
| 73 | + Omits entries when extraction yields an empty string. | |
| 74 | + """ | |
| 75 | + out: Dict[str, str] = {} | |
| 76 | + base_kw = extractor.extract_keywords(rewritten_query) | |
| 77 | + if base_kw: | |
| 78 | + out[KEYWORDS_QUERY_BASE_KEY] = base_kw | |
| 79 | + for lang, text in translations.items(): | |
| 80 | + lang_key = str(lang or "").strip().lower() | |
| 81 | + if not lang_key or not (text or "").strip(): | |
| 82 | + continue | |
| 83 | + kw = extractor.extract_keywords(text) | |
| 84 | + if kw: | |
| 85 | + out[lang_key] = kw | |
| 86 | + return out | ... | ... |
query/query_parser.py
| ... | ... | @@ -27,6 +27,7 @@ from .product_title_exclusion import ( |
| 27 | 27 | from .query_rewriter import QueryRewriter, QueryNormalizer |
| 28 | 28 | from .style_intent import StyleIntentDetector, StyleIntentProfile, StyleIntentRegistry |
| 29 | 29 | from .tokenization import extract_token_strings, simple_tokenize_query |
| 30 | +from .keyword_extractor import KeywordExtractor, collect_keywords_queries | |
| 30 | 31 | |
| 31 | 32 | logger = logging.getLogger(__name__) |
| 32 | 33 | |
| ... | ... | @@ -59,7 +60,14 @@ def rerank_query_text( |
| 59 | 60 | |
| 60 | 61 | @dataclass(slots=True) |
| 61 | 62 | class ParsedQuery: |
| 62 | - """Container for query parser facts.""" | |
| 63 | + """ | |
| 64 | + Container for query parser facts. | |
| 65 | + | |
| 66 | + ``keywords_queries`` parallels text variants: key ``base`` (see | |
| 67 | + ``keyword_extractor.KEYWORDS_QUERY_BASE_KEY``) for ``rewritten_query``, | |
| 68 | + and the same language codes as ``translations`` for each translated string. | |
| 69 | + Entries with no extracted nouns are omitted. | |
| 70 | + """ | |
| 63 | 71 | |
| 64 | 72 | original_query: str |
| 65 | 73 | query_normalized: str |
| ... | ... | @@ -69,6 +77,7 @@ class ParsedQuery: |
| 69 | 77 | query_vector: Optional[np.ndarray] = None |
| 70 | 78 | image_query_vector: Optional[np.ndarray] = None |
| 71 | 79 | query_tokens: List[str] = field(default_factory=list) |
| 80 | + keywords_queries: Dict[str, str] = field(default_factory=dict) | |
| 72 | 81 | style_intent_profile: Optional[StyleIntentProfile] = None |
| 73 | 82 | product_title_exclusion_profile: Optional[ProductTitleExclusionProfile] = None |
| 74 | 83 | |
| ... | ... | @@ -91,6 +100,7 @@ class ParsedQuery: |
| 91 | 100 | "has_query_vector": self.query_vector is not None, |
| 92 | 101 | "has_image_query_vector": self.image_query_vector is not None, |
| 93 | 102 | "query_tokens": self.query_tokens, |
| 103 | + "keywords_queries": dict(self.keywords_queries), | |
| 94 | 104 | "style_intent_profile": ( |
| 95 | 105 | self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None |
| 96 | 106 | ), |
| ... | ... | @@ -138,6 +148,7 @@ class QueryParser: |
| 138 | 148 | self.language_detector = LanguageDetector() |
| 139 | 149 | self.rewriter = QueryRewriter(config.query_config.rewrite_dictionary) |
| 140 | 150 | self._tokenizer = tokenizer or self._build_tokenizer() |
| 151 | + self._keyword_extractor = KeywordExtractor(tokenizer=self._tokenizer) | |
| 141 | 152 | self.style_intent_registry = StyleIntentRegistry.from_query_config(config.query_config) |
| 142 | 153 | self.style_intent_detector = StyleIntentDetector( |
| 143 | 154 | self.style_intent_registry, |
| ... | ... | @@ -523,6 +534,16 @@ class QueryParser: |
| 523 | 534 | if translations and context: |
| 524 | 535 | context.store_intermediate_result("translations", translations) |
| 525 | 536 | |
| 537 | + keywords_queries: Dict[str, str] = {} | |
| 538 | + try: | |
| 539 | + keywords_queries = collect_keywords_queries( | |
| 540 | + self._keyword_extractor, | |
| 541 | + query_text, | |
| 542 | + translations, | |
| 543 | + ) | |
| 544 | + except Exception as e: | |
| 545 | + log_info(f"Keyword extraction failed | Error: {e}") | |
| 546 | + | |
| 526 | 547 | # Build result |
| 527 | 548 | base_result = ParsedQuery( |
| 528 | 549 | original_query=query, |
| ... | ... | @@ -533,6 +554,7 @@ class QueryParser: |
| 533 | 554 | query_vector=query_vector, |
| 534 | 555 | image_query_vector=image_query_vector, |
| 535 | 556 | query_tokens=query_tokens, |
| 557 | + keywords_queries=keywords_queries, | |
| 536 | 558 | ) |
| 537 | 559 | style_intent_profile = self.style_intent_detector.detect(base_result) |
| 538 | 560 | product_title_exclusion_profile = self.product_title_exclusion_detector.detect(base_result) |
| ... | ... | @@ -555,6 +577,7 @@ class QueryParser: |
| 555 | 577 | query_vector=query_vector, |
| 556 | 578 | image_query_vector=image_query_vector, |
| 557 | 579 | query_tokens=query_tokens, |
| 580 | + keywords_queries=keywords_queries, | |
| 558 | 581 | style_intent_profile=style_intent_profile, |
| 559 | 582 | product_title_exclusion_profile=product_title_exclusion_profile, |
| 560 | 583 | ) | ... | ... |
search/es_query_builder.py
| ... | ... | @@ -12,6 +12,7 @@ from typing import Dict, Any, List, Optional, Tuple |
| 12 | 12 | |
| 13 | 13 | import numpy as np |
| 14 | 14 | from config import FunctionScoreConfig |
| 15 | +from query.keyword_extractor import KEYWORDS_QUERY_BASE_KEY | |
| 15 | 16 | |
| 16 | 17 | |
| 17 | 18 | class ESQueryBuilder: |
| ... | ... | @@ -39,6 +40,7 @@ class ESQueryBuilder: |
| 39 | 40 | knn_image_num_candidates: int = 400, |
| 40 | 41 | base_minimum_should_match: str = "70%", |
| 41 | 42 | translation_minimum_should_match: str = "70%", |
| 43 | + keywords_minimum_should_match: str = "50%", | |
| 42 | 44 | translation_boost: float = 0.4, |
| 43 | 45 | tie_breaker_base_query: float = 0.9, |
| 44 | 46 | best_fields_boosts: Optional[Dict[str, float]] = None, |
| ... | ... | @@ -85,6 +87,7 @@ class ESQueryBuilder: |
| 85 | 87 | self.knn_image_num_candidates = int(knn_image_num_candidates) |
| 86 | 88 | self.base_minimum_should_match = base_minimum_should_match |
| 87 | 89 | self.translation_minimum_should_match = translation_minimum_should_match |
| 90 | + self.keywords_minimum_should_match = str(keywords_minimum_should_match) | |
| 88 | 91 | self.translation_boost = float(translation_boost) |
| 89 | 92 | self.tie_breaker_base_query = float(tie_breaker_base_query) |
| 90 | 93 | default_best_fields = { |
| ... | ... | @@ -505,6 +508,7 @@ class ESQueryBuilder: |
| 505 | 508 | clause_name: str, |
| 506 | 509 | *, |
| 507 | 510 | is_source: bool, |
| 511 | + keywords_query: Optional[str] = None, | |
| 508 | 512 | ) -> Optional[Dict[str, Any]]: |
| 509 | 513 | combined_fields = self._match_field_strings(lang) |
| 510 | 514 | if not combined_fields: |
| ... | ... | @@ -512,6 +516,26 @@ class ESQueryBuilder: |
| 512 | 516 | minimum_should_match = ( |
| 513 | 517 | self.base_minimum_should_match if is_source else self.translation_minimum_should_match |
| 514 | 518 | ) |
| 519 | + must_clauses: List[Dict[str, Any]] = [ | |
| 520 | + { | |
| 521 | + "combined_fields": { | |
| 522 | + "query": lang_query, | |
| 523 | + "fields": combined_fields, | |
| 524 | + "minimum_should_match": minimum_should_match, | |
| 525 | + } | |
| 526 | + } | |
| 527 | + ] | |
| 528 | + kw = (keywords_query or "").strip() | |
| 529 | + if kw: | |
| 530 | + must_clauses.append( | |
| 531 | + { | |
| 532 | + "combined_fields": { | |
| 533 | + "query": kw, | |
| 534 | + "fields": combined_fields, | |
| 535 | + "minimum_should_match": self.keywords_minimum_should_match, | |
| 536 | + } | |
| 537 | + } | |
| 538 | + ) | |
| 515 | 539 | should_clauses = [ |
| 516 | 540 | clause |
| 517 | 541 | for clause in ( |
| ... | ... | @@ -523,15 +547,7 @@ class ESQueryBuilder: |
| 523 | 547 | clause: Dict[str, Any] = { |
| 524 | 548 | "bool": { |
| 525 | 549 | "_name": clause_name, |
| 526 | - "must": [ | |
| 527 | - { | |
| 528 | - "combined_fields": { | |
| 529 | - "query": lang_query, | |
| 530 | - "fields": combined_fields, | |
| 531 | - "minimum_should_match": minimum_should_match, | |
| 532 | - } | |
| 533 | - } | |
| 534 | - ], | |
| 550 | + "must": must_clauses, | |
| 535 | 551 | } |
| 536 | 552 | } |
| 537 | 553 | if should_clauses: |
| ... | ... | @@ -572,6 +588,11 @@ class ESQueryBuilder: |
| 572 | 588 | base_query_text = ( |
| 573 | 589 | getattr(parsed_query, "rewritten_query", None) if parsed_query else None |
| 574 | 590 | ) or query_text |
| 591 | + kw_by_variant: Dict[str, str] = ( | |
| 592 | + getattr(parsed_query, "keywords_queries", None) or {} | |
| 593 | + if parsed_query | |
| 594 | + else {} | |
| 595 | + ) | |
| 575 | 596 | |
| 576 | 597 | if base_query_text: |
| 577 | 598 | base_clause = self._build_lexical_language_clause( |
| ... | ... | @@ -579,6 +600,7 @@ class ESQueryBuilder: |
| 579 | 600 | base_query_text, |
| 580 | 601 | "base_query", |
| 581 | 602 | is_source=True, |
| 603 | + keywords_query=(kw_by_variant.get(KEYWORDS_QUERY_BASE_KEY) or "").strip(), | |
| 582 | 604 | ) |
| 583 | 605 | if base_clause: |
| 584 | 606 | should_clauses.append(base_clause) |
| ... | ... | @@ -590,11 +612,13 @@ class ESQueryBuilder: |
| 590 | 612 | continue |
| 591 | 613 | if normalized_lang == source_lang and normalized_text == base_query_text: |
| 592 | 614 | continue |
| 615 | + trans_kw = (kw_by_variant.get(normalized_lang) or "").strip() | |
| 593 | 616 | trans_clause = self._build_lexical_language_clause( |
| 594 | 617 | normalized_lang, |
| 595 | 618 | normalized_text, |
| 596 | 619 | f"base_query_trans_{normalized_lang}", |
| 597 | 620 | is_source=False, |
| 621 | + keywords_query=trans_kw, | |
| 598 | 622 | ) |
| 599 | 623 | if trans_clause: |
| 600 | 624 | should_clauses.append(trans_clause) | ... | ... |
tests/test_es_query_builder_text_recall_languages.py
| ... | ... | @@ -11,6 +11,7 @@ from typing import Any, Dict, List |
| 11 | 11 | |
| 12 | 12 | import numpy as np |
| 13 | 13 | |
| 14 | +from query.keyword_extractor import KEYWORDS_QUERY_BASE_KEY | |
| 14 | 15 | from search.es_query_builder import ESQueryBuilder |
| 15 | 16 | |
| 16 | 17 | |
| ... | ... | @@ -129,6 +130,29 @@ def test_zh_query_index_zh_en_includes_base_zh_and_trans_en(): |
| 129 | 130 | assert "title.en" in _title_fields(idx["base_query_trans_en"]) |
| 130 | 131 | |
| 131 | 132 | |
| 133 | +def test_keywords_combined_fields_second_must_same_fields_and_50pct(): | |
| 134 | + """When ParsedQuery.keywords_queries is set, must includes a second combined_fields.""" | |
| 135 | + qb = _builder_multilingual_title_only(default_language="en") | |
| 136 | + parsed = SimpleNamespace( | |
| 137 | + rewritten_query="连衣裙", | |
| 138 | + detected_language="zh", | |
| 139 | + translations={"en": "red dress"}, | |
| 140 | + keywords_queries={KEYWORDS_QUERY_BASE_KEY: "连衣 裙", "en": "dress"}, | |
| 141 | + ) | |
| 142 | + q = qb.build_query(query_text="连衣裙", parsed_query=parsed, enable_knn=False) | |
| 143 | + idx = _clauses_index(q) | |
| 144 | + base = idx["base_query"] | |
| 145 | + assert len(base["must"]) == 2 | |
| 146 | + assert base["must"][0]["combined_fields"]["query"] == "连衣裙" | |
| 147 | + assert base["must"][1]["combined_fields"]["query"] == "连衣 裙" | |
| 148 | + assert base["must"][1]["combined_fields"]["minimum_should_match"] == "50%" | |
| 149 | + assert base["must"][1]["combined_fields"]["fields"] == base["must"][0]["combined_fields"]["fields"] | |
| 150 | + trans = idx["base_query_trans_en"] | |
| 151 | + assert len(trans["must"]) == 2 | |
| 152 | + assert trans["must"][1]["combined_fields"]["query"] == "dress" | |
| 153 | + assert trans["must"][1]["combined_fields"]["minimum_should_match"] == "50%" | |
| 154 | + | |
| 155 | + | |
| 132 | 156 | def test_en_query_index_zh_en_includes_base_en_and_trans_zh(): |
| 133 | 157 | qb = _builder_multilingual_title_only(default_language="en") |
| 134 | 158 | q = _build( | ... | ... |
| ... | ... | @@ -0,0 +1,115 @@ |
| 1 | +import hanlp | |
| 2 | +from typing import List, Tuple, Dict, Any | |
| 3 | + | |
| 4 | +class KeywordExtractor: | |
| 5 | + """ | |
| 6 | + 基于 HanLP 的名词关键词提取器 | |
| 7 | + """ | |
| 8 | + def __init__(self): | |
| 9 | + # 加载带位置信息的分词模型(细粒度) | |
| 10 | + self.tok = hanlp.load(hanlp.pretrained.tok.CTB9_TOK_ELECTRA_BASE_CRF) | |
| 11 | + self.tok.config.output_spans = True # 启用位置输出 | |
| 12 | + | |
| 13 | + # 加载词性标注模型 | |
| 14 | + self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL) | |
| 15 | + | |
| 16 | + def extract_keywords(self, query: str) -> str: | |
| 17 | + """ | |
| 18 | + 从查询中提取关键词(名词,长度 ≥ 2) | |
| 19 | + | |
| 20 | + Args: | |
| 21 | + query: 输入文本 | |
| 22 | + | |
| 23 | + Returns: | |
| 24 | + 拼接后的关键词字符串,非连续词之间自动插入空格 | |
| 25 | + """ | |
| 26 | + query = query.strip() | |
| 27 | + # 分词结果带位置:[[word, start, end], ...] | |
| 28 | + tok_result_with_position = self.tok(query) | |
| 29 | + tok_result = [x[0] for x in tok_result_with_position] | |
| 30 | + | |
| 31 | + # 词性标注 | |
| 32 | + pos_tag_result = list(zip(tok_result, self.pos_tag(tok_result))) | |
| 33 | + | |
| 34 | + # 需要忽略的词 | |
| 35 | + ignore_keywords = ['玩具'] | |
| 36 | + | |
| 37 | + keywords = [] | |
| 38 | + last_end_pos = 0 | |
| 39 | + | |
| 40 | + for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position): | |
| 41 | + if len(word) >= 2 and postag.startswith('N'): | |
| 42 | + if word in ignore_keywords: | |
| 43 | + continue | |
| 44 | + # 如果当前词与上一个词在原文中不连续,插入空格 | |
| 45 | + if start_pos != last_end_pos and keywords: | |
| 46 | + keywords.append(" ") | |
| 47 | + keywords.append(word) | |
| 48 | + last_end_pos = end_pos | |
| 49 | + # 可选:打印调试信息 | |
| 50 | + # print(f'分词: {word} | 词性: {postag} | 起始: {start_pos} | 结束: {end_pos}') | |
| 51 | + | |
| 52 | + return "".join(keywords).strip() | |
| 53 | + | |
| 54 | + | |
| 55 | +# 测试代码 | |
| 56 | +if __name__ == "__main__": | |
| 57 | + extractor = KeywordExtractor() | |
| 58 | + | |
| 59 | + test_queries = [ | |
| 60 | + # 中文(保留 9 个代表性查询) | |
| 61 | + "2.4G遥控大蛇", | |
| 62 | + "充气的篮球", | |
| 63 | + "遥控 塑料 飞船 汽车 ", | |
| 64 | + "亚克力相框", | |
| 65 | + "8寸 搪胶蘑菇钉", | |
| 66 | + "7寸娃娃", | |
| 67 | + "太空沙套装", | |
| 68 | + "脚蹬工程车", | |
| 69 | + "捏捏乐钥匙扣", | |
| 70 | + | |
| 71 | + # 英文(新增) | |
| 72 | + "plastic toy car", | |
| 73 | + "remote control helicopter", | |
| 74 | + "inflatable beach ball", | |
| 75 | + "music keychain", | |
| 76 | + "sand play set", | |
| 77 | + # 常见商品搜索 | |
| 78 | + "plastic dinosaur toy", | |
| 79 | + "wireless bluetooth speaker", | |
| 80 | + "4K action camera", | |
| 81 | + "stainless steel water bottle", | |
| 82 | + "baby stroller with cup holder", | |
| 83 | + | |
| 84 | + # 疑问式 / 自然语言 | |
| 85 | + "what is the best smartphone under 500 dollars", | |
| 86 | + "how to clean a laptop screen", | |
| 87 | + "where can I buy organic coffee beans", | |
| 88 | + | |
| 89 | + # 含数字、特殊字符 | |
| 90 | + "USB-C to HDMI adapter 4K", | |
| 91 | + "LED strip lights 16.4ft", | |
| 92 | + "Nintendo Switch OLED model", | |
| 93 | + "iPhone 15 Pro Max case", | |
| 94 | + | |
| 95 | + # 简短词组 | |
| 96 | + "gaming mouse", | |
| 97 | + "mechanical keyboard", | |
| 98 | + "wireless earbuds", | |
| 99 | + | |
| 100 | + # 长尾词 | |
| 101 | + "rechargeable AA batteries with charger", | |
| 102 | + "foldable picnic blanket waterproof", | |
| 103 | + | |
| 104 | + # 商品属性组合 | |
| 105 | + "women's running shoes size 8", | |
| 106 | + "men's cotton t-shirt crew neck", | |
| 107 | + | |
| 108 | + | |
| 109 | + # 其他语种(保留原样,用于多语言测试) | |
| 110 | + "свет USB с пультом дистанционного управления красочные", # 俄语 | |
| 111 | + ] | |
| 112 | + | |
| 113 | + for q in test_queries: | |
| 114 | + keywords = extractor.extract_keywords(q) | |
| 115 | + print(f"{q:30} => {keywords}") | ... | ... |