be52af70
tangwang
first commit
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
|
"""
Ranking engine for configurable search result scoring.
Supports expression-based ranking with functions like:
- bm25(): Base BM25 text relevance score
- text_embedding_relevance(): KNN embedding similarity
- field_value(field): Use field value in scoring
- timeliness(date_field): Time decay function
"""
import re
from typing import Dict, Any, List, Optional
import math
class RankingEngine:
"""Evaluates ranking expressions and applies to search results."""
def __init__(self, ranking_expression: str):
"""
Initialize ranking engine.
Args:
ranking_expression: Ranking expression string
Example: "bm25() + 0.2*text_embedding_relevance() + general_score*2"
"""
self.expression = ranking_expression
self.parsed_terms = self._parse_expression(ranking_expression)
def _parse_expression(self, expression: str) -> List[Dict[str, Any]]:
"""
Parse ranking expression into terms.
Args:
expression: Ranking expression
Returns:
List of term dictionaries
"""
terms = []
# Pattern to match: coefficient * function() or field_name
# Example: "0.2*text_embedding_relevance()" or "general_score*2"
pattern = r'([+-]?\s*\d*\.?\d*)\s*\*?\s*([a-zA-Z_]\w*(?:\([^)]*\))?)'
for match in re.finditer(pattern, expression):
coef_str = match.group(1).strip()
func_str = match.group(2).strip()
# Parse coefficient
if coef_str in ['', '+']:
coefficient = 1.0
elif coef_str == '-':
coefficient = -1.0
else:
try:
coefficient = float(coef_str)
except ValueError:
coefficient = 1.0
# Check if function or field
if '(' in func_str:
# Function call
func_name = func_str[:func_str.index('(')]
args_str = func_str[func_str.index('(') + 1:func_str.rindex(')')]
args = [arg.strip() for arg in args_str.split(',')] if args_str else []
terms.append({
'type': 'function',
'name': func_name,
'args': args,
'coefficient': coefficient
})
else:
# Field reference
terms.append({
'type': 'field',
'name': func_str,
'coefficient': coefficient
})
return terms
def calculate_score(
self,
hit: Dict[str, Any],
base_score: float,
knn_score: Optional[float] = None
) -> float:
"""
Calculate final score for a search result.
Args:
hit: ES hit document
base_score: Base BM25 score
knn_score: KNN similarity score (if available)
Returns:
Final calculated score
"""
score = 0.0
source = hit.get('_source', {})
for term in self.parsed_terms:
term_value = 0.0
if term['type'] == 'function':
func_name = term['name']
if func_name == 'bm25':
term_value = base_score
elif func_name == 'text_embedding_relevance':
term_value = knn_score if knn_score is not None else 0.0
elif func_name == 'timeliness':
# Time decay function
if term['args']:
date_field = term['args'][0]
if date_field in source:
# Simple time decay (would need actual implementation)
term_value = 1.0
else:
term_value = 1.0
elif func_name == 'field_value':
# Get field value
if term['args'] and term['args'][0] in source:
field_value = source[term['args'][0]]
try:
term_value = float(field_value)
except (ValueError, TypeError):
term_value = 0.0
elif term['type'] == 'field':
# Direct field reference
field_name = term['name']
if field_name in source:
try:
term_value = float(source[field_name])
except (ValueError, TypeError):
term_value = 0.0
score += term['coefficient'] * term_value
return score
def get_expression(self) -> str:
"""Get ranking expression."""
return self.expression
def get_terms(self) -> List[Dict[str, Any]]:
"""Get parsed expression terms."""
return self.parsed_terms
|