query_rewriter.py 4.28 KB
"""
Query rewriter for handling synonyms, brand mappings, and query transformations.
"""

from typing import Dict, Optional
import re


class QueryRewriter:
    """Rewrites queries based on configured dictionary rules."""

    def __init__(self, rewrite_dict: Dict[str, str] = None):
        """
        Initialize query rewriter.

        Args:
            rewrite_dict: Dictionary mapping query patterns to rewrite expressions
                         e.g., {"芭比": "brand:芭比 OR name:芭比娃娃"}
        """
        self.rewrite_dict = rewrite_dict or {}

    def rewrite(self, query: str) -> str:
        """
        Rewrite query based on dictionary rules.

        Args:
            query: Original query string

        Returns:
            Rewritten query string
        """
        if not query or not query.strip():
            return query

        # Check for exact matches first
        if query in self.rewrite_dict:
            rewritten = self.rewrite_dict[query]
            print(f"[QueryRewriter] Exact match: '{query}' -> '{rewritten}'")
            return rewritten

        # Check for partial matches (query contains a rewrite key)
        for pattern, replacement in self.rewrite_dict.items():
            if pattern in query:
                # Replace the pattern
                rewritten = query.replace(pattern, replacement)
                print(f"[QueryRewriter] Partial match: '{query}' -> '{rewritten}'")
                return rewritten

        # No rewrite needed
        return query

    def add_rule(self, pattern: str, replacement: str) -> None:
        """
        Add a rewrite rule.

        Args:
            pattern: Query pattern to match
            replacement: Replacement expression
        """
        self.rewrite_dict[pattern] = replacement

    def remove_rule(self, pattern: str) -> bool:
        """
        Remove a rewrite rule.

        Args:
            pattern: Query pattern to remove

        Returns:
            True if rule was removed, False if not found
        """
        if pattern in self.rewrite_dict:
            del self.rewrite_dict[pattern]
            return True
        return False

    def get_rules(self) -> Dict[str, str]:
        """Get all rewrite rules."""
        return self.rewrite_dict.copy()

    def clear_rules(self) -> None:
        """Clear all rewrite rules."""
        self.rewrite_dict.clear()


class QueryNormalizer:
    """Normalizes queries for consistent processing."""

    @staticmethod
    def normalize(query: str) -> str:
        """
        Normalize query string.

        - Trim whitespace
        - Convert multiple spaces to single space
        - Remove special characters (optional)

        Args:
            query: Original query

        Returns:
            Normalized query
        """
        if not query:
            return ""

        # Trim and collapse whitespace
        query = " ".join(query.split())

        return query

    @staticmethod
    def remove_punctuation(query: str, keep_operators: bool = True) -> str:
        """
        Remove punctuation from query.

        Args:
            query: Original query
            keep_operators: Whether to keep boolean operators (AND, OR, etc.)

        Returns:
            Query without punctuation
        """
        if not query:
            return ""

        if keep_operators:
            # Keep alphanumeric, spaces, and operator characters
            pattern = r'[^a-zA-Z0-9\u4e00-\u9fff\u0400-\u04ff\s\(\)|&!-]'
        else:
            # Keep only alphanumeric and spaces
            pattern = r'[^a-zA-Z0-9\u4e00-\u9fff\u0400-\u04ff\s]'

        return re.sub(pattern, '', query)

    @staticmethod
    def extract_domain_query(query: str) -> tuple:
        """
        Extract domain prefix from query if present.

        Examples:
            "brand:Nike shoes" -> ("brand", "Nike shoes")
            "category:toys" -> ("category", "toys")
            "default query" -> ("default", "default query")

        Args:
            query: Query string

        Returns:
            Tuple of (domain, query_text)
        """
        # Check for domain:query pattern
        match = re.match(r'^(\w+):(.+)$', query.strip())
        if match:
            return match.group(1), match.group(2).strip()

        # No domain specified, use default
        return "default", query.strip()