tokenization.py 3.36 KB
"""
Shared tokenization helpers for query understanding.
"""

from __future__ import annotations

from dataclasses import dataclass
import re
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple


_TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]+|[A-Za-z0-9_]+(?:-[A-Za-z0-9_]+)*")


def normalize_query_text(text: Optional[str]) -> str:
    if text is None:
        return ""
    return " ".join(str(text).strip().casefold().split())


def simple_tokenize_query(text: str) -> List[str]:
    """
    Lightweight tokenizer for coarse query matching.

    - Consecutive CJK characters form one token
    - Latin / digit runs (with internal hyphens) form tokens
    """
    if not text:
        return []
    return _TOKEN_PATTERN.findall(text)


def extract_token_strings(tokenizer_result: Any) -> List[str]:
    """Normalize tokenizer output into a flat token string list."""
    if not tokenizer_result:
        return []
    if isinstance(tokenizer_result, str):
        token = tokenizer_result.strip()
        return [token] if token else []

    tokens: List[str] = []
    for item in tokenizer_result:
        token: Optional[str] = None
        if isinstance(item, str):
            token = item
        elif isinstance(item, (list, tuple)) and item:
            token = str(item[0])
        elif item is not None:
            token = str(item)

        if token is None:
            continue
        token = token.strip()
        if token:
            tokens.append(token)
    return tokens


def _dedupe_preserve_order(values: Iterable[str]) -> List[str]:
    result: List[str] = []
    seen = set()
    for value in values:
        normalized = normalize_query_text(value)
        if not normalized or normalized in seen:
            continue
        seen.add(normalized)
        result.append(normalized)
    return result


def _build_phrase_candidates(tokens: Sequence[str], max_ngram: int) -> List[str]:
    if not tokens:
        return []

    phrases: List[str] = []
    upper = max(1, int(max_ngram))
    for size in range(1, upper + 1):
        if size > len(tokens):
            break
        for start in range(0, len(tokens) - size + 1):
            phrase = " ".join(tokens[start:start + size]).strip()
            if phrase:
                phrases.append(phrase)
    return phrases


@dataclass(frozen=True)
class TokenizedText:
    text: str
    normalized_text: str
    fine_tokens: Tuple[str, ...]
    coarse_tokens: Tuple[str, ...]
    candidates: Tuple[str, ...]


def tokenize_text(
    text: str,
    *,
    tokenizer: Optional[Callable[[str], Any]] = None,
    max_ngram: int = 3,
) -> TokenizedText:
    normalized_text = normalize_query_text(text)
    coarse_tokens = _dedupe_preserve_order(simple_tokenize_query(text))

    fine_raw = extract_token_strings(tokenizer(text)) if tokenizer is not None and text else []
    fine_tokens = _dedupe_preserve_order(fine_raw)

    candidates = _dedupe_preserve_order(
        list(fine_tokens)
        + list(coarse_tokens)
        + _build_phrase_candidates(fine_tokens, max_ngram=max_ngram)
        + _build_phrase_candidates(coarse_tokens, max_ngram=max_ngram)
        + ([normalized_text] if normalized_text else [])
    )

    return TokenizedText(
        text=text,
        normalized_text=normalized_text,
        fine_tokens=tuple(fine_tokens),
        coarse_tokens=tuple(coarse_tokens),
        candidates=tuple(candidates),
    )