""" 重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。 流程: 1. 从 ES hits 构造用于重排的文档文本列表 2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数 3. 提取 ES 文本/向量子句分数,与重排分数做乘法融合并重排序 """ from typing import Dict, Any, List, Optional, Tuple import logging from config.schema import RerankFusionConfig from providers import create_rerank_provider logger = logging.getLogger(__name__) # 历史配置项,保留签名兼容;当前乘法融合公式不再使用线性权重。 DEFAULT_WEIGHT_ES = 0.4 DEFAULT_WEIGHT_AI = 0.6 # 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大) DEFAULT_TIMEOUT_SEC = 15.0 def build_docs_from_hits( es_hits: List[Dict[str, Any]], language: str = "zh", doc_template: str = "{title}", debug_rows: Optional[List[Dict[str, Any]]] = None, ) -> List[str]: """ 从 ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。 使用 doc_template 将文档字段组装为重排服务输入。 支持占位符:{title} {brief} {vendor} {description} {category_path} Args: es_hits: ES 返回的 hits 列表,每项含 _source language: 语言代码,如 "zh"、"en" Returns: 与 es_hits 等长的字符串列表,用于 POST /rerank 的 docs """ lang = (language or "zh").strip().lower() if lang not in ("zh", "en"): lang = "zh" def pick_lang_text(obj: Any) -> str: if obj is None: return "" if isinstance(obj, dict): return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip() return str(obj).strip() class _SafeDict(dict): def __missing__(self, key: str) -> str: return "" docs: List[str] = [] only_title = "{title}" == doc_template need_brief = "{brief}" in doc_template need_vendor = "{vendor}" in doc_template need_description = "{description}" in doc_template need_category_path = "{category_path}" in doc_template for hit in es_hits: src = hit.get("_source") or {} title_suffix = str(hit.get("_style_rerank_suffix") or "").strip() title_str=( f"{pick_lang_text(src.get('title'))} {title_suffix}".strip() if title_suffix else pick_lang_text(src.get("title")) ) title_str = str(title_str).strip() if only_title: doc_text = title_str if debug_rows is not None: preview = doc_text if len(doc_text) <= 300 else f"{doc_text[:300]}..." debug_rows.append({ "doc_template": doc_template, "title_suffix": title_suffix or None, "fields": { "title": title_str, }, "doc_preview": preview, "doc_length": len(doc_text), }) else: values = _SafeDict( title=title_str, brief=pick_lang_text(src.get("brief")) if need_brief else "", vendor=pick_lang_text(src.get("vendor")) if need_vendor else "", description=pick_lang_text(src.get("description")) if need_description else "", category_path=pick_lang_text(src.get("category_path")) if need_category_path else "", ) doc_text = str(doc_template).format_map(values) if debug_rows is not None: preview = doc_text if len(doc_text) <= 300 else f"{doc_text[:300]}..." debug_rows.append({ "doc_template": doc_template, "title_suffix": title_suffix or None, "fields": { "title": title_str, "brief": values.get("brief") or None, "vendor": values.get("vendor") or None, "category_path": values.get("category_path") or None }, "doc_preview": preview, "doc_length": len(doc_text), }) docs.append(doc_text) return docs def call_rerank_service( query: str, docs: List[str], timeout_sec: float = DEFAULT_TIMEOUT_SEC, top_n: Optional[int] = None, ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]: """ 调用重排服务 POST /rerank,返回分数列表与 meta。 Provider 和 URL 从 services_config 读取。 """ if not docs: return [], {} try: client = create_rerank_provider() return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n) except Exception as e: logger.warning("Rerank request failed: %s", e, exc_info=True) return None, None def _to_score(value: Any) -> float: try: if value is None: return 0.0 return float(value) except (TypeError, ValueError): return 0.0 def _extract_named_query_score(matched_queries: Any, name: str) -> float: if isinstance(matched_queries, dict): return _to_score(matched_queries.get(name)) if isinstance(matched_queries, list): return 1.0 if name in matched_queries else 0.0 return 0.0 def _collect_text_score_components(matched_queries: Any, fallback_es_score: float) -> Dict[str, float]: source_score = _extract_named_query_score(matched_queries, "base_query") translation_score = 0.0 if isinstance(matched_queries, dict): for query_name, score in matched_queries.items(): if not isinstance(query_name, str): continue numeric_score = _to_score(score) if query_name.startswith("base_query_trans_"): translation_score = max(translation_score, numeric_score) elif isinstance(matched_queries, list): for query_name in matched_queries: if not isinstance(query_name, str): continue if query_name.startswith("base_query_trans_"): translation_score = 1.0 weighted_source = source_score weighted_translation = 0.8 * translation_score weighted_components = [weighted_source, weighted_translation] primary_text_score = max(weighted_components) support_text_score = sum(weighted_components) - primary_text_score text_score = primary_text_score + 0.25 * support_text_score if text_score <= 0.0: text_score = fallback_es_score weighted_source = fallback_es_score primary_text_score = fallback_es_score support_text_score = 0.0 return { "source_score": source_score, "translation_score": translation_score, "weighted_source_score": weighted_source, "weighted_translation_score": weighted_translation, "primary_text_score": primary_text_score, "support_text_score": support_text_score, "text_score": text_score, } def _multiply_fusion_factors( rerank_score: float, text_score: float, knn_score: float, fusion: RerankFusionConfig, ) -> Tuple[float, float, float, float]: """(rerank_factor, text_factor, knn_factor, fused_without_style_boost).""" r = (max(rerank_score, 0.0) + fusion.rerank_bias) ** fusion.rerank_exponent t = (max(text_score, 0.0) + fusion.text_bias) ** fusion.text_exponent k = (max(knn_score, 0.0) + fusion.knn_bias) ** fusion.knn_exponent return r, t, k, r * t * k def _has_selected_sku(hit: Dict[str, Any]) -> bool: return bool(str(hit.get("_style_rerank_suffix") or "").strip()) def fuse_scores_and_resort( es_hits: List[Dict[str, Any]], rerank_scores: List[float], weight_es: float = DEFAULT_WEIGHT_ES, weight_ai: float = DEFAULT_WEIGHT_AI, fusion: Optional[RerankFusionConfig] = None, style_intent_selected_sku_boost: float = 1.2, debug: bool = False, rerank_debug_rows: Optional[List[Dict[str, Any]]] = None, ) -> List[Dict[str, Any]]: """ 将 ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。 融合形式(由 ``fusion`` 配置 bias / exponent):: fused = (max(rerank,0)+b_r)^e_r * (max(text,0)+b_t)^e_t * (max(knn,0)+b_k)^e_k * sku_boost 其中 sku_boost 仅在当前 hit 已选中 SKU 时生效,默认值为 1.2,可通过 ``query.style_intent.selected_sku_boost`` 配置。 对每条 hit 会写入: - _original_score: 原始 ES 分数 - _rerank_score: 重排服务返回的分数 - _fused_score: 融合分数 - _text_score: 文本相关性分数(优先取 named queries 的 base_query 分数) - _knn_score: KNN 分数(优先取 named queries 的 knn_query 分数) Args: es_hits: ES hits 列表(会被原地修改) rerank_scores: 与 es_hits 等长的重排分数列表 weight_es: 兼容保留,当前未使用 weight_ai: 兼容保留,当前未使用 """ n = len(es_hits) if n == 0 or len(rerank_scores) != n: return [] f = fusion or RerankFusionConfig() fused_debug: List[Dict[str, Any]] = [] if debug else [] for idx, hit in enumerate(es_hits): es_score = _to_score(hit.get("_score")) rerank_score = _to_score(rerank_scores[idx]) matched_queries = hit.get("matched_queries") knn_score = _extract_named_query_score(matched_queries, "knn_query") text_components = _collect_text_score_components(matched_queries, es_score) text_score = text_components["text_score"] rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors( rerank_score, text_score, knn_score, f ) sku_selected = _has_selected_sku(hit) style_boost = style_intent_selected_sku_boost if sku_selected else 1.0 fused *= style_boost hit["_original_score"] = hit.get("_score") hit["_rerank_score"] = rerank_score hit["_text_score"] = text_score hit["_knn_score"] = knn_score hit["_fused_score"] = fused hit["_style_intent_selected_sku_boost"] = style_boost if debug: hit["_text_source_score"] = text_components["source_score"] hit["_text_translation_score"] = text_components["translation_score"] hit["_text_primary_score"] = text_components["primary_text_score"] hit["_text_support_score"] = text_components["support_text_score"] if debug: debug_entry = { "doc_id": hit.get("_id"), "es_score": es_score, "rerank_score": rerank_score, "text_score": text_score, "text_source_score": text_components["source_score"], "text_translation_score": text_components["translation_score"], "text_weighted_source_score": text_components["weighted_source_score"], "text_weighted_translation_score": text_components["weighted_translation_score"], "text_primary_score": text_components["primary_text_score"], "text_support_score": text_components["support_text_score"], "text_score_fallback_to_es": ( text_score == es_score and text_components["source_score"] <= 0.0 and text_components["translation_score"] <= 0.0 ), "knn_score": knn_score, "rerank_factor": rerank_factor, "text_factor": text_factor, "knn_factor": knn_factor, "style_intent_selected_sku": sku_selected, "style_intent_selected_sku_boost": style_boost, "matched_queries": matched_queries, "fused_score": fused, } if rerank_debug_rows is not None and idx < len(rerank_debug_rows): debug_entry["rerank_input"] = rerank_debug_rows[idx] fused_debug.append(debug_entry) es_hits.sort( key=lambda h: h.get("_fused_score", h.get("_score", 0.0)), reverse=True, ) return fused_debug def run_rerank( query: str, es_response: Dict[str, Any], language: str = "zh", timeout_sec: float = DEFAULT_TIMEOUT_SEC, weight_es: float = DEFAULT_WEIGHT_ES, weight_ai: float = DEFAULT_WEIGHT_AI, rerank_query_template: str = "{query}", rerank_doc_template: str = "{title}", top_n: Optional[int] = None, debug: bool = False, fusion: Optional[RerankFusionConfig] = None, style_intent_selected_sku_boost: float = 1.2, ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]: """ 完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。 Provider 和 URL 从 services_config 读取。 top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。 """ hits = es_response.get("hits", {}).get("hits") or [] if not hits: return es_response, None, [] query_text = str(rerank_query_template).format_map({"query": query}) rerank_debug_rows: Optional[List[Dict[str, Any]]] = [] if debug else None docs = build_docs_from_hits( hits, language=language, doc_template=rerank_doc_template, debug_rows=rerank_debug_rows, ) scores, meta = call_rerank_service( query_text, docs, timeout_sec=timeout_sec, top_n=top_n, ) if scores is None or len(scores) != len(hits): return es_response, None, [] fused_debug = fuse_scores_and_resort( hits, scores, weight_es=weight_es, weight_ai=weight_ai, fusion=fusion, style_intent_selected_sku_boost=style_intent_selected_sku_boost, debug=debug, rerank_debug_rows=rerank_debug_rows, ) # 更新 max_score 为融合后的最高分 if hits: top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0 if "hits" in es_response: es_response["hits"]["max_score"] = top return es_response, meta, fused_debug