es_query_builder.py 11 KB
"""
Elasticsearch query builder.

Converts parsed queries and search parameters into ES DSL queries.
"""

from typing import Dict, Any, List, Optional, Union
import numpy as np
from .boolean_parser import QueryNode


class ESQueryBuilder:
    """Builds Elasticsearch DSL queries."""

    def __init__(
        self,
        index_name: str,
        match_fields: List[str],
        text_embedding_field: Optional[str] = None,
        image_embedding_field: Optional[str] = None
    ):
        """
        Initialize query builder.

        Args:
            index_name: ES index name
            match_fields: Fields to search for text matching
            text_embedding_field: Field name for text embeddings
            image_embedding_field: Field name for image embeddings
        """
        self.index_name = index_name
        self.match_fields = match_fields
        self.text_embedding_field = text_embedding_field
        self.image_embedding_field = image_embedding_field

    def build_query(
        self,
        query_text: str,
        query_vector: Optional[np.ndarray] = None,
        query_node: Optional[QueryNode] = None,
        filters: Optional[Dict[str, Any]] = None,
        range_filters: Optional[Dict[str, Any]] = None,
        size: int = 10,
        from_: int = 0,
        enable_knn: bool = True,
        knn_k: int = 50,
        knn_num_candidates: int = 200,
        min_score: Optional[float] = None
    ) -> Dict[str, Any]:
        """
        Build complete ES query (重构版).

        Args:
            query_text: Query text for BM25 matching
            query_vector: Query embedding for KNN search
            query_node: Parsed boolean expression tree
            filters: Exact match filters
            range_filters: Range filters for numeric fields
            size: Number of results
            from_: Offset for pagination
            enable_knn: Whether to use KNN search
            knn_k: K value for KNN
            knn_num_candidates: Number of candidates for KNN
            min_score: Minimum score threshold

        Returns:
            ES query DSL dictionary
        """
        es_query = {
            "size": size,
            "from": from_
        }

        # Build main query
        if query_node and query_node.operator != 'TERM':
            # Complex boolean query
            query_clause = self._build_boolean_query(query_node)
        else:
            # Simple text query
            query_clause = self._build_text_query(query_text)

        # Add filters if provided
        if filters or range_filters:
            filter_clauses = self._build_filters(filters, range_filters)
            if filter_clauses:
                es_query["query"] = {
                    "bool": {
                        "must": [query_clause],
                        "filter": filter_clauses
                    }
                }
            else:
                es_query["query"] = query_clause
        else:
            es_query["query"] = query_clause

        # Add KNN search if enabled and vector provided
        if enable_knn and query_vector is not None and self.text_embedding_field:
            knn_clause = {
                "field": self.text_embedding_field,
                "query_vector": query_vector.tolist(),
                "k": knn_k,
                "num_candidates": knn_num_candidates
            }
            es_query["knn"] = knn_clause

        # Add minimum score filter
        if min_score is not None:
            es_query["min_score"] = min_score

        return es_query

    def _build_text_query(self, query_text: str) -> Dict[str, Any]:
        """
        Build simple text matching query (BM25).

        Args:
            query_text: Query text

        Returns:
            ES query clause
        """
        return {
            "multi_match": {
                "query": query_text,
                "fields": self.match_fields,
                "minimum_should_match": "67%",
                "tie_breaker": 0.9,
                "boost": 1.0,
                "_name": "base_query"
            }
        }

    def _build_boolean_query(self, node: QueryNode) -> Dict[str, Any]:
        """
        Build query from boolean expression tree.

        Args:
            node: Query tree node

        Returns:
            ES query clause
        """
        if node.operator == 'TERM':
            # Leaf node - simple text query
            return self._build_text_query(node.value)

        elif node.operator == 'AND':
            # All terms must match
            return {
                "bool": {
                    "must": [
                        self._build_boolean_query(term)
                        for term in node.terms
                    ]
                }
            }

        elif node.operator == 'OR':
            # Any term must match
            return {
                "bool": {
                    "should": [
                        self._build_boolean_query(term)
                        for term in node.terms
                    ],
                    "minimum_should_match": 1
                }
            }

        elif node.operator == 'ANDNOT':
            # First term must match, second must not
            if len(node.terms) >= 2:
                return {
                    "bool": {
                        "must": [self._build_boolean_query(node.terms[0])],
                        "must_not": [self._build_boolean_query(node.terms[1])]
                    }
                }
            else:
                return self._build_boolean_query(node.terms[0])

        elif node.operator == 'RANK':
            # Like OR but for ranking (all terms contribute to score)
            return {
                "bool": {
                    "should": [
                        self._build_boolean_query(term)
                        for term in node.terms
                    ]
                }
            }

        else:
            # Unknown operator
            return {"match_all": {}}

    def _build_filters(
        self, 
        filters: Optional[Dict[str, Any]] = None,
        range_filters: Optional[Dict[str, 'RangeFilter']] = None
    ) -> List[Dict[str, Any]]:
        """
        构建过滤子句。
        
        Args:
            filters: 精确匹配过滤器字典
            range_filters: 范围过滤器(Dict[str, RangeFilter],RangeFilter 是 Pydantic 模型)
        
        Returns:
            ES filter 子句列表
        """
        filter_clauses = []
        
        # 1. 处理精确匹配过滤
        if filters:
            for field, value in filters.items():
                if isinstance(value, list):
                    # 多值匹配(OR)
                    filter_clauses.append({
                        "terms": {field: value}
                    })
                else:
                    # 单值精确匹配
                    filter_clauses.append({
                        "term": {field: value}
                    })
        
        # 2. 处理范围过滤(RangeFilter Pydantic 模型)
        if range_filters:
            for field, range_filter in range_filters.items():
                # 将 RangeFilter 模型转换为字典
                range_dict = range_filter.model_dump(exclude_none=True)
                
                if range_dict:
                    filter_clauses.append({
                        "range": {field: range_dict}
                    })
        
        return filter_clauses

    def add_spu_collapse(
        self,
        es_query: Dict[str, Any],
        spu_field: str,
        inner_hits_size: int = 3
    ) -> Dict[str, Any]:
        """
        Add SPU aggregation/collapse to query.

        Args:
            es_query: Existing ES query
            spu_field: Field containing SPU ID
            inner_hits_size: Number of SKUs to return per SPU

        Returns:
            Modified ES query
        """
        # Add collapse
        es_query["collapse"] = {
            "field": spu_field,
            "inner_hits": {
                "_source": False,
                "name": "top_docs",
                "size": inner_hits_size
            }
        }

        # Add cardinality aggregation to count unique SPUs
        if "aggs" not in es_query:
            es_query["aggs"] = {}

        es_query["aggs"]["unique_count"] = {
            "cardinality": {
                "field": spu_field
            }
        }

        return es_query

    def add_sorting(
        self,
        es_query: Dict[str, Any],
        sort_by: str,
        sort_order: str = "desc"
    ) -> Dict[str, Any]:
        """
        Add sorting to ES query.

        Args:
            es_query: Existing ES query
            sort_by: Field name for sorting
            sort_order: Sort order: 'asc' or 'desc'

        Returns:
            Modified ES query
        """
        if not sort_by:
            return es_query

        if not sort_order:
            sort_order = "desc"

        if "sort" not in es_query:
            es_query["sort"] = []

        # Add the specified sort
        sort_field = {
            sort_by: {
                "order": sort_order.lower()
            }
        }
        es_query["sort"].append(sort_field)

        return es_query

    def build_facets(
        self,
        facet_configs: Optional[List[Union[str, 'FacetConfig']]] = None
    ) -> Dict[str, Any]:
        """
        构建分面聚合。
        
        Args:
            facet_configs: 分面配置列表(标准格式):
                - str: 字段名,使用默认 terms 配置
                - FacetConfig: 详细的分面配置对象
        
        Returns:
            ES aggregations 字典
        """
        if not facet_configs:
            return {}
        
        aggs = {}
        
        for config in facet_configs:
            # 简单模式:只有字段名(字符串)
            if isinstance(config, str):
                field = config
                agg_name = f"{field}_facet"
                aggs[agg_name] = {
                    "terms": {
                        "field": field,
                        "size": 10,
                        "order": {"_count": "desc"}
                    }
                }
            
            # 高级模式:FacetConfig 对象
            else:
                # 此时 config 应该是 FacetConfig 对象
                field = config.field
                facet_type = config.type
                size = config.size
                agg_name = f"{field}_facet"
                
                if facet_type == 'terms':
                    aggs[agg_name] = {
                        "terms": {
                            "field": field,
                            "size": size,
                            "order": {"_count": "desc"}
                        }
                    }
                
                elif facet_type == 'range':
                    if config.ranges:
                        aggs[agg_name] = {
                            "range": {
                                "field": field,
                                "ranges": config.ranges
                            }
                        }
        
        return aggs