rerank_client.py 7 KB
"""
重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。

流程:
1. 从 ES hits 构造用于重排的文档文本列表
2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数
3. 将 ES 分数(归一化)与重排分数线性融合,写回 hit["_score"] 并重排序
"""

from typing import Dict, Any, List, Optional, Tuple
import logging

from providers import create_rerank_provider

logger = logging.getLogger(__name__)

# 默认融合权重:ES 归一化分数权重、重排分数权重(相加为 1)
DEFAULT_WEIGHT_ES = 0.4
DEFAULT_WEIGHT_AI = 0.6
# 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大)
DEFAULT_TIMEOUT_SEC = 15.0


def build_docs_from_hits(
    es_hits: List[Dict[str, Any]],
    language: str = "zh",
    doc_template: str = "{title}",
) -> List[str]:
    """
    从 ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。

    使用 doc_template 将文档字段组装为重排服务输入。
    支持占位符:{title} {brief} {vendor} {description} {category_path}

    Args:
        es_hits: ES 返回的 hits 列表,每项含 _source
        language: 语言代码,如 "zh"、"en"

    Returns:
        与 es_hits 等长的字符串列表,用于 POST /rerank 的 docs
    """
    lang = (language or "zh").strip().lower()
    if lang not in ("zh", "en"):
        lang = "zh"

    def pick_lang_text(obj: Any) -> str:
        if obj is None:
            return ""
        if isinstance(obj, dict):
            return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip()
        return str(obj).strip()

    class _SafeDict(dict):
        def __missing__(self, key: str) -> str:
            return ""

    docs: List[str] = []
    only_title = "{title}" == doc_template
    need_brief = "{brief}" in doc_template
    need_vendor = "{vendor}" in doc_template
    need_description = "{description}" in doc_template
    need_category_path = "{category_path}" in doc_template
    for hit in es_hits:
        src = hit.get("_source") or {}
        if only_title:
            docs.append(pick_lang_text(src.get("title")))
        else:
            values = _SafeDict(
                title=pick_lang_text(src.get("title")),
                brief=pick_lang_text(src.get("brief")) if need_brief else "",
                vendor=pick_lang_text(src.get("vendor")) if need_vendor else "",
                description=pick_lang_text(src.get("description")) if need_description else "",
                category_path=pick_lang_text(src.get("category_path")) if need_category_path else "",
            )
            docs.append(str(doc_template).format_map(values))
    return docs


def call_rerank_service(
    query: str,
    docs: List[str],
    timeout_sec: float = DEFAULT_TIMEOUT_SEC,
    top_n: Optional[int] = None,
) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]:
    """
    调用重排服务 POST /rerank,返回分数列表与 meta。
    Provider 和 URL 从 services_config 读取。
    """
    if not docs:
        return [], {}
    try:
        client = create_rerank_provider()
        return client.rerank(query=query, docs=docs, timeout_sec=timeout_sec, top_n=top_n)
    except Exception as e:
        logger.warning("Rerank request failed: %s", e, exc_info=True)
        return None, None


def fuse_scores_and_resort(
    es_hits: List[Dict[str, Any]],
    rerank_scores: List[float],
    weight_es: float = DEFAULT_WEIGHT_ES,
    weight_ai: float = DEFAULT_WEIGHT_AI,
) -> List[Dict[str, Any]]:
    """
    将 ES 分数与重排分数线性融合(不修改原始 _score),并按融合分数降序重排。

    对每条 hit 会写入:
    - _original_score: 原始 ES 分数
    - _rerank_score: 重排服务返回的分数
    - _fused_score: 融合分数

    Args:
        es_hits: ES hits 列表(会被原地修改)
        rerank_scores: 与 es_hits 等长的重排分数列表
        weight_es: ES 归一化分数权重
        weight_ai: 重排分数权重

    Returns:
        每条文档的融合调试信息列表,用于 debug_info
    """
    n = len(es_hits)
    if n == 0 or len(rerank_scores) != n:
        return []

    # 收集 ES 原始分数
    es_scores: List[float] = []
    for hit in es_hits:
        raw = hit.get("_score")
        try:
            es_scores.append(float(raw) if raw is not None else 0.0)
        except (TypeError, ValueError):
            es_scores.append(0.0)

    max_es = max(es_scores) if es_scores else 0.0
    fused_debug: List[Dict[str, Any]] = []

    for idx, hit in enumerate(es_hits):
        es_score = es_scores[idx]
        ai_score_raw = rerank_scores[idx]
        try:
            rerank_score = float(ai_score_raw)
        except (TypeError, ValueError):
            rerank_score = 0.0

        es_norm = (es_score / max_es) if max_es > 0 else 0.0
        fused = weight_es * es_norm + weight_ai * rerank_score

        hit["_original_score"] = hit.get("_score")
        hit["_rerank_score"] = rerank_score
        hit["_fused_score"] = fused

        fused_debug.append({
            "doc_id": hit.get("_id"),
            "es_score": es_score,
            "es_score_norm": es_norm,
            "rerank_score": rerank_score,
            "fused_score": fused,
        })

    # 按融合分数降序重排
    es_hits.sort(
        key=lambda h: h.get("_fused_score", h.get("_score", 0.0)),
        reverse=True,
    )
    return fused_debug


def run_rerank(
    query: str,
    es_response: Dict[str, Any],
    language: str = "zh",
    timeout_sec: float = DEFAULT_TIMEOUT_SEC,
    weight_es: float = DEFAULT_WEIGHT_ES,
    weight_ai: float = DEFAULT_WEIGHT_AI,
    rerank_query_template: str = "{query}",
    rerank_doc_template: str = "{title}",
    top_n: Optional[int] = None,
) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。
    Provider 和 URL 从 services_config 读取。
    top_n 可选;若传入,会透传给 /rerank(供云后端按 page+size 做部分重排)。
    """
    hits = es_response.get("hits", {}).get("hits") or []
    if not hits:
        return es_response, None, []

    query_text = str(rerank_query_template).format_map({"query": query})
    docs = build_docs_from_hits(hits, language=language, doc_template=rerank_doc_template)
    scores, meta = call_rerank_service(
        query_text,
        docs,
        timeout_sec=timeout_sec,
        top_n=top_n,
    )

    if scores is None or len(scores) != len(hits):
        return es_response, None, []

    fused_debug = fuse_scores_and_resort(
        hits,
        scores,
        weight_es=weight_es,
        weight_ai=weight_ai,
    )

    # 更新 max_score 为融合后的最高分
    if hits:
        top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0
        if "hits" in es_response:
            es_response["hits"]["max_score"] = top

    return es_response, meta, fused_debug