rerank_engine.py
5.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Reranking engine for post-processing search result scoring.
本地重排引擎,用于ES返回结果后的二次排序。
当前状态:已禁用,优先使用ES的function_score。
Supports expression-based ranking with functions like:
- bm25(): Base BM25 text relevance score
- text_embedding_relevance(): KNN embedding similarity
- field_value(field): Use field value in scoring
- timeliness(date_field): Time decay function
"""
import re
from typing import Dict, Any, List, Optional
import math
class RerankEngine:
"""
本地重排引擎(当前禁用)
功能:对ES返回的结果进行二次打分和排序
用途:复杂的自定义排序逻辑、实时个性化等
"""
def __init__(self, ranking_expression: str, enabled: bool = False):
"""
Initialize rerank engine.
Args:
ranking_expression: Ranking expression string
Example: "bm25() + 0.2*text_embedding_relevance() + general_score*2"
enabled: Whether local reranking is enabled (default: False)
"""
self.enabled = enabled
self.expression = ranking_expression
self.parsed_terms = []
if enabled:
self.parsed_terms = self._parse_expression(ranking_expression)
def _parse_expression(self, expression: str) -> List[Dict[str, Any]]:
"""
Parse ranking expression into terms.
Args:
expression: Ranking expression
Returns:
List of term dictionaries
"""
terms = []
# Pattern to match: coefficient * function() or field_name
# Example: "0.2*text_embedding_relevance()" or "general_score*2"
pattern = r'([+-]?\s*\d*\.?\d*)\s*\*?\s*([a-zA-Z_]\w*(?:\([^)]*\))?)'
for match in re.finditer(pattern, expression):
coef_str = match.group(1).strip()
func_str = match.group(2).strip()
# Parse coefficient
if coef_str in ['', '+']:
coefficient = 1.0
elif coef_str == '-':
coefficient = -1.0
else:
try:
coefficient = float(coef_str)
except ValueError:
coefficient = 1.0
# Check if function or field
if '(' in func_str:
# Function call
func_name = func_str[:func_str.index('(')]
args_str = func_str[func_str.index('(') + 1:func_str.rindex(')')]
args = [arg.strip() for arg in args_str.split(',')] if args_str else []
terms.append({
'type': 'function',
'name': func_name,
'args': args,
'coefficient': coefficient
})
else:
# Field reference
terms.append({
'type': 'field',
'name': func_str,
'coefficient': coefficient
})
return terms
def calculate_score(
self,
hit: Dict[str, Any],
base_score: float,
knn_score: Optional[float] = None
) -> float:
"""
Calculate final score for a search result.
Args:
hit: ES hit document
base_score: Base BM25 score
knn_score: KNN similarity score (if available)
Returns:
Final calculated score
"""
if not self.enabled:
return base_score
score = 0.0
source = hit.get('_source', {})
for term in self.parsed_terms:
term_value = 0.0
if term['type'] == 'function':
func_name = term['name']
if func_name == 'bm25':
term_value = base_score
elif func_name == 'text_embedding_relevance':
term_value = knn_score if knn_score is not None else 0.0
elif func_name == 'timeliness':
# Time decay function
if term['args']:
date_field = term['args'][0]
if date_field in source:
# Simple time decay (would need actual implementation)
term_value = 1.0
else:
term_value = 1.0
elif func_name == 'field_value':
# Get field value
if term['args'] and term['args'][0] in source:
field_value = source[term['args'][0]]
try:
term_value = float(field_value)
except (ValueError, TypeError):
term_value = 0.0
elif term['type'] == 'field':
# Direct field reference
field_name = term['name']
if field_name in source:
try:
term_value = float(source[field_name])
except (ValueError, TypeError):
term_value = 0.0
score += term['coefficient'] * term_value
return score
def get_expression(self) -> str:
"""Get ranking expression."""
return self.expression
def get_terms(self) -> List[Dict[str, Any]]:
"""Get parsed expression terms."""
return self.parsed_terms