mapping_generator.py 6.26 KB
"""
Elasticsearch mapping generator.

Generates Elasticsearch index mappings from search configuration.
"""

from typing import Dict, Any
from config import (
    SearchConfig,
    FieldConfig,
    get_es_mapping_for_field,
    get_default_analyzers,
    get_default_similarity
)


class MappingGenerator:
    """Generates Elasticsearch mapping from search configuration."""

    def __init__(self, config: SearchConfig):
        self.config = config

    def generate_mapping(self) -> Dict[str, Any]:
        """
        Generate complete Elasticsearch index configuration including
        settings and mappings.

        Returns:
            Dictionary containing index configuration
        """
        return {
            "settings": self._generate_settings(),
            "mappings": self._generate_mappings()
        }

    def _generate_settings(self) -> Dict[str, Any]:
        """Generate index settings."""
        settings = {
            "number_of_shards": self.config.es_settings.get("number_of_shards", 1),
            "number_of_replicas": self.config.es_settings.get("number_of_replicas", 0),
            "refresh_interval": self.config.es_settings.get("refresh_interval", "30s"),
        }

        # Add similarity configuration (modified BM25)
        similarity_config = get_default_similarity()
        settings.update(similarity_config)

        # Add analyzer configuration
        analyzer_config = get_default_analyzers()
        settings.update(analyzer_config)

        # Merge any custom settings from config
        for key, value in self.config.es_settings.items():
            if key not in ["number_of_shards", "number_of_replicas", "refresh_interval"]:
                settings[key] = value

        return settings

    def _generate_mappings(self) -> Dict[str, Any]:
        """Generate field mappings."""
        properties = {}

        for field in self.config.fields:
            field_mapping = get_es_mapping_for_field(field)
            properties[field.name] = field_mapping

        return {
            "properties": properties
        }

    def get_default_domain_fields(self) -> list:
        """
        Get list of fields in the 'default' domain.

        Returns:
            List of field names
        """
        for index in self.config.indexes:
            if index.name == "default":
                return index.fields
        return []

    def get_text_embedding_field(self) -> str:
        """
        Get the primary text embedding field name.

        Returns:
            Field name or empty string if not found
        """
        # Look for name_embedding or first text_embedding field
        for field in self.config.fields:
            if field.name == "name_embedding":
                return field.name

        # Otherwise return first text embedding field
        for field in self.config.fields:
            if "embedding" in field.name and "image" not in field.name:
                return field.name

        return ""

    def get_image_embedding_field(self) -> str:
        """
        Get the primary image embedding field name.

        Returns:
            Field name or empty string if not found
        """
        for field in self.config.fields:
            if "image" in field.name and "embedding" in field.name:
                return field.name
        return ""

    def get_field_by_name(self, field_name: str) -> FieldConfig:
        """
        Get field configuration by name.

        Args:
            field_name: Field name

        Returns:
            FieldConfig object or None if not found
        """
        for field in self.config.fields:
            if field.name == field_name:
                return field
        return None

    def get_match_fields_for_domain(self, domain_name: str = "default") -> list:
        """
        Get list of text fields for matching in a domain.

        Args:
            domain_name: Name of the query domain

        Returns:
            List of field names with optional boost (e.g., ["name^2.0", "category^1.5"])
        """
        for index in self.config.indexes:
            if index.name == domain_name:
                result = []
                for field_name in index.fields:
                    field = self.get_field_by_name(field_name)
                    if field and field.boost != 1.0:
                        result.append(f"{field_name}^{field.boost}")
                    else:
                        result.append(field_name)
                return result
        return []


def create_index_if_not_exists(es_client, index_name: str, mapping: Dict[str, Any]) -> bool:
    """
    Create Elasticsearch index if it doesn't exist.

    Args:
        es_client: Elasticsearch client instance
        index_name: Name of the index to create
        mapping: Index mapping configuration

    Returns:
        True if index was created, False if it already exists
    """
    if es_client.indices.exists(index=index_name):
        print(f"Index '{index_name}' already exists")
        return False

    es_client.indices.create(index=index_name, body=mapping)
    print(f"Index '{index_name}' created successfully")
    return True


def delete_index_if_exists(es_client, index_name: str) -> bool:
    """
    Delete Elasticsearch index if it exists.

    Args:
        es_client: Elasticsearch client instance
        index_name: Name of the index to delete

    Returns:
        True if index was deleted, False if it didn't exist
    """
    if not es_client.indices.exists(index=index_name):
        print(f"Index '{index_name}' does not exist")
        return False

    es_client.indices.delete(index=index_name)
    print(f"Index '{index_name}' deleted successfully")
    return True


def update_mapping(es_client, index_name: str, new_fields: Dict[str, Any]) -> bool:
    """
    Update mapping for existing index (only adding new fields).

    Args:
        es_client: Elasticsearch client instance
        index_name: Name of the index
        new_fields: New field mappings to add

    Returns:
        True if successful
    """
    if not es_client.indices.exists(index=index_name):
        print(f"Index '{index_name}' does not exist")
        return False

    mapping = {"properties": new_fields}
    es_client.indices.put_mapping(index=index_name, body=mapping)
    print(f"Mapping updated for index '{index_name}'")
    return True