Blame view

search/rerank_client.py 14 KB
506c39b7   tangwang   feat(search): 统一重...
1
2
3
4
5
6
  """
  重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。
  
  流程:
  1.  ES hits 构造用于重排的文档文本列表
  2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
7
  3. 提取 ES 文本/向量子句分数,与重排分数做乘法融合并重排序
506c39b7   tangwang   feat(search): 统一重...
8
9
10
  """
  
  from typing import Dict, Any, List, Optional, Tuple
506c39b7   tangwang   feat(search): 统一重...
11
12
  import logging
  
814e352b   tangwang   乘法公式配置化
13
  from config.schema import RerankFusionConfig
42e3aea6   tangwang   tidy
14
15
  from providers import create_rerank_provider
  
506c39b7   tangwang   feat(search): 统一重...
16
17
  logger = logging.getLogger(__name__)
  
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
18
  # 历史配置项,保留签名兼容;当前乘法融合公式不再使用线性权重。
506c39b7   tangwang   feat(search): 统一重...
19
20
21
22
23
24
25
26
27
  DEFAULT_WEIGHT_ES = 0.4
  DEFAULT_WEIGHT_AI = 0.6
  # 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大)
  DEFAULT_TIMEOUT_SEC = 15.0
  
  
  def build_docs_from_hits(
      es_hits: List[Dict[str, Any]],
      language: str = "zh",
ff32d894   tangwang   rerank
28
      doc_template: str = "{title}",
581dafae   tangwang   debug工具,每条结果的打分中间...
29
      debug_rows: Optional[List[Dict[str, Any]]] = None,
506c39b7   tangwang   feat(search): 统一重...
30
31
32
33
  ) -> List[str]:
      """
       ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。
  
ff32d894   tangwang   rerank
34
35
      使用 doc_template 将文档字段组装为重排服务输入。
      支持占位符:{title} {brief} {vendor} {description} {category_path}
506c39b7   tangwang   feat(search): 统一重...
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  
      Args:
          es_hits: ES 返回的 hits 列表,每项含 _source
          language: 语言代码,如 "zh""en"
  
      Returns:
           es_hits 等长的字符串列表,用于 POST /rerank  docs
      """
      lang = (language or "zh").strip().lower()
      if lang not in ("zh", "en"):
          lang = "zh"
  
      def pick_lang_text(obj: Any) -> str:
          if obj is None:
              return ""
          if isinstance(obj, dict):
              return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip()
          return str(obj).strip()
  
ff32d894   tangwang   rerank
55
56
57
58
      class _SafeDict(dict):
          def __missing__(self, key: str) -> str:
              return ""
  
506c39b7   tangwang   feat(search): 统一重...
59
      docs: List[str] = []
ff32d894   tangwang   rerank
60
61
62
63
64
      only_title = "{title}" == doc_template
      need_brief = "{brief}" in doc_template
      need_vendor = "{vendor}" in doc_template
      need_description = "{description}" in doc_template
      need_category_path = "{category_path}" in doc_template
506c39b7   tangwang   feat(search): 统一重...
65
66
      for hit in es_hits:
          src = hit.get("_source") or {}
cda1cd62   tangwang   意图分析&应用 baseline
67
          title_suffix = str(hit.get("_style_rerank_suffix") or "").strip()
6075aa91   tangwang   性能优化
68
69
70
71
72
  
          title_str=(
              f"{pick_lang_text(src.get('title'))} {title_suffix}".strip()
              if title_suffix
              else pick_lang_text(src.get("title"))
581dafae   tangwang   debug工具,每条结果的打分中间...
73
          )
6075aa91   tangwang   性能优化
74
75
          title_str = str(title_str).strip()
  
ff32d894   tangwang   rerank
76
          if only_title:
6075aa91   tangwang   性能优化
77
78
79
80
81
82
83
84
85
86
87
88
              doc_text = title_str
              if debug_rows is not None:
                  preview = doc_text if len(doc_text) <= 300 else f"{doc_text[:300]}..."
                  debug_rows.append({
                      "doc_template": doc_template,
                      "title_suffix": title_suffix or None,
                      "fields": {
                          "title": title_str,
                      },
                      "doc_preview": preview,
                      "doc_length": len(doc_text),
                  })
ff32d894   tangwang   rerank
89
          else:
6075aa91   tangwang   性能优化
90
91
92
93
94
95
96
              values = _SafeDict(
                  title=title_str,
                  brief=pick_lang_text(src.get("brief")) if need_brief else "",
                  vendor=pick_lang_text(src.get("vendor")) if need_vendor else "",
                  description=pick_lang_text(src.get("description")) if need_description else "",
                  category_path=pick_lang_text(src.get("category_path")) if need_category_path else "",
              )
581dafae   tangwang   debug工具,每条结果的打分中间...
97
              doc_text = str(doc_template).format_map(values)
6075aa91   tangwang   性能优化
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  
              if debug_rows is not None:
                  preview = doc_text if len(doc_text) <= 300 else f"{doc_text[:300]}..."
                  debug_rows.append({
                      "doc_template": doc_template,
                      "title_suffix": title_suffix or None,
                      "fields": {
                          "title": title_str,
                          "brief": values.get("brief") or None,
                          "vendor": values.get("vendor") or None,
                          "category_path": values.get("category_path") or None
                      },
                      "doc_preview": preview,
                      "doc_length": len(doc_text),
                  })
581dafae   tangwang   debug工具,每条结果的打分中间...
113
          docs.append(doc_text)
6075aa91   tangwang   性能优化
114
  
506c39b7   tangwang   feat(search): 统一重...
115
116
117
118
119
120
      return docs
  
  
  def call_rerank_service(
      query: str,
      docs: List[str],
506c39b7   tangwang   feat(search): 统一重...
121
      timeout_sec: float = DEFAULT_TIMEOUT_SEC,
d31c7f65   tangwang   补充云服务reranker
122
      top_n: Optional[int] = None,
506c39b7   tangwang   feat(search): 统一重...
123
124
125
  ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]:
      """
      调用重排服务 POST /rerank,返回分数列表与 meta
42e3aea6   tangwang   tidy
126
      Provider  URL  services_config 读取。
506c39b7   tangwang   feat(search): 统一重...
127
128
129
130
      """
      if not docs:
          return [], {}
      try:
42e3aea6   tangwang   tidy
131
          client = create_rerank_provider()
d31c7f65   tangwang   补充云服务reranker
132
          return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n)
506c39b7   tangwang   feat(search): 统一重...
133
134
135
136
137
      except Exception as e:
          logger.warning("Rerank request failed: %s", e, exc_info=True)
          return None, None
  
  
c90f80ed   tangwang   相关性优化
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
  def _to_score(value: Any) -> float:
      try:
          if value is None:
              return 0.0
          return float(value)
      except (TypeError, ValueError):
          return 0.0
  
  
  def _extract_named_query_score(matched_queries: Any, name: str) -> float:
      if isinstance(matched_queries, dict):
          return _to_score(matched_queries.get(name))
      if isinstance(matched_queries, list):
          return 1.0 if name in matched_queries else 0.0
      return 0.0
  
  
  def _collect_text_score_components(matched_queries: Any, fallback_es_score: float) -> Dict[str, float]:
      source_score = _extract_named_query_score(matched_queries, "base_query")
      translation_score = 0.0
c90f80ed   tangwang   相关性优化
158
159
160
161
162
163
164
165
  
      if isinstance(matched_queries, dict):
          for query_name, score in matched_queries.items():
              if not isinstance(query_name, str):
                  continue
              numeric_score = _to_score(score)
              if query_name.startswith("base_query_trans_"):
                  translation_score = max(translation_score, numeric_score)
c90f80ed   tangwang   相关性优化
166
167
168
169
170
171
      elif isinstance(matched_queries, list):
          for query_name in matched_queries:
              if not isinstance(query_name, str):
                  continue
              if query_name.startswith("base_query_trans_"):
                  translation_score = 1.0
c90f80ed   tangwang   相关性优化
172
173
174
  
      weighted_source = source_score
      weighted_translation = 0.8 * translation_score
0536222c   tangwang   query parser优化
175
      weighted_components = [weighted_source, weighted_translation]
c90f80ed   tangwang   相关性优化
176
177
178
179
180
181
182
183
184
185
186
187
188
      primary_text_score = max(weighted_components)
      support_text_score = sum(weighted_components) - primary_text_score
      text_score = primary_text_score + 0.25 * support_text_score
  
      if text_score <= 0.0:
          text_score = fallback_es_score
          weighted_source = fallback_es_score
          primary_text_score = fallback_es_score
          support_text_score = 0.0
  
      return {
          "source_score": source_score,
          "translation_score": translation_score,
c90f80ed   tangwang   相关性优化
189
190
          "weighted_source_score": weighted_source,
          "weighted_translation_score": weighted_translation,
c90f80ed   tangwang   相关性优化
191
192
193
194
195
196
          "primary_text_score": primary_text_score,
          "support_text_score": support_text_score,
          "text_score": text_score,
      }
  
  
814e352b   tangwang   乘法公式配置化
197
198
199
200
201
202
  def _multiply_fusion_factors(
      rerank_score: float,
      text_score: float,
      knn_score: float,
      fusion: RerankFusionConfig,
  ) -> Tuple[float, float, float, float]:
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
203
      """(rerank_factor, text_factor, knn_factor, fused_without_style_boost)."""
814e352b   tangwang   乘法公式配置化
204
205
206
207
208
209
      r = (max(rerank_score, 0.0) + fusion.rerank_bias) ** fusion.rerank_exponent
      t = (max(text_score, 0.0) + fusion.text_bias) ** fusion.text_exponent
      k = (max(knn_score, 0.0) + fusion.knn_bias) ** fusion.knn_exponent
      return r, t, k, r * t * k
  
  
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
210
211
212
213
  def _has_selected_sku(hit: Dict[str, Any]) -> bool:
      return bool(str(hit.get("_style_rerank_suffix") or "").strip())
  
  
506c39b7   tangwang   feat(search): 统一重...
214
215
216
217
218
  def fuse_scores_and_resort(
      es_hits: List[Dict[str, Any]],
      rerank_scores: List[float],
      weight_es: float = DEFAULT_WEIGHT_ES,
      weight_ai: float = DEFAULT_WEIGHT_AI,
814e352b   tangwang   乘法公式配置化
219
      fusion: Optional[RerankFusionConfig] = None,
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
220
      style_intent_selected_sku_boost: float = 1.2,
581dafae   tangwang   debug工具,每条结果的打分中间...
221
222
      debug: bool = False,
      rerank_debug_rows: Optional[List[Dict[str, Any]]] = None,
506c39b7   tangwang   feat(search): 统一重...
223
224
  ) -> List[Dict[str, Any]]:
      """
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
225
       ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。
506c39b7   tangwang   feat(search): 统一重...
226
  
814e352b   tangwang   乘法公式配置化
227
      融合形式(由 ``fusion`` 配置 bias / exponent::
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
228
229
230
231
          fused = (max(rerank,0)+b_r)^e_r * (max(text,0)+b_t)^e_t * (max(knn,0)+b_k)^e_k * sku_boost
  
      其中 sku_boost 仅在当前 hit 已选中 SKU 时生效,默认值为 1.2,可通过
      ``query.style_intent.selected_sku_boost`` 配置。
814e352b   tangwang   乘法公式配置化
232
  
506c39b7   tangwang   feat(search): 统一重...
233
234
      对每条 hit 会写入:
      - _original_score: 原始 ES 分数
33f8f578   tangwang   tidy
235
      - _rerank_score: 重排服务返回的分数
506c39b7   tangwang   feat(search): 统一重...
236
      - _fused_score: 融合分数
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
237
238
      - _text_score: 文本相关性分数(优先取 named queries  base_query 分数)
      - _knn_score: KNN 分数(优先取 named queries  knn_query 分数)
506c39b7   tangwang   feat(search): 统一重...
239
240
241
242
  
      Args:
          es_hits: ES hits 列表(会被原地修改)
          rerank_scores:  es_hits 等长的重排分数列表
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
243
244
          weight_es: 兼容保留,当前未使用
          weight_ai: 兼容保留,当前未使用
506c39b7   tangwang   feat(search): 统一重...
245
246
247
248
249
      """
      n = len(es_hits)
      if n == 0 or len(rerank_scores) != n:
          return []
  
814e352b   tangwang   乘法公式配置化
250
251
      f = fusion or RerankFusionConfig()
      fused_debug: List[Dict[str, Any]] = [] if debug else []
506c39b7   tangwang   feat(search): 统一重...
252
253
  
      for idx, hit in enumerate(es_hits):
c90f80ed   tangwang   相关性优化
254
          es_score = _to_score(hit.get("_score"))
814e352b   tangwang   乘法公式配置化
255
          rerank_score = _to_score(rerank_scores[idx])
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
256
          matched_queries = hit.get("matched_queries")
c90f80ed   tangwang   相关性优化
257
258
259
          knn_score = _extract_named_query_score(matched_queries, "knn_query")
          text_components = _collect_text_score_components(matched_queries, es_score)
          text_score = text_components["text_score"]
814e352b   tangwang   乘法公式配置化
260
261
262
          rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors(
              rerank_score, text_score, knn_score, f
          )
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
263
264
265
          sku_selected = _has_selected_sku(hit)
          style_boost = style_intent_selected_sku_boost if sku_selected else 1.0
          fused *= style_boost
506c39b7   tangwang   feat(search): 统一重...
266
267
  
          hit["_original_score"] = hit.get("_score")
33f8f578   tangwang   tidy
268
          hit["_rerank_score"] = rerank_score
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
269
270
          hit["_text_score"] = text_score
          hit["_knn_score"] = knn_score
506c39b7   tangwang   feat(search): 统一重...
271
          hit["_fused_score"] = fused
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
272
          hit["_style_intent_selected_sku_boost"] = style_boost
814e352b   tangwang   乘法公式配置化
273
274
275
276
277
          if debug:
              hit["_text_source_score"] = text_components["source_score"]
              hit["_text_translation_score"] = text_components["translation_score"]
              hit["_text_primary_score"] = text_components["primary_text_score"]
              hit["_text_support_score"] = text_components["support_text_score"]
506c39b7   tangwang   feat(search): 统一重...
278
  
581dafae   tangwang   debug工具,每条结果的打分中间...
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
          if debug:
              debug_entry = {
                  "doc_id": hit.get("_id"),
                  "es_score": es_score,
                  "rerank_score": rerank_score,
                  "text_score": text_score,
                  "text_source_score": text_components["source_score"],
                  "text_translation_score": text_components["translation_score"],
                  "text_weighted_source_score": text_components["weighted_source_score"],
                  "text_weighted_translation_score": text_components["weighted_translation_score"],
                  "text_primary_score": text_components["primary_text_score"],
                  "text_support_score": text_components["support_text_score"],
                  "text_score_fallback_to_es": (
                      text_score == es_score
                      and text_components["source_score"] <= 0.0
                      and text_components["translation_score"] <= 0.0
                  ),
                  "knn_score": knn_score,
                  "rerank_factor": rerank_factor,
                  "text_factor": text_factor,
                  "knn_factor": knn_factor,
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
300
301
                  "style_intent_selected_sku": sku_selected,
                  "style_intent_selected_sku_boost": style_boost,
581dafae   tangwang   debug工具,每条结果的打分中间...
302
303
304
305
306
307
                  "matched_queries": matched_queries,
                  "fused_score": fused,
              }
              if rerank_debug_rows is not None and idx < len(rerank_debug_rows):
                  debug_entry["rerank_input"] = rerank_debug_rows[idx]
              fused_debug.append(debug_entry)
506c39b7   tangwang   feat(search): 统一重...
308
  
506c39b7   tangwang   feat(search): 统一重...
309
310
311
312
313
314
315
316
317
318
319
      es_hits.sort(
          key=lambda h: h.get("_fused_score", h.get("_score", 0.0)),
          reverse=True,
      )
      return fused_debug
  
  
  def run_rerank(
      query: str,
      es_response: Dict[str, Any],
      language: str = "zh",
506c39b7   tangwang   feat(search): 统一重...
320
321
322
      timeout_sec: float = DEFAULT_TIMEOUT_SEC,
      weight_es: float = DEFAULT_WEIGHT_ES,
      weight_ai: float = DEFAULT_WEIGHT_AI,
ff32d894   tangwang   rerank
323
324
      rerank_query_template: str = "{query}",
      rerank_doc_template: str = "{title}",
d31c7f65   tangwang   补充云服务reranker
325
      top_n: Optional[int] = None,
581dafae   tangwang   debug工具,每条结果的打分中间...
326
      debug: bool = False,
814e352b   tangwang   乘法公式配置化
327
      fusion: Optional[RerankFusionConfig] = None,
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
328
      style_intent_selected_sku_boost: float = 1.2,
506c39b7   tangwang   feat(search): 统一重...
329
330
331
  ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]:
      """
      完整重排流程:从 es_response  hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score
42e3aea6   tangwang   tidy
332
      Provider  URL  services_config 读取。
d31c7f65   tangwang   补充云服务reranker
333
      top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。
506c39b7   tangwang   feat(search): 统一重...
334
      """
506c39b7   tangwang   feat(search): 统一重...
335
336
337
338
      hits = es_response.get("hits", {}).get("hits") or []
      if not hits:
          return es_response, None, []
  
ff32d894   tangwang   rerank
339
      query_text = str(rerank_query_template).format_map({"query": query})
581dafae   tangwang   debug工具,每条结果的打分中间...
340
341
342
343
344
345
346
      rerank_debug_rows: Optional[List[Dict[str, Any]]] = [] if debug else None
      docs = build_docs_from_hits(
          hits,
          language=language,
          doc_template=rerank_doc_template,
          debug_rows=rerank_debug_rows,
      )
42e3aea6   tangwang   tidy
347
348
349
350
      scores, meta = call_rerank_service(
          query_text,
          docs,
          timeout_sec=timeout_sec,
d31c7f65   tangwang   补充云服务reranker
351
          top_n=top_n,
42e3aea6   tangwang   tidy
352
      )
506c39b7   tangwang   feat(search): 统一重...
353
354
355
356
357
358
359
360
361
  
      if scores is None or len(scores) != len(hits):
          return es_response, None, []
  
      fused_debug = fuse_scores_and_resort(
          hits,
          scores,
          weight_es=weight_es,
          weight_ai=weight_ai,
814e352b   tangwang   乘法公式配置化
362
          fusion=fusion,
87cacb1b   tangwang   融合公式优化。加入意图匹配因子
363
          style_intent_selected_sku_boost=style_intent_selected_sku_boost,
581dafae   tangwang   debug工具,每条结果的打分中间...
364
365
          debug=debug,
          rerank_debug_rows=rerank_debug_rows,
506c39b7   tangwang   feat(search): 统一重...
366
367
368
369
370
371
372
373
374
      )
  
      # 更新 max_score 为融合后的最高分
      if hits:
          top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0
          if "hits" in es_response:
              es_response["hits"]["max_score"] = top
  
      return es_response, meta, fused_debug