rerank_client.py 15 KB
"""
重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。

流程:
1. 从 ES hits 构造用于重排的文档文本列表
2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数
3. 提取 ES 文本/向量子句分数,与重排分数做乘法融合并重排序
"""

from typing import Dict, Any, List, Optional, Tuple
import logging

from config.schema import RerankFusionConfig
from providers import create_rerank_provider

logger = logging.getLogger(__name__)

# 历史配置项,保留签名兼容;当前乘法融合公式不再使用线性权重。
DEFAULT_WEIGHT_ES = 0.4
DEFAULT_WEIGHT_AI = 0.6
# 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大)
DEFAULT_TIMEOUT_SEC = 15.0


def build_docs_from_hits(
    es_hits: List[Dict[str, Any]],
    language: str = "zh",
    doc_template: str = "{title}",
    debug_rows: Optional[List[Dict[str, Any]]] = None,
) -> List[str]:
    """
    从 ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。

    使用 doc_template 将文档字段组装为重排服务输入。
    支持占位符:{title} {brief} {vendor} {description} {category_path}

    Args:
        es_hits: ES 返回的 hits 列表,每项含 _source
        language: 语言代码,如 "zh"、"en"

    Returns:
        与 es_hits 等长的字符串列表,用于 POST /rerank 的 docs
    """
    lang = (language or "zh").strip().lower()
    if lang not in ("zh", "en"):
        lang = "zh"

    def pick_lang_text(obj: Any) -> str:
        if obj is None:
            return ""
        if isinstance(obj, dict):
            return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip()
        return str(obj).strip()

    class _SafeDict(dict):
        def __missing__(self, key: str) -> str:
            return ""

    docs: List[str] = []
    only_title = "{title}" == doc_template
    need_brief = "{brief}" in doc_template
    need_vendor = "{vendor}" in doc_template
    need_description = "{description}" in doc_template
    need_category_path = "{category_path}" in doc_template
    for hit in es_hits:
        src = hit.get("_source") or {}
        title_suffix = str(hit.get("_style_rerank_suffix") or "").strip()

        title_str=(
            f"{pick_lang_text(src.get('title'))} {title_suffix}".strip()
            if title_suffix
            else pick_lang_text(src.get("title"))
        )
        title_str = str(title_str).strip()

        if only_title:
            doc_text = title_str
            if debug_rows is not None:
                preview = doc_text if len(doc_text) <= 300 else f"{doc_text[:300]}..."
                debug_rows.append({
                    "doc_template": doc_template,
                    "title_suffix": title_suffix or None,
                    "fields": {
                        "title": title_str,
                    },
                    "doc_preview": preview,
                    "doc_length": len(doc_text),
                })
        else:
            values = _SafeDict(
                title=title_str,
                brief=pick_lang_text(src.get("brief")) if need_brief else "",
                vendor=pick_lang_text(src.get("vendor")) if need_vendor else "",
                description=pick_lang_text(src.get("description")) if need_description else "",
                category_path=pick_lang_text(src.get("category_path")) if need_category_path else "",
            )
            doc_text = str(doc_template).format_map(values)

            if debug_rows is not None:
                preview = doc_text if len(doc_text) <= 300 else f"{doc_text[:300]}..."
                debug_rows.append({
                    "doc_template": doc_template,
                    "title_suffix": title_suffix or None,
                    "fields": {
                        "title": title_str,
                        "brief": values.get("brief") or None,
                        "vendor": values.get("vendor") or None,
                        "category_path": values.get("category_path") or None
                    },
                    "doc_preview": preview,
                    "doc_length": len(doc_text),
                })
        docs.append(doc_text)

    return docs


def call_rerank_service(
    query: str,
    docs: List[str],
    timeout_sec: float = DEFAULT_TIMEOUT_SEC,
    top_n: Optional[int] = None,
) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]:
    """
    调用重排服务 POST /rerank,返回分数列表与 meta。
    Provider 和 URL 从 services_config 读取。
    """
    if not docs:
        return [], {}
    try:
        client = create_rerank_provider()
        return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n)
    except Exception as e:
        logger.warning("Rerank request failed: %s", e, exc_info=True)
        return None, None


def _to_score(value: Any) -> float:
    try:
        if value is None:
            return 0.0
        return float(value)
    except (TypeError, ValueError):
        return 0.0


def _extract_named_query_score(matched_queries: Any, name: str) -> float:
    if isinstance(matched_queries, dict):
        return _to_score(matched_queries.get(name))
    if isinstance(matched_queries, list):
        return 1.0 if name in matched_queries else 0.0
    return 0.0

"""
原始变量:
ES总分
source_score:从 ES 返回的 matched_queries 里取 base_query 这条 named query 的分(dict 用具体分数;list 形式则“匹配到名字就算 1.0”)。
translation_score:所有名字以 base_query_trans_ 开头的 named query 的分,在 dict 里取 最大值;在 list 里只要存在这类名字就记为 1.0。

中间变量:计算原始query得分和翻译query得分
weighted_source :
weighted_translation : 0.8 * translation_score

区分主信号和辅助信号:
合成primary_text_score和support_text_score,取 更强 的那一路(原文检索 vs 翻译检索)作为主信号
primary_text_score : max(weighted_source, weighted_translation)
support_text_score : weighted_source + weighted_translation - primary_text_score

主信号和辅助信号的融合:dismax融合公式
最终text_score:主信号 + 0.25 * 辅助信号
text_score : primary_text_score + 0.25 * support_text_score
"""
def _collect_text_score_components(matched_queries: Any, fallback_es_score: float) -> Dict[str, float]:
    source_score = _extract_named_query_score(matched_queries, "base_query")
    translation_score = 0.0

    if isinstance(matched_queries, dict):
        for query_name, score in matched_queries.items():
            if not isinstance(query_name, str):
                continue
            numeric_score = _to_score(score)
            if query_name.startswith("base_query_trans_"):
                translation_score = max(translation_score, numeric_score)
    elif isinstance(matched_queries, list):
        for query_name in matched_queries:
            if not isinstance(query_name, str):
                continue
            if query_name.startswith("base_query_trans_"):
                translation_score = 1.0

    weighted_source = source_score
    weighted_translation = 0.8 * translation_score
    weighted_components = [weighted_source, weighted_translation]
    primary_text_score = max(weighted_components)
    support_text_score = sum(weighted_components) - primary_text_score
    text_score = primary_text_score + 0.25 * support_text_score

    if text_score <= 0.0:
        text_score = fallback_es_score
        weighted_source = fallback_es_score
        primary_text_score = fallback_es_score
        support_text_score = 0.0

    return {
        "source_score": source_score,
        "translation_score": translation_score,
        "weighted_source_score": weighted_source,
        "weighted_translation_score": weighted_translation,
        "primary_text_score": primary_text_score,
        "support_text_score": support_text_score,
        "text_score": text_score,
    }


def _multiply_fusion_factors(
    rerank_score: float,
    text_score: float,
    knn_score: float,
    fusion: RerankFusionConfig,
) -> Tuple[float, float, float, float]:
    """(rerank_factor, text_factor, knn_factor, fused_without_style_boost)."""
    r = (max(rerank_score, 0.0) + fusion.rerank_bias) ** fusion.rerank_exponent
    t = (max(text_score, 0.0) + fusion.text_bias) ** fusion.text_exponent
    k = (max(knn_score, 0.0) + fusion.knn_bias) ** fusion.knn_exponent
    return r, t, k, r * t * k


def _has_selected_sku(hit: Dict[str, Any]) -> bool:
    return bool(str(hit.get("_style_rerank_suffix") or "").strip())


def fuse_scores_and_resort(
    es_hits: List[Dict[str, Any]],
    rerank_scores: List[float],
    weight_es: float = DEFAULT_WEIGHT_ES,
    weight_ai: float = DEFAULT_WEIGHT_AI,
    fusion: Optional[RerankFusionConfig] = None,
    style_intent_selected_sku_boost: float = 1.2,
    debug: bool = False,
    rerank_debug_rows: Optional[List[Dict[str, Any]]] = None,
) -> List[Dict[str, Any]]:
    """
    将 ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。

    融合形式(由 ``fusion`` 配置 bias / exponent)::
        fused = (max(rerank,0)+b_r)^e_r * (max(text,0)+b_t)^e_t * (max(knn,0)+b_k)^e_k * sku_boost

    其中 sku_boost 仅在当前 hit 已选中 SKU 时生效,默认值为 1.2,可通过
    ``query.style_intent.selected_sku_boost`` 配置。

    对每条 hit 会写入:
    - _original_score: 原始 ES 分数
    - _rerank_score: 重排服务返回的分数
    - _fused_score: 融合分数
    - _text_score: 文本相关性分数(优先取 named queries 的 base_query 分数)
    - _knn_score: KNN 分数(优先取 named queries 的 knn_query 分数)

    Args:
        es_hits: ES hits 列表(会被原地修改)
        rerank_scores: 与 es_hits 等长的重排分数列表
        weight_es: 兼容保留,当前未使用
        weight_ai: 兼容保留,当前未使用
    """
    n = len(es_hits)
    if n == 0 or len(rerank_scores) != n:
        return []

    f = fusion or RerankFusionConfig()
    fused_debug: List[Dict[str, Any]] = [] if debug else []

    for idx, hit in enumerate(es_hits):
        es_score = _to_score(hit.get("_score"))
        rerank_score = _to_score(rerank_scores[idx])
        matched_queries = hit.get("matched_queries")
        knn_score = _extract_named_query_score(matched_queries, "knn_query")
        text_components = _collect_text_score_components(matched_queries, es_score)
        text_score = text_components["text_score"]
        rerank_factor, text_factor, knn_factor, fused = _multiply_fusion_factors(
            rerank_score, text_score, knn_score, f
        )
        sku_selected = _has_selected_sku(hit)
        style_boost = style_intent_selected_sku_boost if sku_selected else 1.0
        fused *= style_boost

        hit["_original_score"] = hit.get("_score")
        hit["_rerank_score"] = rerank_score
        hit["_text_score"] = text_score
        hit["_knn_score"] = knn_score
        hit["_fused_score"] = fused
        hit["_style_intent_selected_sku_boost"] = style_boost
        if debug:
            hit["_text_source_score"] = text_components["source_score"]
            hit["_text_translation_score"] = text_components["translation_score"]
            hit["_text_primary_score"] = text_components["primary_text_score"]
            hit["_text_support_score"] = text_components["support_text_score"]

        if debug:
            debug_entry = {
                "doc_id": hit.get("_id"),
                "es_score": es_score,
                "rerank_score": rerank_score,
                "text_score": text_score,
                "text_source_score": text_components["source_score"],
                "text_translation_score": text_components["translation_score"],
                "text_weighted_source_score": text_components["weighted_source_score"],
                "text_weighted_translation_score": text_components["weighted_translation_score"],
                "text_primary_score": text_components["primary_text_score"],
                "text_support_score": text_components["support_text_score"],
                "text_score_fallback_to_es": (
                    text_score == es_score
                    and text_components["source_score"] <= 0.0
                    and text_components["translation_score"] <= 0.0
                ),
                "knn_score": knn_score,
                "rerank_factor": rerank_factor,
                "text_factor": text_factor,
                "knn_factor": knn_factor,
                "style_intent_selected_sku": sku_selected,
                "style_intent_selected_sku_boost": style_boost,
                "matched_queries": matched_queries,
                "fused_score": fused,
            }
            if rerank_debug_rows is not None and idx < len(rerank_debug_rows):
                debug_entry["rerank_input"] = rerank_debug_rows[idx]
            fused_debug.append(debug_entry)

    es_hits.sort(
        key=lambda h: h.get("_fused_score", h.get("_score", 0.0)),
        reverse=True,
    )
    return fused_debug


def run_rerank(
    query: str,
    es_response: Dict[str, Any],
    language: str = "zh",
    timeout_sec: float = DEFAULT_TIMEOUT_SEC,
    weight_es: float = DEFAULT_WEIGHT_ES,
    weight_ai: float = DEFAULT_WEIGHT_AI,
    rerank_query_template: str = "{query}",
    rerank_doc_template: str = "{title}",
    top_n: Optional[int] = None,
    debug: bool = False,
    fusion: Optional[RerankFusionConfig] = None,
    style_intent_selected_sku_boost: float = 1.2,
) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。
    Provider 和 URL 从 services_config 读取。
    top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。
    """
    hits = es_response.get("hits", {}).get("hits") or []
    if not hits:
        return es_response, None, []

    query_text = str(rerank_query_template).format_map({"query": query})
    rerank_debug_rows: Optional[List[Dict[str, Any]]] = [] if debug else None
    docs = build_docs_from_hits(
        hits,
        language=language,
        doc_template=rerank_doc_template,
        debug_rows=rerank_debug_rows,
    )
    scores, meta = call_rerank_service(
        query_text,
        docs,
        timeout_sec=timeout_sec,
        top_n=top_n,
    )

    if scores is None or len(scores) != len(hits):
        return es_response, None, []

    fused_debug = fuse_scores_and_resort(
        hits,
        scores,
        weight_es=weight_es,
        weight_ai=weight_ai,
        fusion=fusion,
        style_intent_selected_sku_boost=style_intent_selected_sku_boost,
        debug=debug,
        rerank_debug_rows=rerank_debug_rows,
    )

    # 更新 max_score 为融合后的最高分
    if hits:
        top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0
        if "hits" in es_response:
            es_response["hits"]["max_score"] = top

    return es_response, meta, fused_debug