tokenization.py
6.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
Shared tokenization helpers for query understanding.
"""
from __future__ import annotations
from dataclasses import dataclass
import re
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
_HAN_PATTERN = re.compile(r"[\u4e00-\u9fff]")
_TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]+|[^\W_]+(?:[-'][^\W_]+)*", re.UNICODE)
def normalize_query_text(text: Optional[str]) -> str:
if text is None:
return ""
return " ".join(str(text).strip().casefold().split())
def simple_tokenize_query(text: str) -> List[str]:
"""
Lightweight tokenizer for coarse query matching.
- Consecutive CJK characters form one token
- Latin / digit runs (with internal hyphens) form tokens
"""
if not text:
return []
return _TOKEN_PATTERN.findall(text)
def contains_han_text(text: Optional[str]) -> bool:
return bool(text and _HAN_PATTERN.search(str(text)))
def extract_token_strings(tokenizer_result: Any) -> List[str]:
"""Normalize tokenizer output into a flat token string list."""
if not tokenizer_result:
return []
if isinstance(tokenizer_result, str):
token = tokenizer_result.strip()
return [token] if token else []
tokens: List[str] = []
for item in tokenizer_result:
token: Optional[str] = None
if isinstance(item, str):
token = item
elif isinstance(item, (list, tuple)) and item:
token = str(item[0])
elif item is not None:
token = str(item)
if token is None:
continue
token = token.strip()
if token:
tokens.append(token)
return tokens
def _dedupe_preserve_order(values: Iterable[str]) -> List[str]:
result: List[str] = []
seen = set()
for value in values:
normalized = normalize_query_text(value)
if not normalized or normalized in seen:
continue
seen.add(normalized)
result.append(normalized)
return result
def _build_phrase_candidates(tokens: Sequence[str], max_ngram: int) -> List[str]:
if not tokens:
return []
phrases: List[str] = []
upper = max(1, int(max_ngram))
for size in range(1, upper + 1):
if size > len(tokens):
break
for start in range(0, len(tokens) - size + 1):
phrase = " ".join(tokens[start:start + size]).strip()
if phrase:
phrases.append(phrase)
return phrases
def _build_coarse_tokens(
text: str,
*,
language_hint: Optional[str],
tokenizer_tokens: Sequence[str],
) -> List[str]:
normalized_language = normalize_query_text(language_hint)
if normalized_language == "zh" or (contains_han_text(text) and tokenizer_tokens):
# Chinese coarse tokenization should follow the model tokenizer rather than a
# regex that collapses the whole sentence into one CJK span.
return list(_dedupe_preserve_order(tokenizer_tokens))
return _dedupe_preserve_order(simple_tokenize_query(text))
@dataclass(frozen=True)
class TokenizedText:
text: str
normalized_text: str
fine_tokens: Tuple[str, ...]
coarse_tokens: Tuple[str, ...]
candidates: Tuple[str, ...]
class QueryTextAnalysisCache:
"""Per-parse cache for tokenizer output and derived token bundles."""
def __init__(self, *, tokenizer: Optional[Callable[[str], Any]] = None) -> None:
self.tokenizer = tokenizer
self._tokenizer_results: Dict[str, Any] = {}
self._tokenized_texts: Dict[Tuple[str, int], TokenizedText] = {}
self._language_hints: Dict[str, str] = {}
@staticmethod
def _normalize_input(text: Optional[str]) -> str:
return str(text or "").strip()
def set_language_hint(self, text: Optional[str], language: Optional[str]) -> None:
normalized_input = self._normalize_input(text)
normalized_language = normalize_query_text(language)
if normalized_input and normalized_language:
self._language_hints[normalized_input] = normalized_language
def get_language_hint(self, text: Optional[str]) -> Optional[str]:
normalized_input = self._normalize_input(text)
if not normalized_input:
return None
return self._language_hints.get(normalized_input)
def _should_use_model_tokenizer(self, text: str) -> bool:
if self.tokenizer is None:
return False
language_hint = self.get_language_hint(text)
has_han = contains_han_text(text)
if language_hint == "zh":
return has_han
return has_han
def get_tokenizer_result(self, text: Optional[str]) -> Any:
normalized_input = self._normalize_input(text)
if not normalized_input:
return []
if not self._should_use_model_tokenizer(normalized_input):
return simple_tokenize_query(normalized_input)
if normalized_input not in self._tokenizer_results:
self._tokenizer_results[normalized_input] = self.tokenizer(normalized_input)
return self._tokenizer_results[normalized_input]
def get_tokenized_text(self, text: Optional[str], *, max_ngram: int = 3) -> TokenizedText:
normalized_input = self._normalize_input(text)
cache_key = (normalized_input, max(1, int(max_ngram)))
cached = self._tokenized_texts.get(cache_key)
if cached is not None:
return cached
normalized_text = normalize_query_text(normalized_input)
fine_raw = extract_token_strings(self.get_tokenizer_result(normalized_input))
fine_tokens = _dedupe_preserve_order(fine_raw)
coarse_tokens = _build_coarse_tokens(
normalized_input,
language_hint=self.get_language_hint(normalized_input),
tokenizer_tokens=fine_tokens,
)
bundle = TokenizedText(
text=normalized_input,
normalized_text=normalized_text,
fine_tokens=tuple(fine_tokens),
coarse_tokens=tuple(coarse_tokens),
candidates=tuple(
_dedupe_preserve_order(
list(fine_tokens)
+ list(coarse_tokens)
+ _build_phrase_candidates(fine_tokens, max_ngram=max_ngram)
+ _build_phrase_candidates(coarse_tokens, max_ngram=max_ngram)
+ ([normalized_text] if normalized_text else [])
)
),
)
self._tokenized_texts[cache_key] = bundle
return bundle
def tokenize_text(
text: str,
*,
tokenizer: Optional[Callable[[str], Any]] = None,
max_ngram: int = 3,
) -> TokenizedText:
return QueryTextAnalysisCache(tokenizer=tokenizer).get_tokenized_text(
text,
max_ngram=max_ngram,
)