test_rerank_client.py 9.29 KB
from math import isclose

from config.schema import RerankFusionConfig
from search.rerank_client import fuse_scores_and_resort, run_lightweight_rerank


def test_fuse_scores_and_resort_aggregates_text_components_and_keeps_rerank_primary():
    hits = [
        {
            "_id": "1",
            "_score": 3.2,
            "matched_queries": {
                "base_query": 2.4,
                "base_query_trans_zh": 1.8,
                "knn_query": 0.8,
            },
        },
        {
            "_id": "2",
            "_score": 2.8,
            "matched_queries": {
                "base_query": 9.0,
                "knn_query": 0.2,
            },
        },
    ]

    debug = fuse_scores_and_resort(hits, [0.9, 0.7], debug=True)

    expected_text_1 = 2.4 + 0.25 * (0.8 * 1.8)
    expected_fused_1 = (0.9 + 0.00001) * ((expected_text_1 + 0.1) ** 0.35) * ((0.8 + 0.6) ** 0.2)
    expected_fused_2 = (0.7 + 0.00001) * ((9.0 + 0.1) ** 0.35) * ((0.2 + 0.6) ** 0.2)

    by_id = {hit["_id"]: hit for hit in hits}

    assert isclose(by_id["1"]["_text_score"], expected_text_1, rel_tol=1e-9)
    assert isclose(by_id["1"]["_fused_score"], expected_fused_1, rel_tol=1e-9)
    assert isclose(by_id["2"]["_fused_score"], expected_fused_2, rel_tol=1e-9)
    assert debug[0]["text_source_score"] == 2.4
    assert debug[0]["text_translation_score"] == 1.8
    assert isclose(debug[0]["text_weighted_translation_score"], 1.44, rel_tol=1e-9)
    assert debug[0]["knn_score"] == 0.8
    assert isclose(debug[0]["rerank_factor"], 0.90001, rel_tol=1e-9)
    assert [hit["_id"] for hit in hits] == ["2", "1"]


def test_fuse_scores_and_resort_falls_back_when_matched_queries_missing():
    hits = [
        {"_id": "1", "_score": 0.5},
        {"_id": "2", "_score": 2.0},
    ]

    debug = fuse_scores_and_resort(hits, [0.4, 0.3], debug=True)

    expected_1 = (0.4 + 0.00001) * ((0.5 + 0.1) ** 0.35) * ((0.0 + 0.6) ** 0.2)
    expected_2 = (0.3 + 0.00001) * ((2.0 + 0.1) ** 0.35) * ((0.0 + 0.6) ** 0.2)

    by_id = {hit["_id"]: hit for hit in hits}

    assert isclose(by_id["1"]["_text_score"], 0.5, rel_tol=1e-9)
    assert isclose(by_id["1"]["_fused_score"], expected_1, rel_tol=1e-9)
    assert isclose(by_id["2"]["_text_score"], 2.0, rel_tol=1e-9)
    assert isclose(by_id["2"]["_fused_score"], expected_2, rel_tol=1e-9)
    assert debug[0]["text_score_fallback_to_es"] is True
    assert debug[1]["text_score_fallback_to_es"] is True
    assert [hit["_id"] for hit in hits] == ["2", "1"]


def test_fuse_scores_and_resort_downweights_text_only_advantage():
    hits = [
        {
            "_id": "lexical-heavy",
            "_score": 10.0,
            "matched_queries": {
                "base_query": 10.0,
                "knn_query": 0.0,
            },
        },
        {
            "_id": "rerank-better",
            "_score": 6.0,
            "matched_queries": {
                "base_query": 6.0,
                "knn_query": 0.0,
            },
        },
    ]

    fuse_scores_and_resort(hits, [0.72, 0.98])

    assert [hit["_id"] for hit in hits] == ["rerank-better", "lexical-heavy"]


def test_fuse_scores_and_resort_uses_configurable_fusion_params():
    hits = [
        {
            "_id": "a",
            "_score": 1.0,
            "matched_queries": {"base_query": 2.0, "knn_query": 0.5},
        },
        {
            "_id": "b",
            "_score": 1.0,
            "matched_queries": {"base_query": 3.0, "knn_query": 0.0},
        },
    ]
    fusion = RerankFusionConfig(
        rerank_bias=0.0,
        rerank_exponent=1.0,
        text_bias=0.0,
        text_exponent=1.0,
        knn_bias=0.0,
        knn_exponent=1.0,
    )
    fuse_scores_and_resort(hits, [1.0, 1.0], fusion=fusion)
    # b 的 knn 为 0 -> 融合为 0;a 为 1 * 2 * 0.5
    assert [h["_id"] for h in hits] == ["a", "b"]
    by_id = {h["_id"]: h for h in hits}
    assert isclose(by_id["a"]["_fused_score"], 1.0, rel_tol=1e-9)
    assert isclose(by_id["b"]["_fused_score"], 0.0, rel_tol=1e-9)


def test_fuse_scores_and_resort_boosts_hits_with_selected_sku():
    hits = [
        {
            "_id": "style-selected",
            "_score": 1.0,
            "_style_rerank_suffix": "Blue XL",
            "matched_queries": {"base_query": 1.0, "knn_query": 0.0},
        },
        {
            "_id": "plain",
            "_score": 1.0,
            "matched_queries": {"base_query": 1.0, "knn_query": 0.0},
        },
    ]

    debug = fuse_scores_and_resort(
        hits,
        [1.0, 1.0],
        style_intent_selected_sku_boost=1.2,
        debug=True,
    )

    by_id = {h["_id"]: h for h in hits}
    assert isclose(by_id["style-selected"]["_fused_score"], by_id["plain"]["_fused_score"] * 1.2, rel_tol=1e-9)
    assert by_id["style-selected"]["_style_intent_selected_sku_boost"] == 1.2
    assert by_id["plain"]["_style_intent_selected_sku_boost"] == 1.0
    assert [h["_id"] for h in hits] == ["style-selected", "plain"]
    assert debug[0]["style_intent_selected_sku"] is True
    assert debug[0]["style_intent_selected_sku_boost"] == 1.2


def test_fuse_scores_and_resort_uses_max_of_text_and_image_knn_scores():
    hits = [
        {
            "_id": "mm-hit",
            "_score": 1.0,
            "matched_queries": {
                "base_query": 1.5,
                "knn_query": 0.2,
                "image_knn_query": 0.7,
            },
        }
    ]

    debug = fuse_scores_and_resort(hits, [0.8], debug=True)

    assert isclose(hits[0]["_knn_score"], 0.7, rel_tol=1e-9)
    assert isclose(debug[0]["knn_score"], 0.7, rel_tol=1e-9)
    assert isclose(debug[0]["text_knn_score"], 0.2, rel_tol=1e-9)
    assert isclose(debug[0]["image_knn_score"], 0.7, rel_tol=1e-9)


def test_fuse_scores_and_resort_applies_knn_dismax_weights_and_tie_breaker():
    hits = [
        {
            "_id": "mm-hit",
            "_score": 1.0,
            "matched_queries": {
                "base_query": 1.5,
                "knn_query": 0.4,
                "image_knn_query": 0.5,
            },
        }
    ]
    fusion = RerankFusionConfig(
        rerank_bias=0.00001,
        rerank_exponent=1.0,
        text_bias=0.1,
        text_exponent=0.35,
        knn_text_weight=2.0,
        knn_image_weight=1.0,
        knn_tie_breaker=0.25,
        knn_bias=0.0,
        knn_exponent=1.0,
    )

    debug = fuse_scores_and_resort(hits, [0.8], fusion=fusion, debug=True)

    expected_knn = 0.8 + 0.25 * 0.5
    assert isclose(hits[0]["_knn_score"], expected_knn, rel_tol=1e-9)
    assert isclose(debug[0]["weighted_text_knn_score"], 0.8, rel_tol=1e-9)
    assert isclose(debug[0]["weighted_image_knn_score"], 0.5, rel_tol=1e-9)
    assert isclose(debug[0]["knn_primary_score"], 0.8, rel_tol=1e-9)
    assert isclose(debug[0]["knn_support_score"], 0.5, rel_tol=1e-9)


def test_run_lightweight_rerank_sorts_by_fused_stage_score(monkeypatch):
    hits = [
        {
            "_id": "fine-raw-better",
            "_score": 1.0,
            "_source": {"title": {"en": "Alpha"}},
            "matched_queries": {"base_query": 0.5, "knn_query": 0.0},
        },
        {
            "_id": "fusion-better",
            "_score": 1.0,
            "_source": {"title": {"en": "Beta"}},
            "matched_queries": {"base_query": 40.0, "knn_query": 0.0},
        },
    ]

    monkeypatch.setattr(
        "search.rerank_client.call_rerank_service",
        lambda *args, **kwargs: ([0.9, 0.8], {"model": "fine-bge"}),
    )

    scores, meta, debug_rows = run_lightweight_rerank(
        query="toy",
        es_hits=hits,
        language="en",
        debug=True,
    )

    assert scores == [0.9, 0.8]
    assert meta == {"model": "fine-bge"}
    assert [hit["_id"] for hit in hits] == ["fusion-better", "fine-raw-better"]
    assert hits[0]["_fine_fused_score"] > hits[1]["_fine_fused_score"]
    assert debug_rows[0]["fusion_summary"]
    assert "fine_score=" in debug_rows[0]["fusion_summary"]
    assert "text_score=" in debug_rows[0]["fusion_summary"]


def test_fuse_scores_and_resort_uses_hit_level_fine_score_when_not_passed_separately():
    hits = [
        {
            "_id": "with-fine",
            "_score": 1.0,
            "_fine_score": 0.7,
            "matched_queries": {"base_query": 2.0, "knn_query": 0.5},
        }
    ]

    debug = fuse_scores_and_resort(hits, [0.8], debug=True)

    assert isclose(debug[0]["fine_factor"], (0.7 + 0.00001), rel_tol=1e-9)
    assert debug[0]["fusion_inputs"]["fine_score"] == 0.7
    assert "fine_score=" in debug[0]["fusion_summary"]


def test_fuse_scores_and_resort_can_include_raw_es_score_as_factor():
    hits = [
        {
            "_id": "es-strong",
            "_score": 100.0,
            "matched_queries": {"base_query": 1.0, "knn_query": 0.0},
        },
        {
            "_id": "es-weak",
            "_score": 1.0,
            "matched_queries": {"base_query": 1.0, "knn_query": 0.0},
        },
    ]
    fusion = RerankFusionConfig(
        es_bias=0.0,
        es_exponent=1.0,
        rerank_bias=0.0,
        rerank_exponent=1.0,
        text_bias=0.0,
        text_exponent=0.0,
        knn_bias=1.0,
        knn_exponent=0.0,
    )

    debug = fuse_scores_and_resort(hits, [1.0, 1.0], fusion=fusion, debug=True)

    assert [hit["_id"] for hit in hits] == ["es-strong", "es-weak"]
    assert isclose(hits[0]["_raw_es_score"], 100.0, rel_tol=1e-9)
    assert isclose(debug[0]["es_factor"], 100.0, rel_tol=1e-9)
    assert debug[0]["fusion_inputs"]["es_score"] == 100.0