es_client.py 7.68 KB
"""
Elasticsearch client wrapper.
"""

from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from typing import Dict, Any, List, Optional
import os
import logging

# Try to import ES_CONFIG, but allow import to fail
try:
    from config.env_config import ES_CONFIG
except ImportError:
    ES_CONFIG = None

logger = logging.getLogger(__name__)


class ESClient:
    """Wrapper for Elasticsearch client with common operations."""

    def __init__(
        self,
        hosts: List[str] = None,
        username: Optional[str] = None,
        password: Optional[str] = None,
        **kwargs
    ):
        """
        Initialize Elasticsearch client.

        Args:
            hosts: List of ES host URLs (default: ['http://localhost:9200'])
            username: ES username (optional)
            password: ES password (optional)
            **kwargs: Additional ES client parameters
        """
        if hosts is None:
            hosts = [os.getenv('ES_HOST', 'http://localhost:9200')]

        # Build client config
        client_config = {
            'hosts': hosts,
            'timeout': 30,
            'max_retries': 3,
            'retry_on_timeout': True,
        }

        # Add authentication if provided
        if username and password:
            client_config['http_auth'] = (username, password)

        # Merge additional kwargs
        client_config.update(kwargs)

        self.client = Elasticsearch(**client_config)

    def ping(self) -> bool:
        """
        Test connection to Elasticsearch.

        Returns:
            True if connected, False otherwise
        """
        try:
            return self.client.ping()
        except Exception as e:
            logger.error(f"Failed to ping Elasticsearch: {e}", exc_info=True)
            return False

    def create_index(self, index_name: str, body: Dict[str, Any]) -> bool:
        """
        Create an index.

        Args:
            index_name: Name of the index
            body: Index configuration (settings + mappings)

        Returns:
            True if successful, False otherwise
        """
        try:
            self.client.indices.create(index=index_name, body=body)
            logger.info(f"Index '{index_name}' created successfully")
            return True
        except Exception as e:
            logger.error(f"Failed to create index '{index_name}': {e}", exc_info=True)
            return False

    def delete_index(self, index_name: str) -> bool:
        """
        Delete an index.

        Args:
            index_name: Name of the index

        Returns:
            True if successful
        """
        try:
            if self.client.indices.exists(index=index_name):
                self.client.indices.delete(index=index_name)
                logger.info(f"Index '{index_name}' deleted successfully")
                return True
            else:
                logger.warning(f"Index '{index_name}' does not exist")
                return False
        except Exception as e:
            logger.error(f"Failed to delete index '{index_name}': {e}", exc_info=True)
            return False

    def index_exists(self, index_name: str) -> bool:
        """Check if index exists."""
        return self.client.indices.exists(index=index_name)

    def bulk_index(self, index_name: str, docs: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Bulk index documents.

        Args:
            index_name: Name of the index
            docs: List of documents to index

        Returns:
            Dictionary with results
        """
        actions = []
        for doc in docs:
            action = {
                '_index': index_name,
                '_source': doc
            }
            # If document has _id field, use it
            if '_id' in doc:
                action['_id'] = doc['_id']
                del doc['_id']

            actions.append(action)

        try:
            success, failed = bulk(self.client, actions, raise_on_error=False)
            return {
                'success': success,
                'failed': len(failed),
                'errors': failed
            }
        except Exception as e:
            logger.error(f"Bulk indexing failed: {e}", exc_info=True)
            return {
                'success': 0,
                'failed': len(docs),
                'errors': [str(e)]
            }

    def search(
        self,
        index_name: str,
        body: Dict[str, Any],
        size: int = 10,
        from_: int = 0
    ) -> Dict[str, Any]:
        """
        Execute search query.

        Args:
            index_name: Name of the index
            body: Search query body
            size: Number of results to return
            from_: Offset for pagination

        Returns:
            Search results
        """
        # Safety guard: collapse is no longer needed (index is already SPU-level).
        # If any caller accidentally adds a collapse clause (e.g. on product_id),
        # strip it here to avoid 400 errors like:
        # "no mapping found for `product_id` in order to collapse on"
        if isinstance(body, dict) and "collapse" in body:
            logger.warning(
                "Removing unsupported 'collapse' clause from ES query body: %s",
                body.get("collapse")
            )
            body = dict(body)  # shallow copy to avoid mutating caller
            body.pop("collapse", None)

        try:
            return self.client.search(
                index=index_name,
                body=body,
                size=size,
                from_=from_
            )
        except Exception as e:
            logger.error(f"Search failed: {e}", exc_info=True)
            return {
                'hits': {
                    'total': {'value': 0},
                    'hits': []
                },
                'error': str(e)
            }

    def get_mapping(self, index_name: str) -> Dict[str, Any]:
        """Get index mapping."""
        try:
            return self.client.indices.get_mapping(index=index_name)
        except Exception as e:
            logger.error(f"Failed to get mapping for '{index_name}': {e}", exc_info=True)
            return {}

    def refresh(self, index_name: str) -> bool:
        """Refresh index to make documents searchable."""
        try:
            self.client.indices.refresh(index=index_name)
            return True
        except Exception as e:
            logger.error(f"Failed to refresh index '{index_name}': {e}", exc_info=True)
            return False

    def count(self, index_name: str, body: Optional[Dict[str, Any]] = None) -> int:
        """
        Count documents in index.

        Args:
            index_name: Name of the index
            body: Optional query body

        Returns:
            Document count
        """
        try:
            result = self.client.count(index=index_name, body=body)
            return result['count']
        except Exception as e:
            logger.error(f"Count failed: {e}", exc_info=True)
            return 0


def get_es_client_from_env() -> ESClient:
    """
    Create ES client from environment variables.

    Environment variables:
        ES_HOST: Elasticsearch host URL (default: http://localhost:9200)
        ES_USERNAME: Username (optional)
        ES_PASSWORD: Password (optional)

    Returns:
        ESClient instance
    """
    if ES_CONFIG:
        return ESClient(
            hosts=[ES_CONFIG['host']],
            username=ES_CONFIG.get('username'),
            password=ES_CONFIG.get('password')
        )
    else:
        # Fallback to env variables
        return ESClient(
            hosts=[os.getenv('ES_HOST', 'http://localhost:9200')],
            username=os.getenv('ES_USERNAME'),
            password=os.getenv('ES_PASSWORD')
        )