from math import isclose from search.rerank_client import fuse_scores_and_resort def test_fuse_scores_and_resort_uses_multiplicative_formula_with_named_query_scores(): hits = [ { "_id": "1", "_score": 3.2, "matched_queries": { "base_query": 2.4, "knn_query": 0.8, }, }, { "_id": "2", "_score": 2.8, "matched_queries": { "base_query": 1.6, "knn_query": 0.2, }, }, ] debug = fuse_scores_and_resort(hits, [0.9, 0.7]) expected_1 = (0.9 + 0.00001) * ((0.8 + 0.6) ** 0.2) * ((2.4 + 0.1) ** 0.75) expected_2 = (0.7 + 0.00001) * ((0.2 + 0.6) ** 0.2) * ((1.6 + 0.1) ** 0.75) assert isclose(hits[0]["_fused_score"], expected_1, rel_tol=1e-9) assert isclose(hits[1]["_fused_score"], expected_2, rel_tol=1e-9) assert debug[0]["text_score"] == 2.4 assert debug[0]["knn_score"] == 0.8 assert [hit["_id"] for hit in hits] == ["1", "2"] def test_fuse_scores_and_resort_falls_back_when_matched_queries_missing(): hits = [ {"_id": "1", "_score": 0.5}, {"_id": "2", "_score": 2.0}, ] fuse_scores_and_resort(hits, [0.4, 0.3]) expected_1 = (0.4 + 0.00001) * ((0.0 + 0.6) ** 0.2) * ((0.5 + 0.1) ** 0.75) expected_2 = (0.3 + 0.00001) * ((0.0 + 0.6) ** 0.2) * ((2.0 + 0.1) ** 0.75) assert isclose(hits[0]["_text_score"], 2.0, rel_tol=1e-9) assert isclose(hits[0]["_fused_score"], expected_2, rel_tol=1e-9) assert isclose(hits[1]["_text_score"], 0.5, rel_tol=1e-9) assert isclose(hits[1]["_fused_score"], expected_1, rel_tol=1e-9) assert [hit["_id"] for hit in hits] == ["2", "1"]