Blame view

tests/test_rerank_client.py 2.8 KB
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
1
2
3
4
5
  from math import isclose
  
  from search.rerank_client import fuse_scores_and_resort
  
  
c90f80ed   tangwang   相关性优化
6
  def test_fuse_scores_and_resort_aggregates_text_components_and_keeps_rerank_primary():
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
7
8
9
10
11
12
      hits = [
          {
              "_id": "1",
              "_score": 3.2,
              "matched_queries": {
                  "base_query": 2.4,
c90f80ed   tangwang   相关性优化
13
14
                  "base_query_trans_zh": 1.8,
                  "fallback_original_query_zh": 1.2,
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
15
16
17
18
19
20
21
                  "knn_query": 0.8,
              },
          },
          {
              "_id": "2",
              "_score": 2.8,
              "matched_queries": {
c90f80ed   tangwang   相关性优化
22
                  "base_query": 9.0,
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
23
24
25
26
27
28
29
                  "knn_query": 0.2,
              },
          },
      ]
  
      debug = fuse_scores_and_resort(hits, [0.9, 0.7])
  
c90f80ed   tangwang   相关性优化
30
31
32
      expected_text_1 = 2.4 + 0.25 * ((0.8 * 1.8) + (0.55 * 1.2))
      expected_fused_1 = (0.9 + 0.00001) * ((expected_text_1 + 0.1) ** 0.35) * ((0.8 + 0.6) ** 0.2)
      expected_fused_2 = (0.7 + 0.00001) * ((9.0 + 0.1) ** 0.35) * ((0.2 + 0.6) ** 0.2)
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
33
  
c90f80ed   tangwang   相关性优化
34
35
36
37
38
39
40
41
      by_id = {hit["_id"]: hit for hit in hits}
  
      assert isclose(by_id["1"]["_text_score"], expected_text_1, rel_tol=1e-9)
      assert isclose(by_id["1"]["_fused_score"], expected_fused_1, rel_tol=1e-9)
      assert isclose(by_id["2"]["_fused_score"], expected_fused_2, rel_tol=1e-9)
      assert debug[0]["text_source_score"] == 2.4
      assert debug[0]["text_translation_score"] == 1.8
      assert debug[0]["text_fallback_score"] == 1.2
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
42
      assert debug[0]["knn_score"] == 0.8
c90f80ed   tangwang   相关性优化
43
      assert [hit["_id"] for hit in hits] == ["2", "1"]
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
44
45
46
47
48
49
50
51
52
53
  
  
  def test_fuse_scores_and_resort_falls_back_when_matched_queries_missing():
      hits = [
          {"_id": "1", "_score": 0.5},
          {"_id": "2", "_score": 2.0},
      ]
  
      fuse_scores_and_resort(hits, [0.4, 0.3])
  
c90f80ed   tangwang   相关性优化
54
55
56
57
      expected_1 = (0.4 + 0.00001) * ((0.5 + 0.1) ** 0.35) * ((0.0 + 0.6) ** 0.2)
      expected_2 = (0.3 + 0.00001) * ((2.0 + 0.1) ** 0.35) * ((0.0 + 0.6) ** 0.2)
  
      by_id = {hit["_id"]: hit for hit in hits}
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
58
  
c90f80ed   tangwang   相关性优化
59
60
61
62
      assert isclose(by_id["1"]["_text_score"], 0.5, rel_tol=1e-9)
      assert isclose(by_id["1"]["_fused_score"], expected_1, rel_tol=1e-9)
      assert isclose(by_id["2"]["_text_score"], 2.0, rel_tol=1e-9)
      assert isclose(by_id["2"]["_fused_score"], expected_2, rel_tol=1e-9)
a47416ec   tangwang   把融合逻辑改成乘法公式,并把 ES...
63
      assert [hit["_id"] for hit in hits] == ["2", "1"]
c90f80ed   tangwang   相关性优化
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  
  
  def test_fuse_scores_and_resort_downweights_text_only_advantage():
      hits = [
          {
              "_id": "lexical-heavy",
              "_score": 10.0,
              "matched_queries": {
                  "base_query": 10.0,
                  "knn_query": 0.0,
              },
          },
          {
              "_id": "rerank-better",
              "_score": 6.0,
              "matched_queries": {
                  "base_query": 6.0,
                  "knn_query": 0.0,
              },
          },
      ]
  
      fuse_scores_and_resort(hits, [0.72, 0.98])
  
      assert [hit["_id"] for hit in hits] == ["rerank-better", "lexical-heavy"]