""" 重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。 流程: 1. 从 ES hits 构造用于重排的文档文本列表 2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数 3. 将 ES 分数(归一化)与重排分数线性融合,写回 hit["_score"] 并重排序 """ from typing import Dict, Any, List, Optional, Tuple import logging from providers import create_rerank_provider logger = logging.getLogger(__name__) # 默认融合权重:ES 归一化分数权重、重排分数权重(相加为 1) DEFAULT_WEIGHT_ES = 0.4 DEFAULT_WEIGHT_AI = 0.6 # 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大) DEFAULT_TIMEOUT_SEC = 15.0 def build_docs_from_hits( es_hits: List[Dict[str, Any]], language: str = "zh", doc_template: str = "{title}", ) -> List[str]: """ 从 ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。 使用 doc_template 将文档字段组装为重排服务输入。 支持占位符:{title} {brief} {vendor} {description} {category_path} Args: es_hits: ES 返回的 hits 列表,每项含 _source language: 语言代码,如 "zh"、"en" Returns: 与 es_hits 等长的字符串列表,用于 POST /rerank 的 docs """ lang = (language or "zh").strip().lower() if lang not in ("zh", "en"): lang = "zh" def pick_lang_text(obj: Any) -> str: if obj is None: return "" if isinstance(obj, dict): return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip() return str(obj).strip() class _SafeDict(dict): def __missing__(self, key: str) -> str: return "" docs: List[str] = [] only_title = "{title}" == doc_template need_brief = "{brief}" in doc_template need_vendor = "{vendor}" in doc_template need_description = "{description}" in doc_template need_category_path = "{category_path}" in doc_template for hit in es_hits: src = hit.get("_source") or {} if only_title: docs.append(pick_lang_text(src.get("title"))) else: values = _SafeDict( title=pick_lang_text(src.get("title")), brief=pick_lang_text(src.get("brief")) if need_brief else "", vendor=pick_lang_text(src.get("vendor")) if need_vendor else "", description=pick_lang_text(src.get("description")) if need_description else "", category_path=pick_lang_text(src.get("category_path")) if need_category_path else "", ) docs.append(str(doc_template).format_map(values)) return docs def call_rerank_service( query: str, docs: List[str], timeout_sec: float = DEFAULT_TIMEOUT_SEC, top_n: Optional[int] = None, ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]: """ 调用重排服务 POST /rerank,返回分数列表与 meta。 Provider 和 URL 从 services_config 读取。 """ if not docs: return [], {} try: client = create_rerank_provider() return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n) except Exception as e: logger.warning("Rerank request failed: %s", e, exc_info=True) return None, None def fuse_scores_and_resort( es_hits: List[Dict[str, Any]], rerank_scores: List[float], weight_es: float = DEFAULT_WEIGHT_ES, weight_ai: float = DEFAULT_WEIGHT_AI, ) -> List[Dict[str, Any]]: """ 将 ES 分数与重排分数线性融合,写回每条 hit 的 _score,并按融合分数降序重排。 对每条 hit 会写入: - _original_score: 原始 ES 分数 - _ai_rerank_score: 重排服务返回的分数 - _fused_score: 融合分数 - _score: 置为融合分数(供后续 ResultFormatter 使用) Args: es_hits: ES hits 列表(会被原地修改) rerank_scores: 与 es_hits 等长的重排分数列表 weight_es: ES 归一化分数权重 weight_ai: 重排分数权重 Returns: 每条文档的融合调试信息列表,用于 debug_info """ n = len(es_hits) if n == 0 or len(rerank_scores) != n: return [] # 收集 ES 原始分数 es_scores: List[float] = [] for hit in es_hits: raw = hit.get("_score") try: es_scores.append(float(raw) if raw is not None else 0.0) except (TypeError, ValueError): es_scores.append(0.0) max_es = max(es_scores) if es_scores else 0.0 fused_debug: List[Dict[str, Any]] = [] for idx, hit in enumerate(es_hits): es_score = es_scores[idx] ai_score_raw = rerank_scores[idx] try: ai_score = float(ai_score_raw) except (TypeError, ValueError): ai_score = 0.0 es_norm = (es_score / max_es) if max_es > 0 else 0.0 fused = weight_es * es_norm + weight_ai * ai_score hit["_original_score"] = hit.get("_score") hit["_ai_rerank_score"] = ai_score hit["_fused_score"] = fused hit["_score"] = fused fused_debug.append({ "doc_id": hit.get("_id"), "es_score": es_score, "es_score_norm": es_norm, "ai_rerank_score": ai_score, "fused_score": fused, }) # 按融合分数降序重排 es_hits.sort( key=lambda h: h.get("_fused_score", h.get("_score", 0.0)), reverse=True, ) return fused_debug def run_rerank( query: str, es_response: Dict[str, Any], language: str = "zh", timeout_sec: float = DEFAULT_TIMEOUT_SEC, weight_es: float = DEFAULT_WEIGHT_ES, weight_ai: float = DEFAULT_WEIGHT_AI, rerank_query_template: str = "{query}", rerank_doc_template: str = "{title}", top_n: Optional[int] = None, ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]: """ 完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。 Provider 和 URL 从 services_config 读取。 top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。 """ hits = es_response.get("hits", {}).get("hits") or [] if not hits: return es_response, None, [] query_text = str(rerank_query_template).format_map({"query": query}) docs = build_docs_from_hits(hits, language=language, doc_template=rerank_doc_template) scores, meta = call_rerank_service( query_text, docs, timeout_sec=timeout_sec, top_n=top_n, ) if scores is None or len(scores) != len(hits): return es_response, None, [] fused_debug = fuse_scores_and_resort( hits, scores, weight_es=weight_es, weight_ai=weight_ai, ) # 更新 max_score 为融合后的最高分 if hits: top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0 if "hits" in es_response: es_response["hits"]["max_score"] = top return es_response, meta, fused_debug