Blame view

search/rerank_client.py 10.1 KB
506c39b7   tangwang   feat(search): 统一重...
1
2
3
4
5
6
  """
  重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。
  
  流程:
  1.  ES hits 构造用于重排的文档文本列表
  2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
7
  3. 提取 ES 文本/向量子句分数,与重排分数做乘法融合并重排序
506c39b7   tangwang   feat(search): 统一重...
8
9
10
  """
  
  from typing import Dict, Any, List, Optional, Tuple
506c39b7   tangwang   feat(search): 统一重...
11
12
  import logging
  
42e3aea6   tangwang   tidy
13
14
  from providers import create_rerank_provider
  
506c39b7   tangwang   feat(search): 统一重...
15
16
  logger = logging.getLogger(__name__)
  
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
17
  # 历史配置项,保留签名兼容;当前乘法融合公式不再使用线性权重。
506c39b7   tangwang   feat(search): 统一重...
18
19
20
21
22
23
24
25
26
  DEFAULT_WEIGHT_ES = 0.4
  DEFAULT_WEIGHT_AI = 0.6
  # 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大)
  DEFAULT_TIMEOUT_SEC = 15.0
  
  
  def build_docs_from_hits(
      es_hits: List[Dict[str, Any]],
      language: str = "zh",
ff32d894   tangwang   rerank
27
      doc_template: str = "{title}",
506c39b7   tangwang   feat(search): 统一重...
28
29
30
31
  ) -> List[str]:
      """
       ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。
  
ff32d894   tangwang   rerank
32
33
      使用 doc_template 将文档字段组装为重排服务输入。
      支持占位符:{title} {brief} {vendor} {description} {category_path}
506c39b7   tangwang   feat(search): 统一重...
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  
      Args:
          es_hits: ES 返回的 hits 列表,每项含 _source
          language: 语言代码,如 "zh""en"
  
      Returns:
           es_hits 等长的字符串列表,用于 POST /rerank  docs
      """
      lang = (language or "zh").strip().lower()
      if lang not in ("zh", "en"):
          lang = "zh"
  
      def pick_lang_text(obj: Any) -> str:
          if obj is None:
              return ""
          if isinstance(obj, dict):
              return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip()
          return str(obj).strip()
  
ff32d894   tangwang   rerank
53
54
55
56
      class _SafeDict(dict):
          def __missing__(self, key: str) -> str:
              return ""
  
506c39b7   tangwang   feat(search): 统一重...
57
      docs: List[str] = []
ff32d894   tangwang   rerank
58
59
60
61
62
      only_title = "{title}" == doc_template
      need_brief = "{brief}" in doc_template
      need_vendor = "{vendor}" in doc_template
      need_description = "{description}" in doc_template
      need_category_path = "{category_path}" in doc_template
506c39b7   tangwang   feat(search): 统一重...
63
64
      for hit in es_hits:
          src = hit.get("_source") or {}
ff32d894   tangwang   rerank
65
66
67
68
69
70
71
72
73
74
75
          if only_title:
              docs.append(pick_lang_text(src.get("title")))
          else:
              values = _SafeDict(
                  title=pick_lang_text(src.get("title")),
                  brief=pick_lang_text(src.get("brief")) if need_brief else "",
                  vendor=pick_lang_text(src.get("vendor")) if need_vendor else "",
                  description=pick_lang_text(src.get("description")) if need_description else "",
                  category_path=pick_lang_text(src.get("category_path")) if need_category_path else "",
              )
              docs.append(str(doc_template).format_map(values))
506c39b7   tangwang   feat(search): 统一重...
76
77
78
79
80
81
      return docs
  
  
  def call_rerank_service(
      query: str,
      docs: List[str],
506c39b7   tangwang   feat(search): 统一重...
82
      timeout_sec: float = DEFAULT_TIMEOUT_SEC,
d31c7f65   tangwang   补充云服务reranker
83
      top_n: Optional[int] = None,
506c39b7   tangwang   feat(search): 统一重...
84
85
86
  ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]:
      """
      调用重排服务 POST /rerank,返回分数列表与 meta
42e3aea6   tangwang   tidy
87
      Provider  URL  services_config 读取。
506c39b7   tangwang   feat(search): 统一重...
88
89
90
91
      """
      if not docs:
          return [], {}
      try:
42e3aea6   tangwang   tidy
92
          client = create_rerank_provider()
d31c7f65   tangwang   补充云服务reranker
93
          return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n)
506c39b7   tangwang   feat(search): 统一重...
94
95
96
97
98
      except Exception as e:
          logger.warning("Rerank request failed: %s", e, exc_info=True)
          return None, None
  
  
c90f80ed   tangwang   相关性优化
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  def _to_score(value: Any) -> float:
      try:
          if value is None:
              return 0.0
          return float(value)
      except (TypeError, ValueError):
          return 0.0
  
  
  def _extract_named_query_score(matched_queries: Any, name: str) -> float:
      if isinstance(matched_queries, dict):
          return _to_score(matched_queries.get(name))
      if isinstance(matched_queries, list):
          return 1.0 if name in matched_queries else 0.0
      return 0.0
  
  
  def _collect_text_score_components(matched_queries: Any, fallback_es_score: float) -> Dict[str, float]:
      source_score = _extract_named_query_score(matched_queries, "base_query")
      translation_score = 0.0
c90f80ed   tangwang   相关性优化
119
120
121
122
123
124
125
126
  
      if isinstance(matched_queries, dict):
          for query_name, score in matched_queries.items():
              if not isinstance(query_name, str):
                  continue
              numeric_score = _to_score(score)
              if query_name.startswith("base_query_trans_"):
                  translation_score = max(translation_score, numeric_score)
c90f80ed   tangwang   相关性优化
127
128
129
130
131
132
      elif isinstance(matched_queries, list):
          for query_name in matched_queries:
              if not isinstance(query_name, str):
                  continue
              if query_name.startswith("base_query_trans_"):
                  translation_score = 1.0
c90f80ed   tangwang   相关性优化
133
134
135
  
      weighted_source = source_score
      weighted_translation = 0.8 * translation_score
0536222c   tangwang   query parser优化
136
      weighted_components = [weighted_source, weighted_translation]
c90f80ed   tangwang   相关性优化
137
138
139
140
141
142
143
144
145
146
147
148
149
      primary_text_score = max(weighted_components)
      support_text_score = sum(weighted_components) - primary_text_score
      text_score = primary_text_score + 0.25 * support_text_score
  
      if text_score <= 0.0:
          text_score = fallback_es_score
          weighted_source = fallback_es_score
          primary_text_score = fallback_es_score
          support_text_score = 0.0
  
      return {
          "source_score": source_score,
          "translation_score": translation_score,
c90f80ed   tangwang   相关性优化
150
151
          "weighted_source_score": weighted_source,
          "weighted_translation_score": weighted_translation,
c90f80ed   tangwang   相关性优化
152
153
154
155
156
157
158
159
160
161
162
163
164
          "primary_text_score": primary_text_score,
          "support_text_score": support_text_score,
          "text_score": text_score,
      }
  
  
  def _fuse_score(rerank_score: float, text_score: float, knn_score: float) -> float:
      rerank_factor = max(rerank_score, 0.0) + 0.00001
      text_factor = (max(text_score, 0.0) + 0.1) ** 0.35
      knn_factor = (max(knn_score, 0.0) + 0.6) ** 0.2
      return rerank_factor * text_factor * knn_factor
  
  
506c39b7   tangwang   feat(search): 统一重...
165
166
167
168
169
170
171
  def fuse_scores_and_resort(
      es_hits: List[Dict[str, Any]],
      rerank_scores: List[float],
      weight_es: float = DEFAULT_WEIGHT_ES,
      weight_ai: float = DEFAULT_WEIGHT_AI,
  ) -> List[Dict[str, Any]]:
      """
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
172
       ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。
506c39b7   tangwang   feat(search): 统一重...
173
174
175
  
      对每条 hit 会写入:
      - _original_score: 原始 ES 分数
33f8f578   tangwang   tidy
176
      - _rerank_score: 重排服务返回的分数
506c39b7   tangwang   feat(search): 统一重...
177
      - _fused_score: 融合分数
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
178
179
      - _text_score: 文本相关性分数(优先取 named queries  base_query 分数)
      - _knn_score: KNN 分数(优先取 named queries  knn_query 分数)
506c39b7   tangwang   feat(search): 统一重...
180
181
182
183
  
      Args:
          es_hits: ES hits 列表(会被原地修改)
          rerank_scores:  es_hits 等长的重排分数列表
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
184
185
          weight_es: 兼容保留,当前未使用
          weight_ai: 兼容保留,当前未使用
506c39b7   tangwang   feat(search): 统一重...
186
187
188
189
190
191
192
193
  
      Returns:
          每条文档的融合调试信息列表,用于 debug_info
      """
      n = len(es_hits)
      if n == 0 or len(rerank_scores) != n:
          return []
  
506c39b7   tangwang   feat(search): 统一重...
194
195
196
      fused_debug: List[Dict[str, Any]] = []
  
      for idx, hit in enumerate(es_hits):
c90f80ed   tangwang   相关性优化
197
          es_score = _to_score(hit.get("_score"))
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
198
  
506c39b7   tangwang   feat(search): 统一重...
199
          ai_score_raw = rerank_scores[idx]
c90f80ed   tangwang   相关性优化
200
          rerank_score = _to_score(ai_score_raw)
506c39b7   tangwang   feat(search): 统一重...
201
  
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
202
          matched_queries = hit.get("matched_queries")
c90f80ed   tangwang   相关性优化
203
204
205
206
          knn_score = _extract_named_query_score(matched_queries, "knn_query")
          text_components = _collect_text_score_components(matched_queries, es_score)
          text_score = text_components["text_score"]
          fused = _fuse_score(rerank_score, text_score, knn_score)
506c39b7   tangwang   feat(search): 统一重...
207
208
  
          hit["_original_score"] = hit.get("_score")
33f8f578   tangwang   tidy
209
          hit["_rerank_score"] = rerank_score
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
210
211
          hit["_text_score"] = text_score
          hit["_knn_score"] = knn_score
c90f80ed   tangwang   相关性优化
212
213
          hit["_text_source_score"] = text_components["source_score"]
          hit["_text_translation_score"] = text_components["translation_score"]
c90f80ed   tangwang   相关性优化
214
215
          hit["_text_primary_score"] = text_components["primary_text_score"]
          hit["_text_support_score"] = text_components["support_text_score"]
506c39b7   tangwang   feat(search): 统一重...
216
          hit["_fused_score"] = fused
506c39b7   tangwang   feat(search): 统一重...
217
218
219
220
  
          fused_debug.append({
              "doc_id": hit.get("_id"),
              "es_score": es_score,
33f8f578   tangwang   tidy
221
              "rerank_score": rerank_score,
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
222
              "text_score": text_score,
c90f80ed   tangwang   相关性优化
223
224
              "text_source_score": text_components["source_score"],
              "text_translation_score": text_components["translation_score"],
c90f80ed   tangwang   相关性优化
225
226
              "text_primary_score": text_components["primary_text_score"],
              "text_support_score": text_components["support_text_score"],
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
227
228
              "knn_score": knn_score,
              "matched_queries": matched_queries,
506c39b7   tangwang   feat(search): 统一重...
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
              "fused_score": fused,
          })
  
      # 按融合分数降序重排
      es_hits.sort(
          key=lambda h: h.get("_fused_score", h.get("_score", 0.0)),
          reverse=True,
      )
      return fused_debug
  
  
  def run_rerank(
      query: str,
      es_response: Dict[str, Any],
      language: str = "zh",
506c39b7   tangwang   feat(search): 统一重...
244
245
246
      timeout_sec: float = DEFAULT_TIMEOUT_SEC,
      weight_es: float = DEFAULT_WEIGHT_ES,
      weight_ai: float = DEFAULT_WEIGHT_AI,
ff32d894   tangwang   rerank
247
248
      rerank_query_template: str = "{query}",
      rerank_doc_template: str = "{title}",
d31c7f65   tangwang   补充云服务reranker
249
      top_n: Optional[int] = None,
506c39b7   tangwang   feat(search): 统一重...
250
251
252
  ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]:
      """
      完整重排流程:从 es_response  hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score
42e3aea6   tangwang   tidy
253
      Provider  URL  services_config 读取。
d31c7f65   tangwang   补充云服务reranker
254
      top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。
506c39b7   tangwang   feat(search): 统一重...
255
      """
506c39b7   tangwang   feat(search): 统一重...
256
257
258
259
      hits = es_response.get("hits", {}).get("hits") or []
      if not hits:
          return es_response, None, []
  
ff32d894   tangwang   rerank
260
261
      query_text = str(rerank_query_template).format_map({"query": query})
      docs = build_docs_from_hits(hits, language=language, doc_template=rerank_doc_template)
42e3aea6   tangwang   tidy
262
263
264
265
      scores, meta = call_rerank_service(
          query_text,
          docs,
          timeout_sec=timeout_sec,
d31c7f65   tangwang   补充云服务reranker
266
          top_n=top_n,
42e3aea6   tangwang   tidy
267
      )
506c39b7   tangwang   feat(search): 统一重...
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
  
      if scores is None or len(scores) != len(hits):
          return es_response, None, []
  
      fused_debug = fuse_scores_and_resort(
          hits,
          scores,
          weight_es=weight_es,
          weight_ai=weight_ai,
      )
  
      # 更新 max_score 为融合后的最高分
      if hits:
          top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0
          if "hits" in es_response:
              es_response["hits"]["max_score"] = top
  
      return es_response, meta, fused_debug