Blame view

scripts/evaluation/eval_framework/metrics.py 2.23 KB
c81b0fc1   tangwang   scripts/evaluatio...
1
2
3
4
5
6
  """IR metrics for labeled result lists."""
  
  from __future__ import annotations
  
  from typing import Dict, Sequence
  
a345b01f   tangwang   eval framework
7
  from .constants import RELEVANCE_EXACT, RELEVANCE_IRRELEVANT, RELEVANCE_HIGH, RELEVANCE_LOW, RELEVANCE_NON_IRRELEVANT
c81b0fc1   tangwang   scripts/evaluatio...
8
9
10
11
12
13
14
15
  
  
  def precision_at_k(labels: Sequence[str], k: int, relevant: Sequence[str]) -> float:
      if k <= 0:
          return 0.0
      sliced = list(labels[:k])
      if not sliced:
          return 0.0
a345b01f   tangwang   eval framework
16
17
      rel = set(relevant)
      hits = sum(1 for label in sliced if label in rel)
c81b0fc1   tangwang   scripts/evaluatio...
18
19
20
21
      return hits / float(min(k, len(sliced)))
  
  
  def average_precision(labels: Sequence[str], relevant: Sequence[str]) -> float:
a345b01f   tangwang   eval framework
22
      rel = set(relevant)
c81b0fc1   tangwang   scripts/evaluatio...
23
24
25
      hit_count = 0
      precision_sum = 0.0
      for idx, label in enumerate(labels, start=1):
a345b01f   tangwang   eval framework
26
          if label not in rel:
c81b0fc1   tangwang   scripts/evaluatio...
27
28
29
30
31
32
33
34
35
              continue
          hit_count += 1
          precision_sum += hit_count / idx
      if hit_count == 0:
          return 0.0
      return precision_sum / hit_count
  
  
  def compute_query_metrics(labels: Sequence[str]) -> Dict[str, float]:
a345b01f   tangwang   eval framework
36
      """P@k / MAP_3: Exact Match only. P@k_2_3 / MAP_2_3: any non-irrelevant tier (legacy metric names)."""
c81b0fc1   tangwang   scripts/evaluatio...
37
      metrics: Dict[str, float] = {}
a345b01f   tangwang   eval framework
38
      non_irrel = list(RELEVANCE_NON_IRRELEVANT)
c81b0fc1   tangwang   scripts/evaluatio...
39
40
      for k in (5, 10, 20, 50):
          metrics[f"P@{k}"] = round(precision_at_k(labels, k, [RELEVANCE_EXACT]), 6)
a345b01f   tangwang   eval framework
41
          metrics[f"P@{k}_2_3"] = round(precision_at_k(labels, k, non_irrel), 6)
c81b0fc1   tangwang   scripts/evaluatio...
42
      metrics["MAP_3"] = round(average_precision(labels, [RELEVANCE_EXACT]), 6)
a345b01f   tangwang   eval framework
43
      metrics["MAP_2_3"] = round(average_precision(labels, non_irrel), 6)
c81b0fc1   tangwang   scripts/evaluatio...
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
      return metrics
  
  
  def aggregate_metrics(metric_items: Sequence[Dict[str, float]]) -> Dict[str, float]:
      if not metric_items:
          return {}
      keys = sorted(metric_items[0].keys())
      return {
          key: round(sum(float(item.get(key, 0.0)) for item in metric_items) / len(metric_items), 6)
          for key in keys
      }
  
  
  def label_distribution(labels: Sequence[str]) -> Dict[str, int]:
      return {
          RELEVANCE_EXACT: sum(1 for label in labels if label == RELEVANCE_EXACT),
a345b01f   tangwang   eval framework
60
61
          RELEVANCE_HIGH: sum(1 for label in labels if label == RELEVANCE_HIGH),
          RELEVANCE_LOW: sum(1 for label in labels if label == RELEVANCE_LOW),
c81b0fc1   tangwang   scripts/evaluatio...
62
63
          RELEVANCE_IRRELEVANT: sum(1 for label in labels if label == RELEVANCE_IRRELEVANT),
      }