Blame view

search/rerank_engine.py 5.38 KB
be52af70   tangwang   first commit
1
  """
43f1139f   tangwang   refactor: ES查询结构重...
2
3
4
5
  Reranking engine for post-processing search result scoring.
  
  本地重排引擎,用于ES返回结果后的二次排序。
  当前状态:已禁用,优先使用ESfunction_score
be52af70   tangwang   first commit
6
7
8
9
10
11
12
13
14
15
16
17
18
  
  Supports expression-based ranking with functions like:
  - bm25(): Base BM25 text relevance score
  - text_embedding_relevance(): KNN embedding similarity
  - field_value(field): Use field value in scoring
  - timeliness(date_field): Time decay function
  """
  
  import re
  from typing import Dict, Any, List, Optional
  import math
  
  
43f1139f   tangwang   refactor: ES查询结构重...
19
20
21
22
23
24
25
  class RerankEngine:
      """
      本地重排引擎(当前禁用)
      
      功能:对ES返回的结果进行二次打分和排序
      用途:复杂的自定义排序逻辑、实时个性化等
      """
be52af70   tangwang   first commit
26
  
43f1139f   tangwang   refactor: ES查询结构重...
27
      def __init__(self, ranking_expression: str, enabled: bool = False):
be52af70   tangwang   first commit
28
          """
43f1139f   tangwang   refactor: ES查询结构重...
29
          Initialize rerank engine.
be52af70   tangwang   first commit
30
31
32
33
  
          Args:
              ranking_expression: Ranking expression string
                                Example: "bm25() + 0.2*text_embedding_relevance() + general_score*2"
43f1139f   tangwang   refactor: ES查询结构重...
34
              enabled: Whether local reranking is enabled (default: False)
be52af70   tangwang   first commit
35
          """
43f1139f   tangwang   refactor: ES查询结构重...
36
          self.enabled = enabled
be52af70   tangwang   first commit
37
          self.expression = ranking_expression
43f1139f   tangwang   refactor: ES查询结构重...
38
39
40
41
          self.parsed_terms = []
          
          if enabled:
              self.parsed_terms = self._parse_expression(ranking_expression)
be52af70   tangwang   first commit
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
  
      def _parse_expression(self, expression: str) -> List[Dict[str, Any]]:
          """
          Parse ranking expression into terms.
  
          Args:
              expression: Ranking expression
  
          Returns:
              List of term dictionaries
          """
          terms = []
  
          # Pattern to match: coefficient * function() or field_name
          # Example: "0.2*text_embedding_relevance()" or "general_score*2"
          pattern = r'([+-]?\s*\d*\.?\d*)\s*\*?\s*([a-zA-Z_]\w*(?:\([^)]*\))?)'
  
          for match in re.finditer(pattern, expression):
              coef_str = match.group(1).strip()
              func_str = match.group(2).strip()
  
              # Parse coefficient
              if coef_str in ['', '+']:
                  coefficient = 1.0
              elif coef_str == '-':
                  coefficient = -1.0
              else:
                  try:
                      coefficient = float(coef_str)
                  except ValueError:
                      coefficient = 1.0
  
              # Check if function or field
              if '(' in func_str:
                  # Function call
                  func_name = func_str[:func_str.index('(')]
                  args_str = func_str[func_str.index('(') + 1:func_str.rindex(')')]
                  args = [arg.strip() for arg in args_str.split(',')] if args_str else []
  
                  terms.append({
                      'type': 'function',
                      'name': func_name,
                      'args': args,
                      'coefficient': coefficient
                  })
              else:
                  # Field reference
                  terms.append({
                      'type': 'field',
                      'name': func_str,
                      'coefficient': coefficient
                  })
  
          return terms
  
      def calculate_score(
          self,
          hit: Dict[str, Any],
          base_score: float,
          knn_score: Optional[float] = None
      ) -> float:
          """
          Calculate final score for a search result.
  
          Args:
              hit: ES hit document
              base_score: Base BM25 score
              knn_score: KNN similarity score (if available)
  
          Returns:
              Final calculated score
          """
43f1139f   tangwang   refactor: ES查询结构重...
114
115
116
          if not self.enabled:
              return base_score
          
be52af70   tangwang   first commit
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
          score = 0.0
          source = hit.get('_source', {})
  
          for term in self.parsed_terms:
              term_value = 0.0
  
              if term['type'] == 'function':
                  func_name = term['name']
  
                  if func_name == 'bm25':
                      term_value = base_score
  
                  elif func_name == 'text_embedding_relevance':
                      term_value = knn_score if knn_score is not None else 0.0
  
                  elif func_name == 'timeliness':
                      # Time decay function
                      if term['args']:
                          date_field = term['args'][0]
                          if date_field in source:
                              # Simple time decay (would need actual implementation)
                              term_value = 1.0
                      else:
                          term_value = 1.0
  
                  elif func_name == 'field_value':
                      # Get field value
                      if term['args'] and term['args'][0] in source:
                          field_value = source[term['args'][0]]
                          try:
                              term_value = float(field_value)
                          except (ValueError, TypeError):
                              term_value = 0.0
  
              elif term['type'] == 'field':
                  # Direct field reference
                  field_name = term['name']
                  if field_name in source:
                      try:
                          term_value = float(source[field_name])
                      except (ValueError, TypeError):
                          term_value = 0.0
  
              score += term['coefficient'] * term_value
  
          return score
  
      def get_expression(self) -> str:
          """Get ranking expression."""
          return self.expression
  
      def get_terms(self) -> List[Dict[str, Any]]:
          """Get parsed expression terms."""
          return self.parsed_terms
43f1139f   tangwang   refactor: ES查询结构重...