tokenization.py 6.28 KB
"""
Shared tokenization helpers for query understanding.
"""

from __future__ import annotations

from dataclasses import dataclass
import re
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple


_HAN_PATTERN = re.compile(r"[\u4e00-\u9fff]")
_TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]+|[^\W_]+(?:[-'][^\W_]+)*", re.UNICODE)


def normalize_query_text(text: Optional[str]) -> str:
    if text is None:
        return ""
    return " ".join(str(text).strip().casefold().split())


def simple_tokenize_query(text: str) -> List[str]:
    """
    Lightweight tokenizer for coarse query matching.

    - Consecutive CJK characters form one token
    - Latin / digit runs (with internal hyphens) form tokens
    """
    if not text:
        return []
    return _TOKEN_PATTERN.findall(text)


def contains_han_text(text: Optional[str]) -> bool:
    return bool(text and _HAN_PATTERN.search(str(text)))


def extract_token_strings(tokenizer_result: Any) -> List[str]:
    """Normalize tokenizer output into a flat token string list."""
    if not tokenizer_result:
        return []
    if isinstance(tokenizer_result, str):
        token = tokenizer_result.strip()
        return [token] if token else []

    tokens: List[str] = []
    for item in tokenizer_result:
        token: Optional[str] = None
        if isinstance(item, str):
            token = item
        elif isinstance(item, (list, tuple)) and item:
            token = str(item[0])
        elif item is not None:
            token = str(item)

        if token is None:
            continue
        token = token.strip()
        if token:
            tokens.append(token)
    return tokens


def _dedupe_preserve_order(values: Iterable[str]) -> List[str]:
    result: List[str] = []
    seen = set()
    for value in values:
        normalized = normalize_query_text(value)
        if not normalized or normalized in seen:
            continue
        seen.add(normalized)
        result.append(normalized)
    return result


def _build_phrase_candidates(tokens: Sequence[str], max_ngram: int) -> List[str]:
    if not tokens:
        return []

    phrases: List[str] = []
    upper = max(1, int(max_ngram))
    for size in range(1, upper + 1):
        if size > len(tokens):
            break
        for start in range(0, len(tokens) - size + 1):
            phrase = " ".join(tokens[start:start + size]).strip()
            if phrase:
                phrases.append(phrase)
    return phrases


def _build_coarse_tokens(text: str, fine_tokens: Sequence[str]) -> List[str]:
    coarse_tokens = _dedupe_preserve_order(simple_tokenize_query(text))
    if contains_han_text(text) and fine_tokens:
        return list(_dedupe_preserve_order(fine_tokens))
    return coarse_tokens


@dataclass(frozen=True)
class TokenizedText:
    text: str
    normalized_text: str
    fine_tokens: Tuple[str, ...]
    coarse_tokens: Tuple[str, ...]
    candidates: Tuple[str, ...]


class QueryTextAnalysisCache:
    """Per-parse cache for tokenizer output and derived token bundles."""

    def __init__(self, *, tokenizer: Optional[Callable[[str], Any]] = None) -> None:
        self.tokenizer = tokenizer
        self._tokenizer_results: Dict[str, Any] = {}
        self._tokenized_texts: Dict[Tuple[str, int], TokenizedText] = {}
        self._language_hints: Dict[str, str] = {}

    @staticmethod
    def _normalize_input(text: Optional[str]) -> str:
        return str(text or "").strip()

    def set_language_hint(self, text: Optional[str], language: Optional[str]) -> None:
        normalized_input = self._normalize_input(text)
        normalized_language = normalize_query_text(language)
        if normalized_input and normalized_language:
            self._language_hints[normalized_input] = normalized_language

    def get_language_hint(self, text: Optional[str]) -> Optional[str]:
        normalized_input = self._normalize_input(text)
        if not normalized_input:
            return None
        return self._language_hints.get(normalized_input)

    def _should_use_model_tokenizer(self, text: str) -> bool:
        if self.tokenizer is None:
            return False
        language_hint = self.get_language_hint(text)
        has_han = contains_han_text(text)
        if language_hint == "zh":
            return has_han
        return has_han

    def get_tokenizer_result(self, text: Optional[str]) -> Any:
        normalized_input = self._normalize_input(text)
        if not normalized_input:
            return []
        if not self._should_use_model_tokenizer(normalized_input):
            return simple_tokenize_query(normalized_input)
        if normalized_input not in self._tokenizer_results:
            self._tokenizer_results[normalized_input] = self.tokenizer(normalized_input)
        return self._tokenizer_results[normalized_input]

    def get_tokenized_text(self, text: Optional[str], *, max_ngram: int = 3) -> TokenizedText:
        normalized_input = self._normalize_input(text)
        cache_key = (normalized_input, max(1, int(max_ngram)))
        cached = self._tokenized_texts.get(cache_key)
        if cached is not None:
            return cached

        normalized_text = normalize_query_text(normalized_input)
        fine_raw = extract_token_strings(self.get_tokenizer_result(normalized_input))
        fine_tokens = _dedupe_preserve_order(fine_raw)
        coarse_tokens = _build_coarse_tokens(normalized_input, fine_tokens)

        bundle = TokenizedText(
            text=normalized_input,
            normalized_text=normalized_text,
            fine_tokens=tuple(fine_tokens),
            coarse_tokens=tuple(coarse_tokens),
            candidates=tuple(
                _dedupe_preserve_order(
                    list(fine_tokens)
                    + list(coarse_tokens)
                    + _build_phrase_candidates(fine_tokens, max_ngram=max_ngram)
                    + _build_phrase_candidates(coarse_tokens, max_ngram=max_ngram)
                    + ([normalized_text] if normalized_text else [])
                )
            ),
        )
        self._tokenized_texts[cache_key] = bundle
        return bundle


def tokenize_text(
    text: str,
    *,
    tokenizer: Optional[Callable[[str], Any]] = None,
    max_ngram: int = 3,
) -> TokenizedText:
    return QueryTextAnalysisCache(tokenizer=tokenizer).get_tokenized_text(
        text,
        max_ngram=max_ngram,
    )