""" HanLP-based noun keyword string for lexical constraints (token POS starts with N, length >= 2). ``ParsedQuery.keywords_queries`` uses the same key layout as text variants: ``KEYWORDS_QUERY_BASE_KEY`` for the rewritten source query, and ISO-like language codes for each ``ParsedQuery.translations`` entry (non-empty extractions only). """ from __future__ import annotations import logging from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) import hanlp # type: ignore # Aligns with ``rewritten_query`` / ES ``base_query`` (not a language code). KEYWORDS_QUERY_BASE_KEY = "base" # | 场景 | 推荐模型 | # | :--------- | :------------------------------------------- | # | 纯中文 + 最高精度 | CTB9_TOK_ELECTRA_BASE_CRF 或 MSR_TOK_ELECTRA_BASE_CRF | # | 纯中文 + 速度优先 | FINE_ELECTRA_SMALL_ZH(细粒度)或 COARSE_ELECTRA_SMALL_ZH(粗粒度) | # | **中英文混合** | `UD_TOK_MMINILMV2L6` 或 `UD_TOK_MMINILMV2L12` ( Transformer 编码器的层数不同)| class KeywordExtractor: """基于 HanLP 的名词关键词提取器(与分词位置对齐,非连续名词间插入空格)。""" def __init__( self, tokenizer: Optional[Any] = None, *, ignore_keywords: Optional[List[str]] = None, ): if tokenizer is not None: self.tok = tokenizer else: self.tok = hanlp.load(hanlp.pretrained.tok.UD_TOK_MMINILMV2L6) self.tok.config.output_spans = True self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL) self.ignore_keywords = frozenset(ignore_keywords or ["玩具"]) def extract_keywords(self, query: str) -> str: """ 从查询中提取关键词(名词,长度 ≥ 2),以空格分隔非连续片段。 """ query = (query or "").strip() if not query: return "" tok_result_with_position = self.tok(query) tok_result = [x[0] for x in tok_result_with_position] if not tok_result: return "" pos_tags = self.pos_tag(tok_result) pos_tag_result = list(zip(tok_result, pos_tags)) keywords: List[str] = [] last_end_pos = 0 for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position): if len(word) >= 2 and str(postag).startswith("N"): if word in self.ignore_keywords: continue if start_pos != last_end_pos and keywords: keywords.append(" ") keywords.append(word) last_end_pos = end_pos return "".join(keywords).strip() def collect_keywords_queries( extractor: KeywordExtractor, rewritten_query: str, translations: Dict[str, str], ) -> Dict[str, str]: """ Build the keyword map for all lexical variants (base + translations). Omits entries when extraction yields an empty string. """ out: Dict[str, str] = {} base_kw = extractor.extract_keywords(rewritten_query) if base_kw: out[KEYWORDS_QUERY_BASE_KEY] = base_kw for lang, text in translations.items(): lang_key = str(lang or "").strip().lower() if not lang_key or not (text or "").strip(): continue kw = extractor.extract_keywords(text) if kw: out[lang_key] = kw return out