Blame view

tests/test_eval_metrics.py 926 Bytes
30b490e1   tangwang   添加ERR评估指标
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
  """Tests for search evaluation ranking metrics (NDCG, ERR)."""
  
  from scripts.evaluation.eval_framework.constants import (
      RELEVANCE_EXACT,
      RELEVANCE_HIGH,
      RELEVANCE_IRRELEVANT,
      RELEVANCE_LOW,
  )
  from scripts.evaluation.eval_framework.metrics import compute_query_metrics
  
  
  def test_err_matches_documented_three_item_examples():
      # Model A: [Exact, Irrelevant, High] -> ERR ≈ 0.992667
      m_a = compute_query_metrics(
          [RELEVANCE_EXACT, RELEVANCE_IRRELEVANT, RELEVANCE_HIGH],
          ideal_labels=[RELEVANCE_EXACT],
      )
      assert abs(m_a["ERR@5"] - (0.99 + (1.0 / 3.0) * 0.8 * 0.01)) < 1e-5
  
      # Model B: [High, Low, Exact] -> ERR ≈ 0.8694
      m_b = compute_query_metrics(
          [RELEVANCE_HIGH, RELEVANCE_LOW, RELEVANCE_EXACT],
          ideal_labels=[RELEVANCE_EXACT],
      )
      expected_b = 0.8 + 0.5 * 0.1 * 0.2 + (1.0 / 3.0) * 0.99 * 0.18
      assert abs(m_b["ERR@5"] - expected_b) < 1e-5