searcher.py 11.3 KB
"""
Main Searcher module - executes search queries against Elasticsearch.

Handles query parsing, boolean expressions, ranking, and result formatting.
"""

from typing import Dict, Any, List, Optional
import time

from config import CustomerConfig
from utils.es_client import ESClient
from query import QueryParser, ParsedQuery
from indexer import MappingGenerator
from .boolean_parser import BooleanParser, QueryNode
from .es_query_builder import ESQueryBuilder
from .multilang_query_builder import MultiLanguageQueryBuilder
from .ranking_engine import RankingEngine


class SearchResult:
    """Container for search results."""

    def __init__(
        self,
        hits: List[Dict[str, Any]],
        total: int,
        max_score: float,
        took_ms: int,
        aggregations: Optional[Dict[str, Any]] = None,
        query_info: Optional[Dict[str, Any]] = None
    ):
        self.hits = hits
        self.total = total
        self.max_score = max_score
        self.took_ms = took_ms
        self.aggregations = aggregations or {}
        self.query_info = query_info or {}

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary representation."""
        return {
            "hits": self.hits,
            "total": self.total,
            "max_score": self.max_score,
            "took_ms": self.took_ms,
            "aggregations": self.aggregations,
            "query_info": self.query_info
        }


class Searcher:
    """
    Main search engine class.

    Handles:
    - Query parsing and translation
    - Boolean expression parsing
    - ES query building
    - Result ranking and formatting
    """

    def __init__(
        self,
        config: CustomerConfig,
        es_client: ESClient,
        query_parser: Optional[QueryParser] = None
    ):
        """
        Initialize searcher.

        Args:
            config: Customer configuration
            es_client: Elasticsearch client
            query_parser: Query parser (created if not provided)
        """
        self.config = config
        self.es_client = es_client
        self.query_parser = query_parser or QueryParser(config)

        # Initialize components
        self.boolean_parser = BooleanParser()
        self.ranking_engine = RankingEngine(config.ranking.expression)

        # Get mapping info
        mapping_gen = MappingGenerator(config)
        self.match_fields = mapping_gen.get_match_fields_for_domain("default")
        self.text_embedding_field = mapping_gen.get_text_embedding_field()
        self.image_embedding_field = mapping_gen.get_image_embedding_field()

        # Query builder - use multi-language version
        self.query_builder = MultiLanguageQueryBuilder(
            config=config,
            index_name=config.es_index_name,
            text_embedding_field=self.text_embedding_field,
            image_embedding_field=self.image_embedding_field
        )

    def search(
        self,
        query: str,
        size: int = 10,
        from_: int = 0,
        filters: Optional[Dict[str, Any]] = None,
        enable_translation: bool = True,
        enable_embedding: bool = True,
        enable_rerank: bool = True,
        min_score: Optional[float] = None
    ) -> SearchResult:
        """
        Execute search query.

        Args:
            query: Search query string
            size: Number of results to return
            from_: Offset for pagination
            filters: Additional filters (field: value pairs)
            enable_translation: Whether to enable query translation
            enable_embedding: Whether to use semantic search
            enable_rerank: Whether to apply custom ranking
            min_score: Minimum score threshold

        Returns:
            SearchResult object
        """
        start_time = time.time()

        print(f"\n{'='*60}")
        print(f"[Searcher] Starting search for: '{query}'")
        print(f"{'='*60}")

        # Step 1: Parse query
        parsed_query = self.query_parser.parse(
            query,
            generate_vector=enable_embedding
        )

        # Step 2: Check if boolean expression
        query_node = None
        if self.boolean_parser.is_simple_query(parsed_query.rewritten_query):
            # Simple query
            query_text = parsed_query.rewritten_query
        else:
            # Complex boolean query
            query_node = self.boolean_parser.parse(parsed_query.rewritten_query)
            query_text = parsed_query.rewritten_query
            print(f"[Searcher] Parsed boolean expression: {query_node}")

        # Step 3: Build ES query using multi-language builder
        es_query = self.query_builder.build_multilang_query(
            parsed_query=parsed_query,
            query_vector=parsed_query.query_vector if enable_embedding else None,
            query_node=query_node,
            filters=filters,
            size=size,
            from_=from_,
            enable_knn=enable_embedding and parsed_query.query_vector is not None,
            min_score=min_score
        )

        # Add SPU collapse if configured
        if self.config.spu_config.enabled:
            es_query = self.query_builder.add_spu_collapse(
                es_query,
                self.config.spu_config.spu_field,
                self.config.spu_config.inner_hits_size
            )

        # Add aggregations for faceted search
        if filters:
            agg_fields = [f"{k}_keyword" for k in filters.keys() if f"{k}_keyword" in [f.name for f in self.config.fields]]
            if agg_fields:
                es_query = self.query_builder.add_aggregations(es_query, agg_fields)

        # Extract size and from from body for ES client parameters
        body_for_es = {k: v for k, v in es_query.items() if k not in ['size', 'from']}

        print(f"[Searcher] ES Query:")
        import json
        print(json.dumps(es_query, indent=2))

        # Step 4: Execute search
        print(f"[Searcher] Executing ES query...")
        es_response = self.es_client.search(
            index_name=self.config.es_index_name,
            body=body_for_es,
            size=size,
            from_=from_
        )

        # Step 5: Process results
        hits = []
        if 'hits' in es_response and 'hits' in es_response['hits']:
            for hit in es_response['hits']['hits']:
                result_doc = {
                    '_id': hit['_id'],
                    '_score': hit['_score'],
                    '_source': hit['_source']
                }

                # Apply custom ranking if enabled
                if enable_rerank:
                    base_score = hit['_score']
                    knn_score = None

                    # Check if KNN was used
                    if 'knn' in es_query:
                        # KNN score would be in the combined score
                        # For simplicity, extract from score
                        knn_score = base_score * 0.2  # Approximate based on our formula

                    custom_score = self.ranking_engine.calculate_score(
                        hit,
                        base_score,
                        knn_score
                    )
                    result_doc['_custom_score'] = custom_score
                    result_doc['_original_score'] = base_score

                hits.append(result_doc)

            # Re-sort by custom score if reranking enabled
            if enable_rerank:
                hits.sort(key=lambda x: x.get('_custom_score', x['_score']), reverse=True)

        # Extract total and max_score
        total = es_response.get('hits', {}).get('total', {})
        if isinstance(total, dict):
            total_value = total.get('value', 0)
        else:
            total_value = total

        max_score = es_response.get('hits', {}).get('max_score', 0.0)

        # Extract aggregations
        aggregations = es_response.get('aggregations', {})

        # Calculate elapsed time
        elapsed_ms = int((time.time() - start_time) * 1000)

        # Build result
        result = SearchResult(
            hits=hits,
            total=total_value,
            max_score=max_score,
            took_ms=elapsed_ms,
            aggregations=aggregations,
            query_info=parsed_query.to_dict()
        )

        print(f"[Searcher] Search complete: {total_value} results in {elapsed_ms}ms")
        print(f"{'='*60}\n")

        return result

    def search_by_image(
        self,
        image_url: str,
        size: int = 10,
        filters: Optional[Dict[str, Any]] = None
    ) -> SearchResult:
        """
        Search by image similarity.

        Args:
            image_url: URL of query image
            size: Number of results
            filters: Additional filters

        Returns:
            SearchResult object
        """
        if not self.image_embedding_field:
            raise ValueError("Image embedding field not configured")

        # Generate image embedding
        from embeddings import CLIPImageEncoder
        image_encoder = CLIPImageEncoder()
        image_vector = image_encoder.encode_image_from_url(image_url)

        if image_vector is None:
            raise ValueError(f"Failed to encode image: {image_url}")

        # Build KNN query
        es_query = {
            "size": size,
            "knn": {
                "field": self.image_embedding_field,
                "query_vector": image_vector.tolist(),
                "k": size,
                "num_candidates": size * 10
            }
        }

        if filters:
            es_query["query"] = {
                "bool": {
                    "filter": self.query_builder._build_filters(filters)
                }
            }

        # Execute search
        es_response = self.es_client.search(
            index_name=self.config.es_index_name,
            body=es_query,
            size=size
        )

        # Process results (similar to text search)
        hits = []
        if 'hits' in es_response and 'hits' in es_response['hits']:
            for hit in es_response['hits']['hits']:
                hits.append({
                    '_id': hit['_id'],
                    '_score': hit['_score'],
                    '_source': hit['_source']
                })

        total = es_response.get('hits', {}).get('total', {})
        if isinstance(total, dict):
            total_value = total.get('value', 0)
        else:
            total_value = total

        return SearchResult(
            hits=hits,
            total=total_value,
            max_score=es_response.get('hits', {}).get('max_score', 0.0),
            took_ms=es_response.get('took', 0),
            query_info={'image_url': image_url, 'search_type': 'image_similarity'}
        )

    def get_domain_summary(self) -> Dict[str, Any]:
        """
        Get summary of all configured domains.

        Returns:
            Dictionary with domain information
        """
        return self.query_builder.get_domain_summary()

    def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
        """
        Get single document by ID.

        Args:
            doc_id: Document ID

        Returns:
            Document or None if not found
        """
        try:
            response = self.es_client.client.get(
                index=self.config.es_index_name,
                id=doc_id
            )
            return response.get('_source')
        except Exception as e:
            print(f"[Searcher] Failed to get document {doc_id}: {e}")
            return None