ranking_engine.py 4.88 KB
"""
Ranking engine for configurable search result scoring.

Supports expression-based ranking with functions like:
- bm25(): Base BM25 text relevance score
- text_embedding_relevance(): KNN embedding similarity
- field_value(field): Use field value in scoring
- timeliness(date_field): Time decay function
"""

import re
from typing import Dict, Any, List, Optional
import math


class RankingEngine:
    """Evaluates ranking expressions and applies to search results."""

    def __init__(self, ranking_expression: str):
        """
        Initialize ranking engine.

        Args:
            ranking_expression: Ranking expression string
                              Example: "bm25() + 0.2*text_embedding_relevance() + general_score*2"
        """
        self.expression = ranking_expression
        self.parsed_terms = self._parse_expression(ranking_expression)

    def _parse_expression(self, expression: str) -> List[Dict[str, Any]]:
        """
        Parse ranking expression into terms.

        Args:
            expression: Ranking expression

        Returns:
            List of term dictionaries
        """
        terms = []

        # Pattern to match: coefficient * function() or field_name
        # Example: "0.2*text_embedding_relevance()" or "general_score*2"
        pattern = r'([+-]?\s*\d*\.?\d*)\s*\*?\s*([a-zA-Z_]\w*(?:\([^)]*\))?)'

        for match in re.finditer(pattern, expression):
            coef_str = match.group(1).strip()
            func_str = match.group(2).strip()

            # Parse coefficient
            if coef_str in ['', '+']:
                coefficient = 1.0
            elif coef_str == '-':
                coefficient = -1.0
            else:
                try:
                    coefficient = float(coef_str)
                except ValueError:
                    coefficient = 1.0

            # Check if function or field
            if '(' in func_str:
                # Function call
                func_name = func_str[:func_str.index('(')]
                args_str = func_str[func_str.index('(') + 1:func_str.rindex(')')]
                args = [arg.strip() for arg in args_str.split(',')] if args_str else []

                terms.append({
                    'type': 'function',
                    'name': func_name,
                    'args': args,
                    'coefficient': coefficient
                })
            else:
                # Field reference
                terms.append({
                    'type': 'field',
                    'name': func_str,
                    'coefficient': coefficient
                })

        return terms

    def calculate_score(
        self,
        hit: Dict[str, Any],
        base_score: float,
        knn_score: Optional[float] = None
    ) -> float:
        """
        Calculate final score for a search result.

        Args:
            hit: ES hit document
            base_score: Base BM25 score
            knn_score: KNN similarity score (if available)

        Returns:
            Final calculated score
        """
        score = 0.0
        source = hit.get('_source', {})

        for term in self.parsed_terms:
            term_value = 0.0

            if term['type'] == 'function':
                func_name = term['name']

                if func_name == 'bm25':
                    term_value = base_score

                elif func_name == 'text_embedding_relevance':
                    term_value = knn_score if knn_score is not None else 0.0

                elif func_name == 'timeliness':
                    # Time decay function
                    if term['args']:
                        date_field = term['args'][0]
                        if date_field in source:
                            # Simple time decay (would need actual implementation)
                            term_value = 1.0
                    else:
                        term_value = 1.0

                elif func_name == 'field_value':
                    # Get field value
                    if term['args'] and term['args'][0] in source:
                        field_value = source[term['args'][0]]
                        try:
                            term_value = float(field_value)
                        except (ValueError, TypeError):
                            term_value = 0.0

            elif term['type'] == 'field':
                # Direct field reference
                field_name = term['name']
                if field_name in source:
                    try:
                        term_value = float(source[field_name])
                    except (ValueError, TypeError):
                        term_value = 0.0

            score += term['coefficient'] * term_value

        return score

    def get_expression(self) -> str:
        """Get ranking expression."""
        return self.expression

    def get_terms(self) -> List[Dict[str, Any]]:
        """Get parsed expression terms."""
        return self.parsed_terms