test_rerank_client.py 2.98 KB
from math import isclose

from search.rerank_client import fuse_scores_and_resort


def test_fuse_scores_and_resort_aggregates_text_components_and_keeps_rerank_primary():
    hits = [
        {
            "_id": "1",
            "_score": 3.2,
            "matched_queries": {
                "base_query": 2.4,
                "base_query_trans_zh": 1.8,
                "knn_query": 0.8,
            },
        },
        {
            "_id": "2",
            "_score": 2.8,
            "matched_queries": {
                "base_query": 9.0,
                "knn_query": 0.2,
            },
        },
    ]

    debug = fuse_scores_and_resort(hits, [0.9, 0.7], debug=True)

    expected_text_1 = 2.4 + 0.25 * (0.8 * 1.8)
    expected_fused_1 = (0.9 + 0.00001) * ((expected_text_1 + 0.1) ** 0.35) * ((0.8 + 0.6) ** 0.2)
    expected_fused_2 = (0.7 + 0.00001) * ((9.0 + 0.1) ** 0.35) * ((0.2 + 0.6) ** 0.2)

    by_id = {hit["_id"]: hit for hit in hits}

    assert isclose(by_id["1"]["_text_score"], expected_text_1, rel_tol=1e-9)
    assert isclose(by_id["1"]["_fused_score"], expected_fused_1, rel_tol=1e-9)
    assert isclose(by_id["2"]["_fused_score"], expected_fused_2, rel_tol=1e-9)
    assert debug[0]["text_source_score"] == 2.4
    assert debug[0]["text_translation_score"] == 1.8
    assert isclose(debug[0]["text_weighted_translation_score"], 1.44, rel_tol=1e-9)
    assert debug[0]["knn_score"] == 0.8
    assert isclose(debug[0]["rerank_factor"], 0.90001, rel_tol=1e-9)
    assert [hit["_id"] for hit in hits] == ["2", "1"]


def test_fuse_scores_and_resort_falls_back_when_matched_queries_missing():
    hits = [
        {"_id": "1", "_score": 0.5},
        {"_id": "2", "_score": 2.0},
    ]

    debug = fuse_scores_and_resort(hits, [0.4, 0.3], debug=True)

    expected_1 = (0.4 + 0.00001) * ((0.5 + 0.1) ** 0.35) * ((0.0 + 0.6) ** 0.2)
    expected_2 = (0.3 + 0.00001) * ((2.0 + 0.1) ** 0.35) * ((0.0 + 0.6) ** 0.2)

    by_id = {hit["_id"]: hit for hit in hits}

    assert isclose(by_id["1"]["_text_score"], 0.5, rel_tol=1e-9)
    assert isclose(by_id["1"]["_fused_score"], expected_1, rel_tol=1e-9)
    assert isclose(by_id["2"]["_text_score"], 2.0, rel_tol=1e-9)
    assert isclose(by_id["2"]["_fused_score"], expected_2, rel_tol=1e-9)
    assert debug[0]["text_score_fallback_to_es"] is True
    assert debug[1]["text_score_fallback_to_es"] is True
    assert [hit["_id"] for hit in hits] == ["2", "1"]


def test_fuse_scores_and_resort_downweights_text_only_advantage():
    hits = [
        {
            "_id": "lexical-heavy",
            "_score": 10.0,
            "matched_queries": {
                "base_query": 10.0,
                "knn_query": 0.0,
            },
        },
        {
            "_id": "rerank-better",
            "_score": 6.0,
            "matched_queries": {
                "base_query": 6.0,
                "knn_query": 0.0,
            },
        },
    ]

    fuse_scores_and_resort(hits, [0.72, 0.98])

    assert [hit["_id"] for hit in hits] == ["rerank-better", "lexical-heavy"]