rerank_engine.py 5.38 KB
"""
Reranking engine for post-processing search result scoring.

本地重排引擎,用于ES返回结果后的二次排序。
当前状态:已禁用,优先使用ES的function_score。

Supports expression-based ranking with functions like:
- bm25(): Base BM25 text relevance score
- text_embedding_relevance(): KNN embedding similarity
- field_value(field): Use field value in scoring
- timeliness(date_field): Time decay function
"""

import re
from typing import Dict, Any, List, Optional
import math


class RerankEngine:
    """
    本地重排引擎(当前禁用)
    
    功能:对ES返回的结果进行二次打分和排序
    用途:复杂的自定义排序逻辑、实时个性化等
    """

    def __init__(self, ranking_expression: str, enabled: bool = False):
        """
        Initialize rerank engine.

        Args:
            ranking_expression: Ranking expression string
                              Example: "bm25() + 0.2*text_embedding_relevance() + general_score*2"
            enabled: Whether local reranking is enabled (default: False)
        """
        self.enabled = enabled
        self.expression = ranking_expression
        self.parsed_terms = []
        
        if enabled:
            self.parsed_terms = self._parse_expression(ranking_expression)

    def _parse_expression(self, expression: str) -> List[Dict[str, Any]]:
        """
        Parse ranking expression into terms.

        Args:
            expression: Ranking expression

        Returns:
            List of term dictionaries
        """
        terms = []

        # Pattern to match: coefficient * function() or field_name
        # Example: "0.2*text_embedding_relevance()" or "general_score*2"
        pattern = r'([+-]?\s*\d*\.?\d*)\s*\*?\s*([a-zA-Z_]\w*(?:\([^)]*\))?)'

        for match in re.finditer(pattern, expression):
            coef_str = match.group(1).strip()
            func_str = match.group(2).strip()

            # Parse coefficient
            if coef_str in ['', '+']:
                coefficient = 1.0
            elif coef_str == '-':
                coefficient = -1.0
            else:
                try:
                    coefficient = float(coef_str)
                except ValueError:
                    coefficient = 1.0

            # Check if function or field
            if '(' in func_str:
                # Function call
                func_name = func_str[:func_str.index('(')]
                args_str = func_str[func_str.index('(') + 1:func_str.rindex(')')]
                args = [arg.strip() for arg in args_str.split(',')] if args_str else []

                terms.append({
                    'type': 'function',
                    'name': func_name,
                    'args': args,
                    'coefficient': coefficient
                })
            else:
                # Field reference
                terms.append({
                    'type': 'field',
                    'name': func_str,
                    'coefficient': coefficient
                })

        return terms

    def calculate_score(
        self,
        hit: Dict[str, Any],
        base_score: float,
        knn_score: Optional[float] = None
    ) -> float:
        """
        Calculate final score for a search result.

        Args:
            hit: ES hit document
            base_score: Base BM25 score
            knn_score: KNN similarity score (if available)

        Returns:
            Final calculated score
        """
        if not self.enabled:
            return base_score
        
        score = 0.0
        source = hit.get('_source', {})

        for term in self.parsed_terms:
            term_value = 0.0

            if term['type'] == 'function':
                func_name = term['name']

                if func_name == 'bm25':
                    term_value = base_score

                elif func_name == 'text_embedding_relevance':
                    term_value = knn_score if knn_score is not None else 0.0

                elif func_name == 'timeliness':
                    # Time decay function
                    if term['args']:
                        date_field = term['args'][0]
                        if date_field in source:
                            # Simple time decay (would need actual implementation)
                            term_value = 1.0
                    else:
                        term_value = 1.0

                elif func_name == 'field_value':
                    # Get field value
                    if term['args'] and term['args'][0] in source:
                        field_value = source[term['args'][0]]
                        try:
                            term_value = float(field_value)
                        except (ValueError, TypeError):
                            term_value = 0.0

            elif term['type'] == 'field':
                # Direct field reference
                field_name = term['name']
                if field_name in source:
                    try:
                        term_value = float(source[field_name])
                    except (ValueError, TypeError):
                        term_value = 0.0

            score += term['coefficient'] * term_value

        return score

    def get_expression(self) -> str:
        """Get ranking expression."""
        return self.expression

    def get_terms(self) -> List[Dict[str, Any]]:
        """Get parsed expression terms."""
        return self.parsed_terms