""" 重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。 流程: 1. 从 ES hits 构造用于重排的文档文本列表 2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数 3. 提取 ES 文本/向量子句分数,与重排分数做乘法融合并重排序 """ from typing import Dict, Any, List, Optional, Tuple import logging from providers import create_rerank_provider logger = logging.getLogger(__name__) # 历史配置项,保留签名兼容;当前乘法融合公式不再使用线性权重。 DEFAULT_WEIGHT_ES = 0.4 DEFAULT_WEIGHT_AI = 0.6 # 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大) DEFAULT_TIMEOUT_SEC = 15.0 def build_docs_from_hits( es_hits: List[Dict[str, Any]], language: str = "zh", doc_template: str = "{title}", ) -> List[str]: """ 从 ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。 使用 doc_template 将文档字段组装为重排服务输入。 支持占位符:{title} {brief} {vendor} {description} {category_path} Args: es_hits: ES 返回的 hits 列表,每项含 _source language: 语言代码,如 "zh"、"en" Returns: 与 es_hits 等长的字符串列表,用于 POST /rerank 的 docs """ lang = (language or "zh").strip().lower() if lang not in ("zh", "en"): lang = "zh" def pick_lang_text(obj: Any) -> str: if obj is None: return "" if isinstance(obj, dict): return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip() return str(obj).strip() class _SafeDict(dict): def __missing__(self, key: str) -> str: return "" docs: List[str] = [] only_title = "{title}" == doc_template need_brief = "{brief}" in doc_template need_vendor = "{vendor}" in doc_template need_description = "{description}" in doc_template need_category_path = "{category_path}" in doc_template for hit in es_hits: src = hit.get("_source") or {} if only_title: docs.append(pick_lang_text(src.get("title"))) else: values = _SafeDict( title=pick_lang_text(src.get("title")), brief=pick_lang_text(src.get("brief")) if need_brief else "", vendor=pick_lang_text(src.get("vendor")) if need_vendor else "", description=pick_lang_text(src.get("description")) if need_description else "", category_path=pick_lang_text(src.get("category_path")) if need_category_path else "", ) docs.append(str(doc_template).format_map(values)) return docs def call_rerank_service( query: str, docs: List[str], timeout_sec: float = DEFAULT_TIMEOUT_SEC, top_n: Optional[int] = None, ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]: """ 调用重排服务 POST /rerank,返回分数列表与 meta。 Provider 和 URL 从 services_config 读取。 """ if not docs: return [], {} try: client = create_rerank_provider() return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n) except Exception as e: logger.warning("Rerank request failed: %s", e, exc_info=True) return None, None def fuse_scores_and_resort( es_hits: List[Dict[str, Any]], rerank_scores: List[float], weight_es: float = DEFAULT_WEIGHT_ES, weight_ai: float = DEFAULT_WEIGHT_AI, ) -> List[Dict[str, Any]]: """ 将 ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。 对每条 hit 会写入: - _original_score: 原始 ES 分数 - _rerank_score: 重排服务返回的分数 - _fused_score: 融合分数 - _text_score: 文本相关性分数(优先取 named queries 的 base_query 分数) - _knn_score: KNN 分数(优先取 named queries 的 knn_query 分数) Args: es_hits: ES hits 列表(会被原地修改) rerank_scores: 与 es_hits 等长的重排分数列表 weight_es: 兼容保留,当前未使用 weight_ai: 兼容保留,当前未使用 Returns: 每条文档的融合调试信息列表,用于 debug_info """ n = len(es_hits) if n == 0 or len(rerank_scores) != n: return [] fused_debug: List[Dict[str, Any]] = [] for idx, hit in enumerate(es_hits): raw_es_score = hit.get("_score") try: es_score = float(raw_es_score) if raw_es_score is not None else 0.0 except (TypeError, ValueError): es_score = 0.0 ai_score_raw = rerank_scores[idx] try: rerank_score = float(ai_score_raw) except (TypeError, ValueError): rerank_score = 0.0 matched_queries = hit.get("matched_queries") text_score = 0.0 knn_score = 0.0 if isinstance(matched_queries, dict): try: text_score = float(matched_queries.get("base_query", 0.0) or 0.0) except (TypeError, ValueError): text_score = 0.0 try: knn_score = float(matched_queries.get("knn_query", 0.0) or 0.0) except (TypeError, ValueError): knn_score = 0.0 elif isinstance(matched_queries, list): text_score = 1.0 if "base_query" in matched_queries else 0.0 knn_score = 1.0 if "knn_query" in matched_queries else 0.0 # 回退逻辑: # - text_score 缺失时,退回原始 _score,避免纯文本召回被错误打成 0。 # - knn_score 缺失时保持 0,由平滑项 0.6 兜底。 if text_score <= 0.0: text_score = es_score fused = ( (rerank_score + 0.00001) ** 1.0 * (knn_score + 0.6) ** 0.2 * (text_score + 0.1) ** 0.75 ) hit["_original_score"] = hit.get("_score") hit["_rerank_score"] = rerank_score hit["_text_score"] = text_score hit["_knn_score"] = knn_score hit["_fused_score"] = fused fused_debug.append({ "doc_id": hit.get("_id"), "es_score": es_score, "rerank_score": rerank_score, "text_score": text_score, "knn_score": knn_score, "matched_queries": matched_queries, "fused_score": fused, }) # 按融合分数降序重排 es_hits.sort( key=lambda h: h.get("_fused_score", h.get("_score", 0.0)), reverse=True, ) return fused_debug def run_rerank( query: str, es_response: Dict[str, Any], language: str = "zh", timeout_sec: float = DEFAULT_TIMEOUT_SEC, weight_es: float = DEFAULT_WEIGHT_ES, weight_ai: float = DEFAULT_WEIGHT_AI, rerank_query_template: str = "{query}", rerank_doc_template: str = "{title}", top_n: Optional[int] = None, ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]: """ 完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。 Provider 和 URL 从 services_config 读取。 top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。 """ hits = es_response.get("hits", {}).get("hits") or [] if not hits: return es_response, None, [] query_text = str(rerank_query_template).format_map({"query": query}) docs = build_docs_from_hits(hits, language=language, doc_template=rerank_doc_template) scores, meta = call_rerank_service( query_text, docs, timeout_sec=timeout_sec, top_n=top_n, ) if scores is None or len(scores) != len(hits): return es_response, None, [] fused_debug = fuse_scores_and_resort( hits, scores, weight_es=weight_es, weight_ai=weight_ai, ) # 更新 max_score 为融合后的最高分 if hits: top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0 if "hits" in es_response: es_response["hits"]["max_score"] = top return es_response, meta, fused_debug