keyword_extractor.py
5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
HanLP-based noun keyword string for lexical constraints (token POS starts with N, length >= 2).
``ParsedQuery.keywords_queries`` uses the same key layout as text variants:
``KEYWORDS_QUERY_BASE_KEY`` for the rewritten source query, and ISO-like language
codes for each ``ParsedQuery.translations`` entry (non-empty extractions only).
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from .english_keyword_extractor import EnglishKeywordExtractor
from .tokenization import QueryTextAnalysisCache
logger = logging.getLogger(__name__)
import hanlp # type: ignore
# Aligns with ``rewritten_query`` / ES ``base_query`` (not a language code).
KEYWORDS_QUERY_BASE_KEY = "base"
# | 场景 | 推荐模型 |
# | :--------- | :------------------------------------------- |
# | 纯中文 + 最高精度 | CTB9_TOK_ELECTRA_BASE_CRF 或 MSR_TOK_ELECTRA_BASE_CRF |
# | 纯中文 + 速度优先 | FINE_ELECTRA_SMALL_ZH (细粒度)或 COARSE_ELECTRA_SMALL_ZH (粗粒度) |
# | **中英文混合** | `UD_TOK_MMINILMV2L6` 或 `UD_TOK_MMINILMV2L12` ( Transformer 编码器的层数不同)|
class KeywordExtractor:
"""基于 HanLP 的名词关键词提取器(与分词位置对齐,非连续名词间插入空格)。"""
def __init__(
self,
tokenizer: Optional[Any] = None,
*,
ignore_keywords: Optional[List[str]] = None,
english_extractor: Optional[EnglishKeywordExtractor] = None,
):
if tokenizer is not None:
self.tok = tokenizer
else:
self.tok = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH)
self.tok.config.output_spans = True
self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL)
self.ignore_keywords = frozenset(ignore_keywords or ["玩具"])
self.english_extractor = english_extractor or EnglishKeywordExtractor()
def extract_keywords(
self,
query: str,
*,
language_hint: Optional[str] = None,
tokenizer_result: Optional[Any] = None,
) -> str:
"""
从查询中提取关键词(名词,长度 ≥ 2),以空格分隔非连续片段。
"""
query = (query or "").strip()
if not query:
return ""
normalized_language = str(language_hint or "").strip().lower()
if normalized_language == "en":
return self.english_extractor.extract_keywords(query)
if normalized_language and normalized_language != "zh":
return ""
tok_result_with_position = (
tokenizer_result if tokenizer_result is not None else self.tok(query)
)
tok_result = [x[0] for x in tok_result_with_position]
if not tok_result:
return ""
pos_tags = self.pos_tag(tok_result)
pos_tag_result = list(zip(tok_result, pos_tags))
keywords: List[str] = []
last_end_pos = 0
for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position):
if len(word) >= 2 and str(postag).startswith("N"):
if word in self.ignore_keywords:
continue
if start_pos != last_end_pos and keywords:
keywords.append(" ")
keywords.append(word)
last_end_pos = end_pos
return "".join(keywords).strip()
def collect_keywords_queries(
extractor: KeywordExtractor,
rewritten_query: str,
translations: Dict[str, str],
*,
source_language: Optional[str] = None,
text_analysis_cache: Optional[QueryTextAnalysisCache] = None,
base_keywords_query: Optional[str] = None,
) -> Dict[str, str]:
"""
Build the keyword map for all lexical variants (base + translations).
Omits entries when extraction yields an empty string.
"""
out: Dict[str, str] = {}
base_kw = base_keywords_query
if base_kw is None:
base_kw = extractor.extract_keywords(
rewritten_query,
language_hint=source_language or (
text_analysis_cache.get_language_hint(rewritten_query)
if text_analysis_cache is not None
else None
),
tokenizer_result=(
text_analysis_cache.get_tokenizer_result(rewritten_query)
if text_analysis_cache is not None
else None
),
)
if base_kw:
out[KEYWORDS_QUERY_BASE_KEY] = base_kw
for lang, text in translations.items():
lang_key = str(lang or "").strip().lower()
if not lang_key or not (text or "").strip():
continue
kw = extractor.extract_keywords(
text,
language_hint=lang_key or (
text_analysis_cache.get_language_hint(text)
if text_analysis_cache is not None
else None
),
tokenizer_result=(
text_analysis_cache.get_tokenizer_result(text)
if text_analysis_cache is not None
else None
),
)
if kw:
out[lang_key] = kw
return out