""" 重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。 流程: 1. 从 ES hits 构造用于重排的文档文本列表 2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数 3. 提取 ES 文本/向量子句分数,与重排分数做乘法融合并重排序 """ from typing import Dict, Any, List, Optional, Tuple import logging from providers import create_rerank_provider logger = logging.getLogger(__name__) # 历史配置项,保留签名兼容;当前乘法融合公式不再使用线性权重。 DEFAULT_WEIGHT_ES = 0.4 DEFAULT_WEIGHT_AI = 0.6 # 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大) DEFAULT_TIMEOUT_SEC = 15.0 def build_docs_from_hits( es_hits: List[Dict[str, Any]], language: str = "zh", doc_template: str = "{title}", ) -> List[str]: """ 从 ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。 使用 doc_template 将文档字段组装为重排服务输入。 支持占位符:{title} {brief} {vendor} {description} {category_path} Args: es_hits: ES 返回的 hits 列表,每项含 _source language: 语言代码,如 "zh"、"en" Returns: 与 es_hits 等长的字符串列表,用于 POST /rerank 的 docs """ lang = (language or "zh").strip().lower() if lang not in ("zh", "en"): lang = "zh" def pick_lang_text(obj: Any) -> str: if obj is None: return "" if isinstance(obj, dict): return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip() return str(obj).strip() class _SafeDict(dict): def __missing__(self, key: str) -> str: return "" docs: List[str] = [] only_title = "{title}" == doc_template need_brief = "{brief}" in doc_template need_vendor = "{vendor}" in doc_template need_description = "{description}" in doc_template need_category_path = "{category_path}" in doc_template for hit in es_hits: src = hit.get("_source") or {} if only_title: docs.append(pick_lang_text(src.get("title"))) else: values = _SafeDict( title=pick_lang_text(src.get("title")), brief=pick_lang_text(src.get("brief")) if need_brief else "", vendor=pick_lang_text(src.get("vendor")) if need_vendor else "", description=pick_lang_text(src.get("description")) if need_description else "", category_path=pick_lang_text(src.get("category_path")) if need_category_path else "", ) docs.append(str(doc_template).format_map(values)) return docs def call_rerank_service( query: str, docs: List[str], timeout_sec: float = DEFAULT_TIMEOUT_SEC, top_n: Optional[int] = None, ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]: """ 调用重排服务 POST /rerank,返回分数列表与 meta。 Provider 和 URL 从 services_config 读取。 """ if not docs: return [], {} try: client = create_rerank_provider() return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n) except Exception as e: logger.warning("Rerank request failed: %s", e, exc_info=True) return None, None def _to_score(value: Any) -> float: try: if value is None: return 0.0 return float(value) except (TypeError, ValueError): return 0.0 def _extract_named_query_score(matched_queries: Any, name: str) -> float: if isinstance(matched_queries, dict): return _to_score(matched_queries.get(name)) if isinstance(matched_queries, list): return 1.0 if name in matched_queries else 0.0 return 0.0 def _collect_text_score_components(matched_queries: Any, fallback_es_score: float) -> Dict[str, float]: source_score = _extract_named_query_score(matched_queries, "base_query") translation_score = 0.0 if isinstance(matched_queries, dict): for query_name, score in matched_queries.items(): if not isinstance(query_name, str): continue numeric_score = _to_score(score) if query_name.startswith("base_query_trans_"): translation_score = max(translation_score, numeric_score) elif isinstance(matched_queries, list): for query_name in matched_queries: if not isinstance(query_name, str): continue if query_name.startswith("base_query_trans_"): translation_score = 1.0 weighted_source = source_score weighted_translation = 0.8 * translation_score weighted_components = [weighted_source, weighted_translation] primary_text_score = max(weighted_components) support_text_score = sum(weighted_components) - primary_text_score text_score = primary_text_score + 0.25 * support_text_score if text_score <= 0.0: text_score = fallback_es_score weighted_source = fallback_es_score primary_text_score = fallback_es_score support_text_score = 0.0 return { "source_score": source_score, "translation_score": translation_score, "weighted_source_score": weighted_source, "weighted_translation_score": weighted_translation, "primary_text_score": primary_text_score, "support_text_score": support_text_score, "text_score": text_score, } def _fuse_score(rerank_score: float, text_score: float, knn_score: float) -> float: rerank_factor = max(rerank_score, 0.0) + 0.00001 text_factor = (max(text_score, 0.0) + 0.1) ** 0.35 knn_factor = (max(knn_score, 0.0) + 0.6) ** 0.2 return rerank_factor * text_factor * knn_factor def fuse_scores_and_resort( es_hits: List[Dict[str, Any]], rerank_scores: List[float], weight_es: float = DEFAULT_WEIGHT_ES, weight_ai: float = DEFAULT_WEIGHT_AI, ) -> List[Dict[str, Any]]: """ 将 ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。 对每条 hit 会写入: - _original_score: 原始 ES 分数 - _rerank_score: 重排服务返回的分数 - _fused_score: 融合分数 - _text_score: 文本相关性分数(优先取 named queries 的 base_query 分数) - _knn_score: KNN 分数(优先取 named queries 的 knn_query 分数) Args: es_hits: ES hits 列表(会被原地修改) rerank_scores: 与 es_hits 等长的重排分数列表 weight_es: 兼容保留,当前未使用 weight_ai: 兼容保留,当前未使用 Returns: 每条文档的融合调试信息列表,用于 debug_info """ n = len(es_hits) if n == 0 or len(rerank_scores) != n: return [] fused_debug: List[Dict[str, Any]] = [] for idx, hit in enumerate(es_hits): es_score = _to_score(hit.get("_score")) ai_score_raw = rerank_scores[idx] rerank_score = _to_score(ai_score_raw) matched_queries = hit.get("matched_queries") knn_score = _extract_named_query_score(matched_queries, "knn_query") text_components = _collect_text_score_components(matched_queries, es_score) text_score = text_components["text_score"] fused = _fuse_score(rerank_score, text_score, knn_score) hit["_original_score"] = hit.get("_score") hit["_rerank_score"] = rerank_score hit["_text_score"] = text_score hit["_knn_score"] = knn_score hit["_text_source_score"] = text_components["source_score"] hit["_text_translation_score"] = text_components["translation_score"] hit["_text_primary_score"] = text_components["primary_text_score"] hit["_text_support_score"] = text_components["support_text_score"] hit["_fused_score"] = fused fused_debug.append({ "doc_id": hit.get("_id"), "es_score": es_score, "rerank_score": rerank_score, "text_score": text_score, "text_source_score": text_components["source_score"], "text_translation_score": text_components["translation_score"], "text_primary_score": text_components["primary_text_score"], "text_support_score": text_components["support_text_score"], "knn_score": knn_score, "matched_queries": matched_queries, "fused_score": fused, }) # 按融合分数降序重排 es_hits.sort( key=lambda h: h.get("_fused_score", h.get("_score", 0.0)), reverse=True, ) return fused_debug def run_rerank( query: str, es_response: Dict[str, Any], language: str = "zh", timeout_sec: float = DEFAULT_TIMEOUT_SEC, weight_es: float = DEFAULT_WEIGHT_ES, weight_ai: float = DEFAULT_WEIGHT_AI, rerank_query_template: str = "{query}", rerank_doc_template: str = "{title}", top_n: Optional[int] = None, ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]: """ 完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。 Provider 和 URL 从 services_config 读取。 top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。 """ hits = es_response.get("hits", {}).get("hits") or [] if not hits: return es_response, None, [] query_text = str(rerank_query_template).format_map({"query": query}) docs = build_docs_from_hits(hits, language=language, doc_template=rerank_doc_template) scores, meta = call_rerank_service( query_text, docs, timeout_sec=timeout_sec, top_n=top_n, ) if scores is None or len(scores) != len(hits): return es_response, None, [] fused_debug = fuse_scores_and_resort( hits, scores, weight_es=weight_es, weight_ai=weight_ai, ) # 更新 max_score 为融合后的最高分 if hits: top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0 if "hits" in es_response: es_response["hits"]["max_score"] = top return es_response, meta, fused_debug