from math import isclose from config.schema import RerankFusionConfig from search.rerank_client import fuse_scores_and_resort def test_fuse_scores_and_resort_aggregates_text_components_and_keeps_rerank_primary(): hits = [ { "_id": "1", "_score": 3.2, "matched_queries": { "base_query": 2.4, "base_query_trans_zh": 1.8, "knn_query": 0.8, }, }, { "_id": "2", "_score": 2.8, "matched_queries": { "base_query": 9.0, "knn_query": 0.2, }, }, ] debug = fuse_scores_and_resort(hits, [0.9, 0.7], debug=True) expected_text_1 = 2.4 + 0.25 * (0.8 * 1.8) expected_fused_1 = (0.9 + 0.00001) * ((expected_text_1 + 0.1) ** 0.35) * ((0.8 + 0.6) ** 0.2) expected_fused_2 = (0.7 + 0.00001) * ((9.0 + 0.1) ** 0.35) * ((0.2 + 0.6) ** 0.2) by_id = {hit["_id"]: hit for hit in hits} assert isclose(by_id["1"]["_text_score"], expected_text_1, rel_tol=1e-9) assert isclose(by_id["1"]["_fused_score"], expected_fused_1, rel_tol=1e-9) assert isclose(by_id["2"]["_fused_score"], expected_fused_2, rel_tol=1e-9) assert debug[0]["text_source_score"] == 2.4 assert debug[0]["text_translation_score"] == 1.8 assert isclose(debug[0]["text_weighted_translation_score"], 1.44, rel_tol=1e-9) assert debug[0]["knn_score"] == 0.8 assert isclose(debug[0]["rerank_factor"], 0.90001, rel_tol=1e-9) assert [hit["_id"] for hit in hits] == ["2", "1"] def test_fuse_scores_and_resort_falls_back_when_matched_queries_missing(): hits = [ {"_id": "1", "_score": 0.5}, {"_id": "2", "_score": 2.0}, ] debug = fuse_scores_and_resort(hits, [0.4, 0.3], debug=True) expected_1 = (0.4 + 0.00001) * ((0.5 + 0.1) ** 0.35) * ((0.0 + 0.6) ** 0.2) expected_2 = (0.3 + 0.00001) * ((2.0 + 0.1) ** 0.35) * ((0.0 + 0.6) ** 0.2) by_id = {hit["_id"]: hit for hit in hits} assert isclose(by_id["1"]["_text_score"], 0.5, rel_tol=1e-9) assert isclose(by_id["1"]["_fused_score"], expected_1, rel_tol=1e-9) assert isclose(by_id["2"]["_text_score"], 2.0, rel_tol=1e-9) assert isclose(by_id["2"]["_fused_score"], expected_2, rel_tol=1e-9) assert debug[0]["text_score_fallback_to_es"] is True assert debug[1]["text_score_fallback_to_es"] is True assert [hit["_id"] for hit in hits] == ["2", "1"] def test_fuse_scores_and_resort_downweights_text_only_advantage(): hits = [ { "_id": "lexical-heavy", "_score": 10.0, "matched_queries": { "base_query": 10.0, "knn_query": 0.0, }, }, { "_id": "rerank-better", "_score": 6.0, "matched_queries": { "base_query": 6.0, "knn_query": 0.0, }, }, ] fuse_scores_and_resort(hits, [0.72, 0.98]) assert [hit["_id"] for hit in hits] == ["rerank-better", "lexical-heavy"] def test_fuse_scores_and_resort_uses_configurable_fusion_params(): hits = [ { "_id": "a", "_score": 1.0, "matched_queries": {"base_query": 2.0, "knn_query": 0.5}, }, { "_id": "b", "_score": 1.0, "matched_queries": {"base_query": 3.0, "knn_query": 0.0}, }, ] fusion = RerankFusionConfig( rerank_bias=0.0, rerank_exponent=1.0, text_bias=0.0, text_exponent=1.0, knn_bias=0.0, knn_exponent=1.0, ) fuse_scores_and_resort(hits, [1.0, 1.0], fusion=fusion) # b 的 knn 为 0 -> 融合为 0;a 为 1 * 2 * 0.5 assert [h["_id"] for h in hits] == ["a", "b"] by_id = {h["_id"]: h for h in hits} assert isclose(by_id["a"]["_fused_score"], 1.0, rel_tol=1e-9) assert isclose(by_id["b"]["_fused_score"], 0.0, rel_tol=1e-9) def test_fuse_scores_and_resort_boosts_hits_with_selected_sku(): hits = [ { "_id": "style-selected", "_score": 1.0, "_style_rerank_suffix": "Blue XL", "matched_queries": {"base_query": 1.0, "knn_query": 0.0}, }, { "_id": "plain", "_score": 1.0, "matched_queries": {"base_query": 1.0, "knn_query": 0.0}, }, ] debug = fuse_scores_and_resort( hits, [1.0, 1.0], style_intent_selected_sku_boost=1.2, debug=True, ) by_id = {h["_id"]: h for h in hits} assert isclose(by_id["style-selected"]["_fused_score"], by_id["plain"]["_fused_score"] * 1.2, rel_tol=1e-9) assert by_id["style-selected"]["_style_intent_selected_sku_boost"] == 1.2 assert by_id["plain"]["_style_intent_selected_sku_boost"] == 1.0 assert [h["_id"] for h in hits] == ["style-selected", "plain"] assert debug[0]["style_intent_selected_sku"] is True assert debug[0]["style_intent_selected_sku_boost"] == 1.2 def test_fuse_scores_and_resort_uses_max_of_text_and_image_knn_scores(): hits = [ { "_id": "mm-hit", "_score": 1.0, "matched_queries": { "base_query": 1.5, "knn_query": 0.2, "image_knn_query": 0.7, }, } ] debug = fuse_scores_and_resort(hits, [0.8], debug=True) assert isclose(hits[0]["_knn_score"], 0.7, rel_tol=1e-9) assert isclose(debug[0]["knn_score"], 0.7, rel_tol=1e-9)