es_client.py 12.7 KB
"""
Elasticsearch client wrapper.
"""

from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from typing import Dict, Any, List, Optional
import logging

from config.loader import get_app_config

logger = logging.getLogger(__name__)


class ESClient:
    """Wrapper for Elasticsearch client with common operations."""

    def __init__(
        self,
        hosts: List[str] = None,
        username: Optional[str] = None,
        password: Optional[str] = None,
        **kwargs
    ):
        """
        Initialize Elasticsearch client.

        Args:
            hosts: List of ES host URLs (default: ['http://localhost:9200'])
            username: ES username (optional)
            password: ES password (optional)
            **kwargs: Additional ES client parameters
        """
        if hosts is None:
            hosts = [get_app_config().infrastructure.elasticsearch.host]

        # Build client config
        client_config = {
            'hosts': hosts,
            'timeout': 30,
            'max_retries': 3,
            'retry_on_timeout': True,
        }

        # Add authentication if provided
        if username and password:
            client_config['basic_auth'] = (username, password)

        # Merge additional kwargs
        client_config.update(kwargs)

        self.client = Elasticsearch(**client_config)

    def ping(self) -> bool:
        """
        Test connection to Elasticsearch.

        Returns:
            True if connected, False otherwise
        """
        try:
            return self.client.ping()
        except Exception as e:
            logger.error(f"Failed to ping Elasticsearch: {e}", exc_info=True)
            return False

    def create_index(self, index_name: str, body: Dict[str, Any]) -> bool:
        """
        Create an index.

        Args:
            index_name: Name of the index
            body: Index configuration (settings + mappings)

        Returns:
            True if successful, False otherwise
        """
        try:
            client = self.client.options(request_timeout=30, max_retries=0)
            client.indices.create(
                index=index_name,
                body=body,
                wait_for_active_shards="0",
            )
            logger.info(f"Index '{index_name}' created successfully")
            return True
        except Exception as e:
            if self.index_exists(index_name):
                logger.warning(
                    "Create index request for '%s' raised %s, but the index now exists; treating it as created",
                    index_name,
                    type(e).__name__,
                    exc_info=True,
                )
                return True
            logger.error(f"Failed to create index '{index_name}': {e}", exc_info=True)
            return False

    def wait_for_index_ready(self, index_name: str, timeout: str = "10s") -> Dict[str, Any]:
        """Wait until an index primary shard is allocated and searchable."""
        try:
            resp = self.client.cluster.health(
                index=index_name,
                wait_for_status="yellow",
                timeout=timeout,
                level="indices",
            )
            index_info = ((resp.get("indices") or {}).get(index_name) or {})
            status = index_info.get("status") or resp.get("status")
            timed_out = bool(resp.get("timed_out"))
            return {
                "ok": (not timed_out) and status in {"yellow", "green"},
                "status": status,
                "timed_out": timed_out,
                "response": resp,
            }
        except Exception as e:
            logger.error("Failed waiting for index '%s' readiness: %s", index_name, e, exc_info=True)
            return {
                "ok": False,
                "status": "unknown",
                "timed_out": False,
                "error": str(e),
            }

    def get_allocation_explain(self, index_name: str, shard: int = 0, primary: bool = True) -> Optional[Dict[str, Any]]:
        """Explain why a shard can or cannot be allocated."""
        try:
            return self.client.cluster.allocation_explain(
                body={"index": index_name, "shard": shard, "primary": primary}
            )
        except Exception as e:
            logger.warning(
                "Failed to get allocation explain for index '%s' shard=%s primary=%s: %s",
                index_name,
                shard,
                primary,
                e,
                exc_info=True,
            )
            return None

    def put_alias(self, index_name: str, alias_name: str) -> bool:
        """Add alias for an index."""
        try:
            self.client.indices.put_alias(index=index_name, name=alias_name)
            return True
        except Exception as e:
            logger.error(
                "Failed to put alias '%s' for index '%s': %s",
                alias_name,
                index_name,
                e,
                exc_info=True,
            )
            return False

    def alias_exists(self, alias_name: str) -> bool:
        """Check if alias exists."""
        try:
            return self.client.indices.exists_alias(name=alias_name)
        except Exception as e:
            logger.error("Failed to check alias exists '%s': %s", alias_name, e, exc_info=True)
            return False

    def get_alias_indices(self, alias_name: str) -> List[str]:
        """Get concrete indices behind alias."""
        try:
            result = self.client.indices.get_alias(name=alias_name)
            return sorted(list((result or {}).keys()))
        except Exception:
            return []

    def update_aliases(self, actions: List[Dict[str, Any]]) -> bool:
        """Atomically update aliases."""
        try:
            self.client.indices.update_aliases(body={"actions": actions})
            return True
        except Exception as e:
            logger.error("Failed to update aliases: %s", e, exc_info=True)
            return False

    def list_indices(self, pattern: str) -> List[str]:
        """List indices by wildcard pattern."""
        try:
            result = self.client.indices.get(index=pattern, allow_no_indices=True)
            return sorted(list((result or {}).keys()))
        except Exception:
            return []

    def delete_index(self, index_name: str) -> bool:
        """
        Delete an index.

        Args:
            index_name: Name of the index

        Returns:
            True if successful
        """
        try:
            if self.client.indices.exists(index=index_name):
                self.client.indices.delete(index=index_name)
                logger.info(f"Index '{index_name}' deleted successfully")
                return True
            else:
                logger.warning(f"Index '{index_name}' does not exist")
                return False
        except Exception as e:
            logger.error(f"Failed to delete index '{index_name}': {e}", exc_info=True)
            return False

    def index_exists(self, index_name: str) -> bool:
        """Check if index exists."""
        return self.client.indices.exists(index=index_name)

    def bulk_index(self, index_name: str, docs: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Bulk index documents.

        Args:
            index_name: Name of the index
            docs: List of documents to index

        Returns:
            Dictionary with results
        """
        actions = []
        for doc in docs:
            action = {
                '_index': index_name,
                '_source': doc
            }
            # If document has _id field, use it
            if '_id' in doc:
                action['_id'] = doc['_id']
                del doc['_id']

            actions.append(action)

        try:
            success, failed = bulk(self.client, actions, raise_on_error=False)
            return {
                'success': success,
                'failed': len(failed),
                'errors': failed
            }
        except Exception as e:
            logger.error(f"Bulk indexing failed: {e}", exc_info=True)
            return {
                'success': 0,
                'failed': len(docs),
                'errors': [str(e)]
            }

    def bulk_actions(self, actions: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Execute generic bulk actions.

        Args:
            actions: elasticsearch.helpers.bulk compatible action list
        """
        if not actions:
            return {'success': 0, 'failed': 0, 'errors': []}
        try:
            success, failed = bulk(self.client, actions, raise_on_error=False)
            return {
                'success': success,
                'failed': len(failed),
                'errors': failed
            }
        except Exception as e:
            logger.error("Bulk actions failed: %s", e, exc_info=True)
            return {
                'success': 0,
                'failed': len(actions),
                'errors': [str(e)],
            }

    def search(
        self,
        index_name: str,
        body: Dict[str, Any],
        size: int = 10,
        from_: int = 0,
        routing: Optional[str] = None,
        include_named_queries_score: bool = False,
    ) -> Dict[str, Any]:
        """
        Execute search query.

        Args:
            index_name: Name of the index
            body: Search query body
            size: Number of results to return
            from_: Offset for pagination

        Returns:
            Search results
        """
        # Safety guard: collapse is no longer needed (index is already SPU-level).
        # If any caller accidentally adds a collapse clause (e.g. on product_id),
        # strip it here to avoid 400 errors like:
        # "no mapping found for `product_id` in order to collapse on"
        if isinstance(body, dict) and "collapse" in body:
            logger.warning(
                "Removing unsupported 'collapse' clause from ES query body: %s",
                body.get("collapse")
            )
            body = dict(body)  # shallow copy to avoid mutating caller
            body.pop("collapse", None)

        try:
            response = self.client.search(
                index=index_name,
                body=body,
                size=size,
                from_=from_,
                routing=routing,
                include_named_queries_score=include_named_queries_score,
            )
            # elasticsearch-py 8.x returns ObjectApiResponse; normalize to mutable dict
            # so caller can safely patch hits/took during post-processing.
            if hasattr(response, "body"):
                payload = response.body
                if isinstance(payload, dict):
                    return dict(payload)
                return payload
            if isinstance(response, dict):
                return response
            return dict(response)
        except Exception as e:
            logger.error(f"Search failed: {e}", exc_info=True)
            raise RuntimeError(f"Elasticsearch search failed for index '{index_name}': {e}") from e

    def get_mapping(self, index_name: str) -> Dict[str, Any]:
        """Get index mapping."""
        try:
            return self.client.indices.get_mapping(index=index_name)
        except Exception as e:
            logger.error(f"Failed to get mapping for '{index_name}': {e}", exc_info=True)
            return {}

    def refresh(self, index_name: str) -> bool:
        """Refresh index to make documents searchable."""
        try:
            self.client.indices.refresh(index=index_name)
            return True
        except Exception as e:
            logger.error(f"Failed to refresh index '{index_name}': {e}", exc_info=True)
            return False

    def count(self, index_name: str, body: Optional[Dict[str, Any]] = None) -> int:
        """
        Count documents in index.

        Args:
            index_name: Name of the index
            body: Optional query body

        Returns:
            Document count
        """
        try:
            result = self.client.count(index=index_name, body=body)
            return result['count']
        except Exception as e:
            logger.error(f"Count failed: {e}", exc_info=True)
            raise RuntimeError(f"Elasticsearch count failed for index '{index_name}': {e}") from e


def get_es_client_from_env() -> ESClient:
    """
    Create ES client from environment variables.

    Environment variables:
        ES_HOST: Elasticsearch host URL (default: http://localhost:9200)
        ES_USERNAME: Username (optional)
        ES_PASSWORD: Password (optional)

    Returns:
        ESClient instance
    """
    cfg = get_app_config().infrastructure.elasticsearch
    return ESClient(
        hosts=[cfg.host],
        username=cfg.username,
        password=cfg.password,
    )