multilang_query_builder.py 11.5 KB
"""
Multi-language query builder for handling domain-specific searches.

This module extends the ESQueryBuilder to support multi-language field mappings,
allowing queries to be routed to appropriate language-specific fields while
maintaining a unified external interface.
"""

from typing import Dict, Any, List, Optional
import numpy as np

from config import CustomerConfig, IndexConfig
from query import ParsedQuery
from .es_query_builder import ESQueryBuilder


class MultiLanguageQueryBuilder(ESQueryBuilder):
    """
    Enhanced query builder with multi-language support.

    Handles routing queries to appropriate language-specific fields based on:
    1. Detected query language
    2. Available translations
    3. Domain configuration (language_field_mapping)
    """

    def __init__(
        self,
        config: CustomerConfig,
        index_name: str,
        text_embedding_field: Optional[str] = None,
        image_embedding_field: Optional[str] = None
    ):
        """
        Initialize multi-language query builder.

        Args:
            config: Customer configuration
            index_name: ES index name
            text_embedding_field: Field name for text embeddings
            image_embedding_field: Field name for image embeddings
        """
        self.config = config

        # For default domain, use all fields as fallback
        default_fields = self._get_domain_fields("default")

        super().__init__(
            index_name=index_name,
            match_fields=default_fields,
            text_embedding_field=text_embedding_field,
            image_embedding_field=image_embedding_field
        )

        # Build domain configurations
        self.domain_configs = self._build_domain_configs()

    def _build_domain_configs(self) -> Dict[str, IndexConfig]:
        """Build mapping of domain name to IndexConfig."""
        return {index.name: index for index in self.config.indexes}

    def _get_domain_fields(self, domain_name: str) -> List[str]:
        """Get fields for a specific domain with boost notation."""
        for index in self.config.indexes:
            if index.name == domain_name:
                result = []
                for field_name in index.fields:
                    field = self._get_field_by_name(field_name)
                    if field and field.boost != 1.0:
                        result.append(f"{field_name}^{field.boost}")
                    else:
                        result.append(field_name)
                return result
        return []

    def _get_field_by_name(self, field_name: str):
        """Get field configuration by name."""
        for field in self.config.fields:
            if field.name == field_name:
                return field
        return None

    def build_multilang_query(
        self,
        parsed_query: ParsedQuery,
        query_vector: Optional[np.ndarray] = None,
        filters: Optional[Dict[str, Any]] = None,
        size: int = 10,
        from_: int = 0,
        enable_knn: bool = True,
        knn_k: int = 50,
        knn_num_candidates: int = 200,
        min_score: Optional[float] = None
    ) -> Dict[str, Any]:
        """
        Build ES query with multi-language support.

        Args:
            parsed_query: Parsed query with language info and translations
            query_vector: Query embedding for KNN search
            filters: Additional filters
            size: Number of results
            from_: Offset for pagination
            enable_knn: Whether to use KNN search
            knn_k: K value for KNN
            knn_num_candidates: Number of candidates for KNN
            min_score: Minimum score threshold

        Returns:
            ES query DSL dictionary
        """
        domain = parsed_query.domain
        domain_config = self.domain_configs.get(domain)

        if not domain_config:
            # Fallback to default domain
            domain = "default"
            domain_config = self.domain_configs.get("default")

        if not domain_config:
            # Use original behavior
            return super().build_query(
                query_text=parsed_query.rewritten_query,
                query_vector=query_vector,
                filters=filters,
                size=size,
                from_=from_,
                enable_knn=enable_knn,
                knn_k=knn_k,
                knn_num_candidates=knn_num_candidates,
                min_score=min_score
            )

        print(f"[MultiLangQueryBuilder] Building query for domain: {domain}")
        print(f"[MultiLangQueryBuilder] Detected language: {parsed_query.detected_language}")
        print(f"[MultiLangQueryBuilder] Available translations: {list(parsed_query.translations.keys())}")

        # Build query clause with multi-language support
        query_clause = self._build_multilang_text_query(parsed_query, domain_config)

        es_query = {
            "size": size,
            "from": from_
        }

        # Add filters if provided
        if filters:
            es_query["query"] = {
                "bool": {
                    "must": [query_clause],
                    "filter": self._build_filters(filters)
                }
            }
        else:
            es_query["query"] = query_clause

        # Add KNN search if enabled and vector provided
        if enable_knn and query_vector is not None and self.text_embedding_field:
            knn_clause = {
                "field": self.text_embedding_field,
                "query_vector": query_vector.tolist(),
                "k": knn_k,
                "num_candidates": knn_num_candidates
            }
            es_query["knn"] = knn_clause

        # Add minimum score filter
        if min_score is not None:
            es_query["min_score"] = min_score

        return es_query

    def _build_multilang_text_query(
        self,
        parsed_query: ParsedQuery,
        domain_config: IndexConfig
    ) -> Dict[str, Any]:
        """
        Build text query with multi-language field routing.

        Args:
            parsed_query: Parsed query with language info
            domain_config: Domain configuration

        Returns:
            ES query clause
        """
        if not domain_config.language_field_mapping:
            # No multi-language mapping, use all fields with default analyzer
            fields_with_boost = []
            for field_name in domain_config.fields:
                field = self._get_field_by_name(field_name)
                if field and field.boost != 1.0:
                    fields_with_boost.append(f"{field_name}^{field.boost}")
                else:
                    fields_with_boost.append(field_name)

            return {
                "multi_match": {
                    "query": parsed_query.rewritten_query,
                    "fields": fields_with_boost,
                    "minimum_should_match": "67%",
                    "tie_breaker": 0.9,
                    "boost": domain_config.boost,
                    "_name": f"{domain_config.name}_query"
                }
            }

        # Multi-language mapping exists - build targeted queries
        should_clauses = []
        available_languages = set(domain_config.language_field_mapping.keys())

        # 1. Query in detected language (if it exists in mapping)
        detected_lang = parsed_query.detected_language
        if detected_lang in available_languages:
            target_fields = domain_config.language_field_mapping[detected_lang]
            fields_with_boost = self._apply_field_boosts(target_fields)

            should_clauses.append({
                "multi_match": {
                    "query": parsed_query.rewritten_query,
                    "fields": fields_with_boost,
                    "minimum_should_match": "67%",
                    "tie_breaker": 0.9,
                    "boost": domain_config.boost * 1.5,  # Higher boost for detected language
                    "_name": f"{domain_config.name}_{detected_lang}_query"
                }
            })
            print(f"[MultiLangQueryBuilder] Added query for detected language '{detected_lang}' on fields: {target_fields}")

        # 2. Query in translated languages (only for languages in mapping)
        for lang, translation in parsed_query.translations.items():
            # Only use translations for languages that exist in the mapping
            if lang in available_languages and translation and translation.strip():
                target_fields = domain_config.language_field_mapping[lang]
                fields_with_boost = self._apply_field_boosts(target_fields)

                should_clauses.append({
                    "multi_match": {
                        "query": translation,
                        "fields": fields_with_boost,
                        "minimum_should_match": "67%",
                        "tie_breaker": 0.9,
                        "boost": domain_config.boost,
                        "_name": f"{domain_config.name}_{lang}_translated_query"
                    }
                })
                print(f"[MultiLangQueryBuilder] Added translated query for language '{lang}' on fields: {target_fields}")

        # 3. Fallback: query all fields in mapping if no language-specific query was built
        if not should_clauses:
            print(f"[MultiLangQueryBuilder] No language mapping matched, using all fields from mapping")
            # Use all fields from all languages in the mapping
            all_mapped_fields = []
            for lang_fields in domain_config.language_field_mapping.values():
                all_mapped_fields.extend(lang_fields)
            # Remove duplicates while preserving order
            unique_fields = list(dict.fromkeys(all_mapped_fields))
            fields_with_boost = self._apply_field_boosts(unique_fields)

            should_clauses.append({
                "multi_match": {
                    "query": parsed_query.rewritten_query,
                    "fields": fields_with_boost,
                    "minimum_should_match": "67%",
                    "tie_breaker": 0.9,
                    "boost": domain_config.boost * 0.8,  # Lower boost for fallback
                    "_name": f"{domain_config.name}_fallback_query"
                }
            })

        if len(should_clauses) == 1:
            return should_clauses[0]
        else:
            return {
                "bool": {
                    "should": should_clauses,
                    "minimum_should_match": 1
                }
            }

    def _apply_field_boosts(self, field_names: List[str]) -> List[str]:
        """Apply boost values to field names."""
        result = []
        for field_name in field_names:
            field = self._get_field_by_name(field_name)
            if field and field.boost != 1.0:
                result.append(f"{field_name}^{field.boost}")
            else:
                result.append(field_name)
        return result

    def get_domain_summary(self) -> Dict[str, Any]:
        """Get summary of all configured domains."""
        summary = {}
        for domain_name, domain_config in self.domain_configs.items():
            summary[domain_name] = {
                "label": domain_config.label,
                "fields": domain_config.fields,
                "analyzer": domain_config.analyzer.value,
                "boost": domain_config.boost,
                "has_multilang_mapping": domain_config.language_field_mapping is not None,
                "supported_languages": list(domain_config.language_field_mapping.keys()) if domain_config.language_field_mapping else []
            }
        return summary