Blame view

tests/test_eval_metrics.py 2.01 KB
30b490e1   tangwang   添加ERR评估指标
1
2
  """Tests for search evaluation ranking metrics (NDCG, ERR)."""
  
99b72698   tangwang   测试回归钩子梳理
3
4
5
6
7
8
  import math
  
  import pytest
  
  pytestmark = [pytest.mark.eval, pytest.mark.regression]
  
30b490e1   tangwang   添加ERR评估指标
9
  from scripts.evaluation.eval_framework.constants import (
99b72698   tangwang   测试回归钩子梳理
10
11
12
13
14
      RELEVANCE_LV0,
      RELEVANCE_LV1,
      RELEVANCE_LV2,
      RELEVANCE_LV3,
      STOP_PROB_MAP,
30b490e1   tangwang   添加ERR评估指标
15
16
17
18
  )
  from scripts.evaluation.eval_framework.metrics import compute_query_metrics
  
  
99b72698   tangwang   测试回归钩子梳理
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  def _expected_err(labels):
      err = 0.0
      product = 1.0
      for i, label in enumerate(labels, start=1):
          p = STOP_PROB_MAP[label]
          err += (1.0 / i) * p * product
          product *= 1.0 - p
      return err
  
  
  def test_err_matches_cascade_formula_on_four_level_labels():
      """ERR@k must equal the textbook cascade formula against the four-level label set.
  
      The metric is the primary ranking signal (see `PRIMARY_METRIC_KEYS` in
      `eval_framework.metrics`); any regression here invalidates the whole
      evaluation pipeline.
      """
  
      ranked_a = [RELEVANCE_LV3, RELEVANCE_LV0, RELEVANCE_LV2]
      ranked_b = [RELEVANCE_LV2, RELEVANCE_LV1, RELEVANCE_LV3]
  
      m_a = compute_query_metrics(ranked_a, ideal_labels=[RELEVANCE_LV3])
      m_b = compute_query_metrics(ranked_b, ideal_labels=[RELEVANCE_LV3])
  
      assert math.isclose(m_a["ERR@5"], _expected_err(ranked_a), abs_tol=1e-5)
      assert math.isclose(m_b["ERR@5"], _expected_err(ranked_b), abs_tol=1e-5)
      assert m_a["ERR@5"] > m_b["ERR@5"]
  
  
  def test_ndcg_at_k_is_1_when_actual_equals_ideal():
      labels = [RELEVANCE_LV3, RELEVANCE_LV2, RELEVANCE_LV1]
      metrics = compute_query_metrics(labels, ideal_labels=labels)
      assert math.isclose(metrics["NDCG@5"], 1.0, abs_tol=1e-9)
      assert math.isclose(metrics["NDCG@20"], 1.0, abs_tol=1e-9)
  
  
  def test_all_irrelevant_zeroes_out_primary_signals():
      labels = [RELEVANCE_LV0, RELEVANCE_LV0, RELEVANCE_LV0]
      metrics = compute_query_metrics(labels, ideal_labels=[RELEVANCE_LV3])
      assert metrics["ERR@10"] == 0.0
      assert metrics["NDCG@20"] == 0.0
      assert metrics["Strong_Precision@10"] == 0.0
      assert metrics["Primary_Metric_Score"] == 0.0