""" Shared tokenization helpers for query understanding. """ from __future__ import annotations from dataclasses import dataclass import re from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple _HAN_PATTERN = re.compile(r"[\u4e00-\u9fff]") _TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]+|[^\W_]+(?:[-'][^\W_]+)*", re.UNICODE) def normalize_query_text(text: Optional[str]) -> str: if text is None: return "" return " ".join(str(text).strip().casefold().split()) def simple_tokenize_query(text: str) -> List[str]: """ Lightweight tokenizer for coarse query matching. - Consecutive CJK characters form one token - Latin / digit runs (with internal hyphens) form tokens """ if not text: return [] return _TOKEN_PATTERN.findall(text) def contains_han_text(text: Optional[str]) -> bool: return bool(text and _HAN_PATTERN.search(str(text))) def extract_token_strings(tokenizer_result: Any) -> List[str]: """Normalize tokenizer output into a flat token string list.""" if not tokenizer_result: return [] if isinstance(tokenizer_result, str): token = tokenizer_result.strip() return [token] if token else [] tokens: List[str] = [] for item in tokenizer_result: token: Optional[str] = None if isinstance(item, str): token = item elif isinstance(item, (list, tuple)) and item: token = str(item[0]) elif item is not None: token = str(item) if token is None: continue token = token.strip() if token: tokens.append(token) return tokens def _dedupe_preserve_order(values: Iterable[str]) -> List[str]: result: List[str] = [] seen = set() for value in values: normalized = normalize_query_text(value) if not normalized or normalized in seen: continue seen.add(normalized) result.append(normalized) return result def _build_phrase_candidates(tokens: Sequence[str], max_ngram: int) -> List[str]: if not tokens: return [] phrases: List[str] = [] upper = max(1, int(max_ngram)) for size in range(1, upper + 1): if size > len(tokens): break for start in range(0, len(tokens) - size + 1): phrase = " ".join(tokens[start:start + size]).strip() if phrase: phrases.append(phrase) return phrases def _build_coarse_tokens( text: str, *, language_hint: Optional[str], tokenizer_tokens: Sequence[str], ) -> List[str]: normalized_language = normalize_query_text(language_hint) if normalized_language == "zh" or (contains_han_text(text) and tokenizer_tokens): # Chinese coarse tokenization should follow the model tokenizer rather than a # regex that collapses the whole sentence into one CJK span. return list(_dedupe_preserve_order(tokenizer_tokens)) return _dedupe_preserve_order(simple_tokenize_query(text)) @dataclass(frozen=True) class TokenizedText: text: str normalized_text: str fine_tokens: Tuple[str, ...] coarse_tokens: Tuple[str, ...] candidates: Tuple[str, ...] class QueryTextAnalysisCache: """Per-parse cache for tokenizer output and derived token bundles.""" def __init__(self, *, tokenizer: Optional[Callable[[str], Any]] = None) -> None: self.tokenizer = tokenizer self._tokenizer_results: Dict[str, Any] = {} self._tokenized_texts: Dict[Tuple[str, int], TokenizedText] = {} self._language_hints: Dict[str, str] = {} @staticmethod def _normalize_input(text: Optional[str]) -> str: return str(text or "").strip() def set_language_hint(self, text: Optional[str], language: Optional[str]) -> None: normalized_input = self._normalize_input(text) normalized_language = normalize_query_text(language) if normalized_input and normalized_language: self._language_hints[normalized_input] = normalized_language def get_language_hint(self, text: Optional[str]) -> Optional[str]: normalized_input = self._normalize_input(text) if not normalized_input: return None return self._language_hints.get(normalized_input) def _should_use_model_tokenizer(self, text: str) -> bool: if self.tokenizer is None: return False language_hint = self.get_language_hint(text) has_han = contains_han_text(text) if language_hint == "zh": return has_han return has_han def get_tokenizer_result(self, text: Optional[str]) -> Any: normalized_input = self._normalize_input(text) if not normalized_input: return [] if not self._should_use_model_tokenizer(normalized_input): return simple_tokenize_query(normalized_input) if normalized_input not in self._tokenizer_results: self._tokenizer_results[normalized_input] = self.tokenizer(normalized_input) return self._tokenizer_results[normalized_input] def get_tokenized_text(self, text: Optional[str], *, max_ngram: int = 3) -> TokenizedText: normalized_input = self._normalize_input(text) cache_key = (normalized_input, max(1, int(max_ngram))) cached = self._tokenized_texts.get(cache_key) if cached is not None: return cached normalized_text = normalize_query_text(normalized_input) fine_raw = extract_token_strings(self.get_tokenizer_result(normalized_input)) fine_tokens = _dedupe_preserve_order(fine_raw) coarse_tokens = _build_coarse_tokens( normalized_input, language_hint=self.get_language_hint(normalized_input), tokenizer_tokens=fine_tokens, ) bundle = TokenizedText( text=normalized_input, normalized_text=normalized_text, fine_tokens=tuple(fine_tokens), coarse_tokens=tuple(coarse_tokens), candidates=tuple( _dedupe_preserve_order( list(fine_tokens) + list(coarse_tokens) + _build_phrase_candidates(fine_tokens, max_ngram=max_ngram) + _build_phrase_candidates(coarse_tokens, max_ngram=max_ngram) + ([normalized_text] if normalized_text else []) ) ), ) self._tokenized_texts[cache_key] = bundle return bundle def tokenize_text( text: str, *, tokenizer: Optional[Callable[[str], Any]] = None, max_ngram: int = 3, ) -> TokenizedText: return QueryTextAnalysisCache(tokenizer=tokenizer).get_tokenized_text( text, max_ngram=max_ngram, )