a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
1
2
|
from math import isclose
|
814e352b
tangwang
乘法公式配置化
|
3
|
from config.schema import RerankFusionConfig
|
a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
4
5
6
|
from search.rerank_client import fuse_scores_and_resort
|
c90f80ed
tangwang
相关性优化
|
7
|
def test_fuse_scores_and_resort_aggregates_text_components_and_keeps_rerank_primary():
|
a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
8
9
10
11
12
13
|
hits = [
{
"_id": "1",
"_score": 3.2,
"matched_queries": {
"base_query": 2.4,
|
c90f80ed
tangwang
相关性优化
|
14
|
"base_query_trans_zh": 1.8,
|
a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
15
16
17
18
19
20
21
|
"knn_query": 0.8,
},
},
{
"_id": "2",
"_score": 2.8,
"matched_queries": {
|
c90f80ed
tangwang
相关性优化
|
22
|
"base_query": 9.0,
|
a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
23
24
25
26
27
|
"knn_query": 0.2,
},
},
]
|
581dafae
tangwang
debug工具,每条结果的打分中间...
|
28
|
debug = fuse_scores_and_resort(hits, [0.9, 0.7], debug=True)
|
a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
29
|
|
0536222c
tangwang
query parser优化
|
30
|
expected_text_1 = 2.4 + 0.25 * (0.8 * 1.8)
|
c90f80ed
tangwang
相关性优化
|
31
32
|
expected_fused_1 = (0.9 + 0.00001) * ((expected_text_1 + 0.1) ** 0.35) * ((0.8 + 0.6) ** 0.2)
expected_fused_2 = (0.7 + 0.00001) * ((9.0 + 0.1) ** 0.35) * ((0.2 + 0.6) ** 0.2)
|
a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
33
|
|
c90f80ed
tangwang
相关性优化
|
34
35
36
37
38
39
40
|
by_id = {hit["_id"]: hit for hit in hits}
assert isclose(by_id["1"]["_text_score"], expected_text_1, rel_tol=1e-9)
assert isclose(by_id["1"]["_fused_score"], expected_fused_1, rel_tol=1e-9)
assert isclose(by_id["2"]["_fused_score"], expected_fused_2, rel_tol=1e-9)
assert debug[0]["text_source_score"] == 2.4
assert debug[0]["text_translation_score"] == 1.8
|
581dafae
tangwang
debug工具,每条结果的打分中间...
|
41
|
assert isclose(debug[0]["text_weighted_translation_score"], 1.44, rel_tol=1e-9)
|
a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
42
|
assert debug[0]["knn_score"] == 0.8
|
581dafae
tangwang
debug工具,每条结果的打分中间...
|
43
|
assert isclose(debug[0]["rerank_factor"], 0.90001, rel_tol=1e-9)
|
c90f80ed
tangwang
相关性优化
|
44
|
assert [hit["_id"] for hit in hits] == ["2", "1"]
|
a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
45
46
47
48
49
50
51
52
|
def test_fuse_scores_and_resort_falls_back_when_matched_queries_missing():
hits = [
{"_id": "1", "_score": 0.5},
{"_id": "2", "_score": 2.0},
]
|
581dafae
tangwang
debug工具,每条结果的打分中间...
|
53
|
debug = fuse_scores_and_resort(hits, [0.4, 0.3], debug=True)
|
a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
54
|
|
c90f80ed
tangwang
相关性优化
|
55
56
57
58
|
expected_1 = (0.4 + 0.00001) * ((0.5 + 0.1) ** 0.35) * ((0.0 + 0.6) ** 0.2)
expected_2 = (0.3 + 0.00001) * ((2.0 + 0.1) ** 0.35) * ((0.0 + 0.6) ** 0.2)
by_id = {hit["_id"]: hit for hit in hits}
|
a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
59
|
|
c90f80ed
tangwang
相关性优化
|
60
61
62
63
|
assert isclose(by_id["1"]["_text_score"], 0.5, rel_tol=1e-9)
assert isclose(by_id["1"]["_fused_score"], expected_1, rel_tol=1e-9)
assert isclose(by_id["2"]["_text_score"], 2.0, rel_tol=1e-9)
assert isclose(by_id["2"]["_fused_score"], expected_2, rel_tol=1e-9)
|
581dafae
tangwang
debug工具,每条结果的打分中间...
|
64
65
|
assert debug[0]["text_score_fallback_to_es"] is True
assert debug[1]["text_score_fallback_to_es"] is True
|
a47416ec
tangwang
把融合逻辑改成乘法公式,并把 ES...
|
66
|
assert [hit["_id"] for hit in hits] == ["2", "1"]
|
c90f80ed
tangwang
相关性优化
|
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
|
def test_fuse_scores_and_resort_downweights_text_only_advantage():
hits = [
{
"_id": "lexical-heavy",
"_score": 10.0,
"matched_queries": {
"base_query": 10.0,
"knn_query": 0.0,
},
},
{
"_id": "rerank-better",
"_score": 6.0,
"matched_queries": {
"base_query": 6.0,
"knn_query": 0.0,
},
},
]
fuse_scores_and_resort(hits, [0.72, 0.98])
assert [hit["_id"] for hit in hits] == ["rerank-better", "lexical-heavy"]
|
814e352b
tangwang
乘法公式配置化
|
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
|
def test_fuse_scores_and_resort_uses_configurable_fusion_params():
hits = [
{
"_id": "a",
"_score": 1.0,
"matched_queries": {"base_query": 2.0, "knn_query": 0.5},
},
{
"_id": "b",
"_score": 1.0,
"matched_queries": {"base_query": 3.0, "knn_query": 0.0},
},
]
fusion = RerankFusionConfig(
rerank_bias=0.0,
rerank_exponent=1.0,
text_bias=0.0,
text_exponent=1.0,
knn_bias=0.0,
knn_exponent=1.0,
)
fuse_scores_and_resort(hits, [1.0, 1.0], fusion=fusion)
# b 的 knn 为 0 -> 融合为 0;a 为 1 * 2 * 0.5
assert [h["_id"] for h in hits] == ["a", "b"]
by_id = {h["_id"]: h for h in hits}
assert isclose(by_id["a"]["_fused_score"], 1.0, rel_tol=1e-9)
assert isclose(by_id["b"]["_fused_score"], 0.0, rel_tol=1e-9)
|