Blame view

search/rerank_client.py 15 KB
506c39b7   tangwang   feat(search): 统一重...
1
2
3
4
5
6
  """
  重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。
  
  流程:
  1.  ES hits 构造用于重排的文档文本列表
  2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
7
  3. 提取 ES 文本/向量子句分数,与重排分数做乘法融合并重排序
506c39b7   tangwang   feat(search): 统一重...
8
9
10
  """
  
  from typing import Dict, Any, List, Optional, Tuple
506c39b7   tangwang   feat(search): 统一重...
11
12
  import logging
  
814e352b   tangwang   乘法公式配置化
13
  from config.schema import RerankFusionConfig
42e3aea6   tangwang   tidy
14
15
  from providers import create_rerank_provider
  
506c39b7   tangwang   feat(search): 统一重...
16
17
  logger = logging.getLogger(__name__)
  
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
18
  # 历史配置项,保留签名兼容;当前乘法融合公式不再使用线性权重。
506c39b7   tangwang   feat(search): 统一重...
19
20
21
22
23
24
25
26
27
  DEFAULT_WEIGHT_ES = 0.4
  DEFAULT_WEIGHT_AI = 0.6
  # 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大)
  DEFAULT_TIMEOUT_SEC = 15.0
  
  
  def build_docs_from_hits(
      es_hits: List[Dict[str, Any]],
      language: str = "zh",
ff32d894   tangwang   rerank
28
      doc_template: str = "{title}",
581dafae   tangwang   debug工具,每条结果的打分中间...
29
      debug_rows: Optional[List[Dict[str, Any]]] = None,
506c39b7   tangwang   feat(search): 统一重...
30
31
32
33
  ) -> List[str]:
      """
       ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。
  
ff32d894   tangwang   rerank
34
35
      使用 doc_template 将文档字段组装为重排服务输入。
      支持占位符:{title} {brief} {vendor} {description} {category_path}
506c39b7   tangwang   feat(search): 统一重...
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  
      Args:
          es_hits: ES 返回的 hits 列表,每项含 _source
          language: 语言代码,如 "zh""en"
  
      Returns:
           es_hits 等长的字符串列表,用于 POST /rerank  docs
      """
      lang = (language or "zh").strip().lower()
      if lang not in ("zh", "en"):
          lang = "zh"
  
      def pick_lang_text(obj: Any) -> str:
          if obj is None:
              return ""
          if isinstance(obj, dict):
              return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip()
          return str(obj).strip()
  
ff32d894   tangwang   rerank
55
56
57
58
      class _SafeDict(dict):
          def __missing__(self, key: str) -> str:
              return ""
  
506c39b7   tangwang   feat(search): 统一重...
59
      docs: List[str] = []
ff32d894   tangwang   rerank
60
61
62
63
64
      only_title = "{title}" == doc_template
      need_brief = "{brief}" in doc_template
      need_vendor = "{vendor}" in doc_template
      need_description = "{description}" in doc_template
      need_category_path = "{category_path}" in doc_template
506c39b7   tangwang   feat(search): 统一重...
65
66
      for hit in es_hits:
          src = hit.get("_source") or {}
cda1cd62   tangwang   意图分析&应用 baseline
67
          title_suffix = str(hit.get("_style_rerank_suffix") or "").strip()
6075aa91   tangwang   性能优化
68
69
70
71
72
  
          title_str=(
              f"{pick_lang_text(src.get('title'))} {title_suffix}".strip()
              if title_suffix
              else pick_lang_text(src.get("title"))
581dafae   tangwang   debug工具,每条结果的打分中间...
73
          )
6075aa91   tangwang   性能优化
74
75
          title_str = str(title_str).strip()
  
ff32d894   tangwang   rerank
76
          if only_title:
6075aa91   tangwang   性能优化
77
78
79
80
81
82
83
84
85
86
87
88
              doc_text = title_str
              if debug_rows is not None:
                  preview = doc_text if len(doc_text) <= 300 else f"{doc_text[:300]}..."
                  debug_rows.append({
                      "doc_template": doc_template,
                      "title_suffix": title_suffix or None,
                      "fields": {
                          "title": title_str,
                      },
                      "doc_preview": preview,
                      "doc_length": len(doc_text),
                  })
ff32d894   tangwang   rerank
89
          else:
6075aa91   tangwang   性能优化
90
91
92
93
94
95
96
              values = _SafeDict(
                  title=title_str,
                  brief=pick_lang_text(src.get("brief")) if need_brief else "",
                  vendor=pick_lang_text(src.get("vendor")) if need_vendor else "",
                  description=pick_lang_text(src.get("description")) if need_description else "",
                  category_path=pick_lang_text(src.get("category_path")) if need_category_path else "",
              )
581dafae   tangwang   debug工具,每条结果的打分中间...
97
              doc_text = str(doc_template).format_map(values)
6075aa91   tangwang   性能优化
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  
              if debug_rows is not None:
                  preview = doc_text if len(doc_text) <= 300 else f"{doc_text[:300]}..."
                  debug_rows.append({
                      "doc_template": doc_template,
                      "title_suffix": title_suffix or None,
                      "fields": {
                          "title": title_str,
                          "brief": values.get("brief") or None,
                          "vendor": values.get("vendor") or None,
                          "category_path": values.get("category_path") or None
                      },
                      "doc_preview": preview,
                      "doc_length": len(doc_text),
                  })
581dafae   tangwang   debug工具,每条结果的打分中间...
113
          docs.append(doc_text)
6075aa91   tangwang   性能优化
114
  
506c39b7   tangwang   feat(search): 统一重...
115
116
117
118
119
120
      return docs
  
  
  def call_rerank_service(
      query: str,
      docs: List[str],
506c39b7   tangwang   feat(search): 统一重...
121
      timeout_sec: float = DEFAULT_TIMEOUT_SEC,
d31c7f65   tangwang   补充云服务reranker
122
      top_n: Optional[int] = None,
506c39b7   tangwang   feat(search): 统一重...
123
124
125
  ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]:
      """
      调用重排服务 POST /rerank,返回分数列表与 meta
42e3aea6   tangwang   tidy
126
      Provider  URL  services_config 读取。
506c39b7   tangwang   feat(search): 统一重...
127
128
129
130
      """
      if not docs:
          return [], {}
      try:
42e3aea6   tangwang   tidy
131
          client = create_rerank_provider()
d31c7f65   tangwang   补充云服务reranker
132
          return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n)
506c39b7   tangwang   feat(search): 统一重...
133
134
135
136
137
      except Exception as e:
          logger.warning("Rerank request failed: %s", e, exc_info=True)
          return None, None
  
  
c90f80ed   tangwang   相关性优化
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
  def _to_score(value: Any) -> float:
      try:
          if value is None:
              return 0.0
          return float(value)
      except (TypeError, ValueError):
          return 0.0
  
  
  def _extract_named_query_score(matched_queries: Any, name: str) -> float:
      if isinstance(matched_queries, dict):
          return _to_score(matched_queries.get(name))
      if isinstance(matched_queries, list):
          return 1.0 if name in matched_queries else 0.0
      return 0.0
  
e38dc1be   tangwang   融合公式参数调整、以及展示信息优化
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
  """
  原始变量:
  ES总分
  source_score:从 ES 返回的 matched_queries 里取 base_query 这条 named query 的分(dict 用具体分数;list 形式则“匹配到名字就算 1.0”)。
  translation_score:所有名字以 base_query_trans_ 开头的 named query 的分,在 dict 里取 最大值;在 list 里只要存在这类名字就记为 1.0
  
  中间变量:计算原始query得分和翻译query得分
  weighted_source :
  weighted_translation : 0.8 * translation_score
  
  区分主信号和辅助信号:
  合成primary_text_scoresupport_text_score,取 更强 的那一路(原文检索 vs 翻译检索)作为主信号
  primary_text_score : max(weighted_source, weighted_translation)
  support_text_score : weighted_source + weighted_translation - primary_text_score
  
  主信号和辅助信号的融合:dismax融合公式
  最终text_score:主信号 + 0.25 * 辅助信号
  text_score : primary_text_score + 0.25 * support_text_score
  """
c90f80ed   tangwang   相关性优化
173
174
175
  def _collect_text_score_components(matched_queries: Any, fallback_es_score: float) -> Dict[str, float]:
      source_score = _extract_named_query_score(matched_queries, "base_query")
      translation_score = 0.0
c90f80ed   tangwang   相关性优化
176
177
178
179
180
181
182
183
  
      if isinstance(matched_queries, dict):
          for query_name, score in matched_queries.items():
              if not isinstance(query_name, str):
                  continue
              numeric_score = _to_score(score)
              if query_name.startswith("base_query_trans_"):
                  translation_score = max(translation_score, numeric_score)
c90f80ed   tangwang   相关性优化
184
185
186
187
188
189
      elif isinstance(matched_queries, list):
          for query_name in matched_queries:
              if not isinstance(query_name, str):
                  continue
              if query_name.startswith("base_query_trans_"):
                  translation_score = 1.0
c90f80ed   tangwang   相关性优化
190
191
192
  
      weighted_source = source_score
      weighted_translation = 0.8 * translation_score
0536222c   tangwang   query parser优化
193
      weighted_components = [weighted_source, weighted_translation]
c90f80ed   tangwang   相关性优化
194
195
196
197
198
199
200
201
202
203
204
205
206
      primary_text_score = max(weighted_components)
      support_text_score = sum(weighted_components) - primary_text_score
      text_score = primary_text_score + 0.25 * support_text_score
  
      if text_score <= 0.0:
          text_score = fallback_es_score
          weighted_source = fallback_es_score
          primary_text_score = fallback_es_score
          support_text_score = 0.0
  
      return {
          "source_score": source_score,
          "translation_score": translation_score,
c90f80ed   tangwang   相关性优化
207
208
          "weighted_source_score": weighted_source,
          "weighted_translation_score": weighted_translation,
c90f80ed   tangwang   相关性优化
209
210
211
212
213
214
          "primary_text_score": primary_text_score,
          "support_text_score": support_text_score,
          "text_score": text_score,
      }
  
  
814e352b   tangwang   乘法公式配置化
215
216
217
218
219
220
  def _multiply_fusion_factors(
      rerank_score: float,
      text_score: float,
      knn_score: float,
      fusion: RerankFusionConfig,
  ) -> Tuple[float, float, float, float]:
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
221
      """(rerank_factor, text_factor, knn_factor, fused_without_style_boost)."""
814e352b   tangwang   乘法公式配置化
222
223
224
225
226
227
      r = (max(rerank_score, 0.0) + fusion.rerank_bias) ** fusion.rerank_exponent
      t = (max(text_score, 0.0) + fusion.text_bias) ** fusion.text_exponent
      k = (max(knn_score, 0.0) + fusion.knn_bias) ** fusion.knn_exponent
      return r, t, k, r * t * k
  
  
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
228
229
230
231
  def _has_selected_sku(hit: Dict[str, Any]) -> bool:
      return bool(str(hit.get("_style_rerank_suffix") or "").strip())
  
  
506c39b7   tangwang   feat(search): 统一重...
232
233
234
235
236
  def fuse_scores_and_resort(
      es_hits: List[Dict[str, Any]],
      rerank_scores: List[float],
      weight_es: float = DEFAULT_WEIGHT_ES,
      weight_ai: float = DEFAULT_WEIGHT_AI,
814e352b   tangwang   乘法公式配置化
237
      fusion: Optional[RerankFusionConfig] = None,
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
238
      style_intent_selected_sku_boost: float = 1.2,
581dafae   tangwang   debug工具,每条结果的打分中间...
239
240
      debug: bool = False,
      rerank_debug_rows: Optional[List[Dict[str, Any]]] = None,
506c39b7   tangwang   feat(search): 统一重...
241
242
  ) -> List[Dict[str, Any]]:
      """
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
243
       ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。
506c39b7   tangwang   feat(search): 统一重...
244
  
814e352b   tangwang   乘法公式配置化
245
      融合形式(由 ``fusion`` 配置 bias / exponent::
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
246
247
248
249
          fused = (max(rerank,0)+b_r)^e_r * (max(text,0)+b_t)^e_t * (max(knn,0)+b_k)^e_k * sku_boost
  
      其中 sku_boost 仅在当前 hit 已选中 SKU 时生效,默认值为 1.2,可通过
      ``query.style_intent.selected_sku_boost`` 配置。
814e352b   tangwang   乘法公式配置化
250
  
506c39b7   tangwang   feat(search): 统一重...
251
252
      对每条 hit 会写入:
      - _original_score: 原始 ES 分数
33f8f578   tangwang   tidy
253
      - _rerank_score: 重排服务返回的分数
506c39b7   tangwang   feat(search): 统一重...
254
      - _fused_score: 融合分数
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
255
256
      - _text_score: 文本相关性分数(优先取 named queries  base_query 分数)
      - _knn_score: KNN 分数(优先取 named queries  knn_query 分数)
506c39b7   tangwang   feat(search): 统一重...
257
258
259
260
  
      Args:
          es_hits: ES hits 列表(会被原地修改)
          rerank_scores:  es_hits 等长的重排分数列表
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
261
262
          weight_es: 兼容保留,当前未使用
          weight_ai: 兼容保留,当前未使用
506c39b7   tangwang   feat(search): 统一重...
263
264
265
266
267
      """
      n = len(es_hits)
      if n == 0 or len(rerank_scores) != n:
          return []
  
814e352b   tangwang   乘法公式配置化
268
269
      f = fusion or RerankFusionConfig()
      fused_debug: List[Dict[str, Any]] = [] if debug else []
506c39b7   tangwang   feat(search): 统一重...
270
271
  
      for idx, hit in enumerate(es_hits):
c90f80ed   tangwang   相关性优化
272
          es_score = _to_score(hit.get("_score"))
814e352b   tangwang   乘法公式配置化
273
          rerank_score = _to_score(rerank_scores[idx])
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
274
          matched_queries = hit.get("matched_queries")
c90f80ed   tangwang   相关性优化
275
276
277
          knn_score = _extract_named_query_score(matched_queries, "knn_query")
          text_components = _collect_text_score_components(matched_queries, es_score)
          text_score = text_components["text_score"]
814e352b   tangwang   乘法公式配置化
278
279
280
          rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors(
              rerank_score, text_score, knn_score, f
          )
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
281
282
283
          sku_selected = _has_selected_sku(hit)
          style_boost = style_intent_selected_sku_boost if sku_selected else 1.0
          fused *= style_boost
506c39b7   tangwang   feat(search): 统一重...
284
285
  
          hit["_original_score"] = hit.get("_score")
33f8f578   tangwang   tidy
286
          hit["_rerank_score"] = rerank_score
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
287
288
          hit["_text_score"] = text_score
          hit["_knn_score"] = knn_score
506c39b7   tangwang   feat(search): 统一重...
289
          hit["_fused_score"] = fused
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
290
          hit["_style_intent_selected_sku_boost"] = style_boost
814e352b   tangwang   乘法公式配置化
291
292
293
294
295
          if debug:
              hit["_text_source_score"] = text_components["source_score"]
              hit["_text_translation_score"] = text_components["translation_score"]
              hit["_text_primary_score"] = text_components["primary_text_score"]
              hit["_text_support_score"] = text_components["support_text_score"]
506c39b7   tangwang   feat(search): 统一重...
296
  
581dafae   tangwang   debug工具,每条结果的打分中间...
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
          if debug:
              debug_entry = {
                  "doc_id": hit.get("_id"),
                  "es_score": es_score,
                  "rerank_score": rerank_score,
                  "text_score": text_score,
                  "text_source_score": text_components["source_score"],
                  "text_translation_score": text_components["translation_score"],
                  "text_weighted_source_score": text_components["weighted_source_score"],
                  "text_weighted_translation_score": text_components["weighted_translation_score"],
                  "text_primary_score": text_components["primary_text_score"],
                  "text_support_score": text_components["support_text_score"],
                  "text_score_fallback_to_es": (
                      text_score == es_score
                      and text_components["source_score"] <= 0.0
                      and text_components["translation_score"] <= 0.0
                  ),
                  "knn_score": knn_score,
                  "rerank_factor": rerank_factor,
                  "text_factor": text_factor,
                  "knn_factor": knn_factor,
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
318
319
                  "style_intent_selected_sku": sku_selected,
                  "style_intent_selected_sku_boost": style_boost,
581dafae   tangwang   debug工具,每条结果的打分中间...
320
321
322
323
324
325
                  "matched_queries": matched_queries,
                  "fused_score": fused,
              }
              if rerank_debug_rows is not None and idx < len(rerank_debug_rows):
                  debug_entry["rerank_input"] = rerank_debug_rows[idx]
              fused_debug.append(debug_entry)
506c39b7   tangwang   feat(search): 统一重...
326
  
506c39b7   tangwang   feat(search): 统一重...
327
328
329
330
331
332
333
334
335
336
337
      es_hits.sort(
          key=lambda h: h.get("_fused_score", h.get("_score", 0.0)),
          reverse=True,
      )
      return fused_debug
  
  
  def run_rerank(
      query: str,
      es_response: Dict[str, Any],
      language: str = "zh",
506c39b7   tangwang   feat(search): 统一重...
338
339
340
      timeout_sec: float = DEFAULT_TIMEOUT_SEC,
      weight_es: float = DEFAULT_WEIGHT_ES,
      weight_ai: float = DEFAULT_WEIGHT_AI,
ff32d894   tangwang   rerank
341
342
      rerank_query_template: str = "{query}",
      rerank_doc_template: str = "{title}",
d31c7f65   tangwang   补充云服务reranker
343
      top_n: Optional[int] = None,
581dafae   tangwang   debug工具,每条结果的打分中间...
344
      debug: bool = False,
814e352b   tangwang   乘法公式配置化
345
      fusion: Optional[RerankFusionConfig] = None,
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
346
      style_intent_selected_sku_boost: float = 1.2,
506c39b7   tangwang   feat(search): 统一重...
347
348
349
  ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]:
      """
      完整重排流程:从 es_response  hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score
42e3aea6   tangwang   tidy
350
      Provider  URL  services_config 读取。
d31c7f65   tangwang   补充云服务reranker
351
      top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。
506c39b7   tangwang   feat(search): 统一重...
352
      """
506c39b7   tangwang   feat(search): 统一重...
353
354
355
356
      hits = es_response.get("hits", {}).get("hits") or []
      if not hits:
          return es_response, None, []
  
ff32d894   tangwang   rerank
357
      query_text = str(rerank_query_template).format_map({"query": query})
581dafae   tangwang   debug工具,每条结果的打分中间...
358
359
360
361
362
363
364
      rerank_debug_rows: Optional[List[Dict[str, Any]]] = [] if debug else None
      docs = build_docs_from_hits(
          hits,
          language=language,
          doc_template=rerank_doc_template,
          debug_rows=rerank_debug_rows,
      )
42e3aea6   tangwang   tidy
365
366
367
368
      scores, meta = call_rerank_service(
          query_text,
          docs,
          timeout_sec=timeout_sec,
d31c7f65   tangwang   补充云服务reranker
369
          top_n=top_n,
42e3aea6   tangwang   tidy
370
      )
506c39b7   tangwang   feat(search): 统一重...
371
372
373
374
375
376
377
378
379
  
      if scores is None or len(scores) != len(hits):
          return es_response, None, []
  
      fused_debug = fuse_scores_and_resort(
          hits,
          scores,
          weight_es=weight_es,
          weight_ai=weight_ai,
814e352b   tangwang   乘法公式配置化
380
          fusion=fusion,
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
381
          style_intent_selected_sku_boost=style_intent_selected_sku_boost,
581dafae   tangwang   debug工具,每条结果的打分中间...
382
383
          debug=debug,
          rerank_debug_rows=rerank_debug_rows,
506c39b7   tangwang   feat(search): 统一重...
384
385
386
387
388
389
390
391
392
      )
  
      # 更新 max_score 为融合后的最高分
      if hits:
          top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0
          if "hits" in es_response:
              es_response["hits"]["max_score"] = top
  
      return es_response, meta, fused_debug