es_query_builder.py 11 KB
"""
Elasticsearch query builder.

Converts parsed queries and search parameters into ES DSL queries.
"""

from typing import Dict, Any, List, Optional
import numpy as np
from .boolean_parser import QueryNode


class ESQueryBuilder:
    """Builds Elasticsearch DSL queries."""

    def __init__(
        self,
        index_name: str,
        match_fields: List[str],
        text_embedding_field: Optional[str] = None,
        image_embedding_field: Optional[str] = None
    ):
        """
        Initialize query builder.

        Args:
            index_name: ES index name
            match_fields: Fields to search for text matching
            text_embedding_field: Field name for text embeddings
            image_embedding_field: Field name for image embeddings
        """
        self.index_name = index_name
        self.match_fields = match_fields
        self.text_embedding_field = text_embedding_field
        self.image_embedding_field = image_embedding_field

    def build_query(
        self,
        query_text: str,
        query_vector: Optional[np.ndarray] = None,
        query_node: Optional[QueryNode] = None,
        filters: Optional[Dict[str, Any]] = None,
        size: int = 10,
        from_: int = 0,
        enable_knn: bool = True,
        knn_k: int = 50,
        knn_num_candidates: int = 200,
        min_score: Optional[float] = None
    ) -> Dict[str, Any]:
        """
        Build complete ES query.

        Args:
            query_text: Query text for BM25 matching
            query_vector: Query embedding for KNN search
            query_node: Parsed boolean expression tree
            filters: Additional filters (term, range, etc.)
            size: Number of results
            from_: Offset for pagination
            enable_knn: Whether to use KNN search
            knn_k: K value for KNN
            knn_num_candidates: Number of candidates for KNN
            min_score: Minimum score threshold

        Returns:
            ES query DSL dictionary
        """
        es_query = {
            "size": size,
            "from": from_
        }

        # Build main query
        if query_node and query_node.operator != 'TERM':
            # Complex boolean query
            query_clause = self._build_boolean_query(query_node)
        else:
            # Simple text query
            query_clause = self._build_text_query(query_text)

        # Add filters if provided
        if filters:
            es_query["query"] = {
                "bool": {
                    "must": [query_clause],
                    "filter": self._build_filters(filters)
                }
            }
        else:
            es_query["query"] = query_clause

        # Add KNN search if enabled and vector provided
        if enable_knn and query_vector is not None and self.text_embedding_field:
            knn_clause = {
                "field": self.text_embedding_field,
                "query_vector": query_vector.tolist(),
                "k": knn_k,
                "num_candidates": knn_num_candidates
            }
            es_query["knn"] = knn_clause

        # Add minimum score filter
        if min_score is not None:
            es_query["min_score"] = min_score

        return es_query

    def _build_text_query(self, query_text: str) -> Dict[str, Any]:
        """
        Build simple text matching query (BM25).

        Args:
            query_text: Query text

        Returns:
            ES query clause
        """
        return {
            "multi_match": {
                "query": query_text,
                "fields": self.match_fields,
                "minimum_should_match": "67%",
                "tie_breaker": 0.9,
                "boost": 1.0,
                "_name": "base_query"
            }
        }

    def _build_boolean_query(self, node: QueryNode) -> Dict[str, Any]:
        """
        Build query from boolean expression tree.

        Args:
            node: Query tree node

        Returns:
            ES query clause
        """
        if node.operator == 'TERM':
            # Leaf node - simple text query
            return self._build_text_query(node.value)

        elif node.operator == 'AND':
            # All terms must match
            return {
                "bool": {
                    "must": [
                        self._build_boolean_query(term)
                        for term in node.terms
                    ]
                }
            }

        elif node.operator == 'OR':
            # Any term must match
            return {
                "bool": {
                    "should": [
                        self._build_boolean_query(term)
                        for term in node.terms
                    ],
                    "minimum_should_match": 1
                }
            }

        elif node.operator == 'ANDNOT':
            # First term must match, second must not
            if len(node.terms) >= 2:
                return {
                    "bool": {
                        "must": [self._build_boolean_query(node.terms[0])],
                        "must_not": [self._build_boolean_query(node.terms[1])]
                    }
                }
            else:
                return self._build_boolean_query(node.terms[0])

        elif node.operator == 'RANK':
            # Like OR but for ranking (all terms contribute to score)
            return {
                "bool": {
                    "should": [
                        self._build_boolean_query(term)
                        for term in node.terms
                    ]
                }
            }

        else:
            # Unknown operator
            return {"match_all": {}}

    def _build_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Build filter clauses.

        Args:
            filters: Filter specifications

        Returns:
            List of ES filter clauses
        """
        filter_clauses = []

        for field, value in filters.items():
            if field == 'price_ranges':
                # Handle price range filters
                if isinstance(value, list):
                    price_ranges = []
                    for price_range in value:
                        if price_range == '0-50':
                            price_ranges.append({"lt": 50})
                        elif price_range == '50-100':
                            price_ranges.append({"gte": 50, "lt": 100})
                        elif price_range == '100-200':
                            price_ranges.append({"gte": 100, "lt": 200})
                        elif price_range == '200+':
                            price_ranges.append({"gte": 200})

                    if price_ranges:
                        if len(price_ranges) == 1:
                            filter_clauses.append({
                                "range": {
                                    "price": price_ranges[0]
                                }
                            })
                        else:
                            # Multiple price ranges - use bool should clause
                            range_clauses = [{"range": {"price": pr}} for pr in price_ranges]
                            filter_clauses.append({
                                "bool": {
                                    "should": range_clauses
                                }
                            })
            elif isinstance(value, dict):
                # Range query
                if "gte" in value or "lte" in value or "gt" in value or "lt" in value:
                    filter_clauses.append({
                        "range": {
                            field: value
                        }
                    })
            elif isinstance(value, list):
                # Terms query (match any)
                filter_clauses.append({
                    "terms": {
                        field: value
                    }
                })
            else:
                # Term query (exact match)
                filter_clauses.append({
                    "term": {
                        field: value
                    }
                })

        return filter_clauses

    def add_spu_collapse(
        self,
        es_query: Dict[str, Any],
        spu_field: str,
        inner_hits_size: int = 3
    ) -> Dict[str, Any]:
        """
        Add SPU aggregation/collapse to query.

        Args:
            es_query: Existing ES query
            spu_field: Field containing SPU ID
            inner_hits_size: Number of SKUs to return per SPU

        Returns:
            Modified ES query
        """
        # Add collapse
        es_query["collapse"] = {
            "field": spu_field,
            "inner_hits": {
                "_source": False,
                "name": "top_docs",
                "size": inner_hits_size
            }
        }

        # Add cardinality aggregation to count unique SPUs
        if "aggs" not in es_query:
            es_query["aggs"] = {}

        es_query["aggs"]["unique_count"] = {
            "cardinality": {
                "field": spu_field
            }
        }

        return es_query

    def add_dynamic_aggregations(
        self,
        es_query: Dict[str, Any],
        aggregations: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        Add dynamic aggregations based on request parameters.

        Args:
            es_query: Existing ES query
            aggregations: Aggregation specifications

        Returns:
            Modified ES query
        """
        if "aggs" not in es_query:
            es_query["aggs"] = {}

        for agg_name, agg_spec in aggregations.items():
            es_query["aggs"][agg_name] = agg_spec

        return es_query

    def add_sorting(
        self,
        es_query: Dict[str, Any],
        sort_by: str,
        sort_order: str = "desc"
    ) -> Dict[str, Any]:
        """
        Add sorting to ES query.

        Args:
            es_query: Existing ES query
            sort_by: Field name for sorting
            sort_order: Sort order: 'asc' or 'desc'

        Returns:
            Modified ES query
        """
        if not sort_by:
            return es_query

        if not sort_order:
            sort_order = "desc"

        if "sort" not in es_query:
            es_query["sort"] = []

        # Add the specified sort
        sort_field = {
            sort_by: {
                "order": sort_order.lower()
            }
        }
        es_query["sort"].append(sort_field)

        return es_query

    def add_aggregations(
        self,
        es_query: Dict[str, Any],
        agg_fields: List[str]
    ) -> Dict[str, Any]:
        """
        Add aggregations for faceted search.

        Args:
            es_query: Existing ES query
            agg_fields: Fields to aggregate on

        Returns:
            Modified ES query
        """
        if "aggs" not in es_query:
            es_query["aggs"] = {}

        for field in agg_fields:
            es_query["aggs"][f"{field}_agg"] = {
                "terms": {
                    "field": f"{field}",
                    "size": 20
                }
            }

        return es_query