506c39b7
tangwang
feat(search): 统一重...
|
1
2
3
4
5
6
7
8
9
10
|
"""
重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。
流程:
1. 从 ES hits 构造用于重排的文档文本列表
2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数
3. 将 ES 分数(归一化)与重排分数线性融合,写回 hit["_score"] 并重排序
"""
from typing import Dict, Any, List, Optional, Tuple
|
506c39b7
tangwang
feat(search): 统一重...
|
11
12
|
import logging
|
42e3aea6
tangwang
tidy
|
13
14
|
from providers import create_rerank_provider
|
506c39b7
tangwang
feat(search): 统一重...
|
15
16
17
18
19
20
21
22
23
24
25
26
|
logger = logging.getLogger(__name__)
# 默认融合权重:ES 归一化分数权重、重排分数权重(相加为 1)
DEFAULT_WEIGHT_ES = 0.4
DEFAULT_WEIGHT_AI = 0.6
# 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大)
DEFAULT_TIMEOUT_SEC = 15.0
def build_docs_from_hits(
es_hits: List[Dict[str, Any]],
language: str = "zh",
|
ff32d894
tangwang
rerank
|
27
|
doc_template: str = "{title}",
|
506c39b7
tangwang
feat(search): 统一重...
|
28
29
30
31
|
) -> List[str]:
"""
从 ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。
|
ff32d894
tangwang
rerank
|
32
33
|
使用 doc_template 将文档字段组装为重排服务输入。
支持占位符:{title} {brief} {vendor} {description} {category_path}
|
506c39b7
tangwang
feat(search): 统一重...
|
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
|
Args:
es_hits: ES 返回的 hits 列表,每项含 _source
language: 语言代码,如 "zh"、"en"
Returns:
与 es_hits 等长的字符串列表,用于 POST /rerank 的 docs
"""
lang = (language or "zh").strip().lower()
if lang not in ("zh", "en"):
lang = "zh"
def pick_lang_text(obj: Any) -> str:
if obj is None:
return ""
if isinstance(obj, dict):
return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip()
return str(obj).strip()
|
ff32d894
tangwang
rerank
|
53
54
55
56
|
class _SafeDict(dict):
def __missing__(self, key: str) -> str:
return ""
|
506c39b7
tangwang
feat(search): 统一重...
|
57
|
docs: List[str] = []
|
ff32d894
tangwang
rerank
|
58
59
60
61
62
|
only_title = "{title}" == doc_template
need_brief = "{brief}" in doc_template
need_vendor = "{vendor}" in doc_template
need_description = "{description}" in doc_template
need_category_path = "{category_path}" in doc_template
|
506c39b7
tangwang
feat(search): 统一重...
|
63
64
|
for hit in es_hits:
src = hit.get("_source") or {}
|
ff32d894
tangwang
rerank
|
65
66
67
68
69
70
71
72
73
74
75
|
if only_title:
docs.append(pick_lang_text(src.get("title")))
else:
values = _SafeDict(
title=pick_lang_text(src.get("title")),
brief=pick_lang_text(src.get("brief")) if need_brief else "",
vendor=pick_lang_text(src.get("vendor")) if need_vendor else "",
description=pick_lang_text(src.get("description")) if need_description else "",
category_path=pick_lang_text(src.get("category_path")) if need_category_path else "",
)
docs.append(str(doc_template).format_map(values))
|
506c39b7
tangwang
feat(search): 统一重...
|
76
77
78
79
80
81
|
return docs
def call_rerank_service(
query: str,
docs: List[str],
|
506c39b7
tangwang
feat(search): 统一重...
|
82
|
timeout_sec: float = DEFAULT_TIMEOUT_SEC,
|
d31c7f65
tangwang
补充云服务reranker
|
83
|
top_n: Optional[int] = None,
|
506c39b7
tangwang
feat(search): 统一重...
|
84
85
86
|
) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]:
"""
调用重排服务 POST /rerank,返回分数列表与 meta。
|
42e3aea6
tangwang
tidy
|
87
|
Provider 和 URL 从 services_config 读取。
|
506c39b7
tangwang
feat(search): 统一重...
|
88
89
90
91
|
"""
if not docs:
return [], {}
try:
|
42e3aea6
tangwang
tidy
|
92
|
client = create_rerank_provider()
|
d31c7f65
tangwang
补充云服务reranker
|
93
|
return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n)
|
506c39b7
tangwang
feat(search): 统一重...
|
94
95
96
97
98
99
100
101
102
103
104
105
|
except Exception as e:
logger.warning("Rerank request failed: %s", e, exc_info=True)
return None, None
def fuse_scores_and_resort(
es_hits: List[Dict[str, Any]],
rerank_scores: List[float],
weight_es: float = DEFAULT_WEIGHT_ES,
weight_ai: float = DEFAULT_WEIGHT_AI,
) -> List[Dict[str, Any]]:
"""
|
af827ce9
tangwang
rerank
|
106
|
将 ES 分数与重排分数线性融合(不修改原始 _score),并按融合分数降序重排。
|
506c39b7
tangwang
feat(search): 统一重...
|
107
108
109
|
对每条 hit 会写入:
- _original_score: 原始 ES 分数
|
33f8f578
tangwang
tidy
|
110
|
- _rerank_score: 重排服务返回的分数
|
506c39b7
tangwang
feat(search): 统一重...
|
111
|
- _fused_score: 融合分数
|
506c39b7
tangwang
feat(search): 统一重...
|
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
|
Args:
es_hits: ES hits 列表(会被原地修改)
rerank_scores: 与 es_hits 等长的重排分数列表
weight_es: ES 归一化分数权重
weight_ai: 重排分数权重
Returns:
每条文档的融合调试信息列表,用于 debug_info
"""
n = len(es_hits)
if n == 0 or len(rerank_scores) != n:
return []
# 收集 ES 原始分数
es_scores: List[float] = []
for hit in es_hits:
raw = hit.get("_score")
try:
es_scores.append(float(raw) if raw is not None else 0.0)
except (TypeError, ValueError):
es_scores.append(0.0)
max_es = max(es_scores) if es_scores else 0.0
fused_debug: List[Dict[str, Any]] = []
for idx, hit in enumerate(es_hits):
es_score = es_scores[idx]
ai_score_raw = rerank_scores[idx]
try:
|
33f8f578
tangwang
tidy
|
142
|
rerank_score = float(ai_score_raw)
|
506c39b7
tangwang
feat(search): 统一重...
|
143
|
except (TypeError, ValueError):
|
33f8f578
tangwang
tidy
|
144
|
rerank_score = 0.0
|
506c39b7
tangwang
feat(search): 统一重...
|
145
146
|
es_norm = (es_score / max_es) if max_es > 0 else 0.0
|
33f8f578
tangwang
tidy
|
147
|
fused = weight_es * es_norm + weight_ai * rerank_score
|
506c39b7
tangwang
feat(search): 统一重...
|
148
149
|
hit["_original_score"] = hit.get("_score")
|
33f8f578
tangwang
tidy
|
150
|
hit["_rerank_score"] = rerank_score
|
506c39b7
tangwang
feat(search): 统一重...
|
151
|
hit["_fused_score"] = fused
|
506c39b7
tangwang
feat(search): 统一重...
|
152
153
154
155
156
|
fused_debug.append({
"doc_id": hit.get("_id"),
"es_score": es_score,
"es_score_norm": es_norm,
|
33f8f578
tangwang
tidy
|
157
|
"rerank_score": rerank_score,
|
506c39b7
tangwang
feat(search): 统一重...
|
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
|
"fused_score": fused,
})
# 按融合分数降序重排
es_hits.sort(
key=lambda h: h.get("_fused_score", h.get("_score", 0.0)),
reverse=True,
)
return fused_debug
def run_rerank(
query: str,
es_response: Dict[str, Any],
language: str = "zh",
|
506c39b7
tangwang
feat(search): 统一重...
|
173
174
175
|
timeout_sec: float = DEFAULT_TIMEOUT_SEC,
weight_es: float = DEFAULT_WEIGHT_ES,
weight_ai: float = DEFAULT_WEIGHT_AI,
|
ff32d894
tangwang
rerank
|
176
177
|
rerank_query_template: str = "{query}",
rerank_doc_template: str = "{title}",
|
d31c7f65
tangwang
补充云服务reranker
|
178
|
top_n: Optional[int] = None,
|
506c39b7
tangwang
feat(search): 统一重...
|
179
180
181
|
) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]:
"""
完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。
|
42e3aea6
tangwang
tidy
|
182
|
Provider 和 URL 从 services_config 读取。
|
d31c7f65
tangwang
补充云服务reranker
|
183
|
top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。
|
506c39b7
tangwang
feat(search): 统一重...
|
184
|
"""
|
506c39b7
tangwang
feat(search): 统一重...
|
185
186
187
188
|
hits = es_response.get("hits", {}).get("hits") or []
if not hits:
return es_response, None, []
|
ff32d894
tangwang
rerank
|
189
190
|
query_text = str(rerank_query_template).format_map({"query": query})
docs = build_docs_from_hits(hits, language=language, doc_template=rerank_doc_template)
|
42e3aea6
tangwang
tidy
|
191
192
193
194
|
scores, meta = call_rerank_service(
query_text,
docs,
timeout_sec=timeout_sec,
|
d31c7f65
tangwang
补充云服务reranker
|
195
|
top_n=top_n,
|
42e3aea6
tangwang
tidy
|
196
|
)
|
506c39b7
tangwang
feat(search): 统一重...
|
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
|
if scores is None or len(scores) != len(hits):
return es_response, None, []
fused_debug = fuse_scores_and_resort(
hits,
scores,
weight_es=weight_es,
weight_ai=weight_ai,
)
# 更新 max_score 为融合后的最高分
if hits:
top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0
if "hits" in es_response:
es_response["hits"]["max_score"] = top
return es_response, meta, fused_debug
|