Blame view

scripts/evaluation/eval_framework/metrics.py 4.99 KB
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
1
  """Ranking metrics for graded e-commerce relevance labels."""
c81b0fc1   tangwang   scripts/evaluatio...
2
3
4
  
  from __future__ import annotations
  
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
5
6
  import math
  from typing import Dict, Iterable, Sequence
c81b0fc1   tangwang   scripts/evaluatio...
7
  
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
8
9
10
11
12
13
14
15
16
17
  from .constants import (
      RELEVANCE_EXACT,
      RELEVANCE_GAIN_MAP,
      RELEVANCE_GRADE_MAP,
      RELEVANCE_HIGH,
      RELEVANCE_IRRELEVANT,
      RELEVANCE_LOW,
      RELEVANCE_NON_IRRELEVANT,
      RELEVANCE_STRONG,
  )
c81b0fc1   tangwang   scripts/evaluatio...
18
19
  
  
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  def _normalize_label(label: str) -> str:
      if label in RELEVANCE_GRADE_MAP:
          return label
      return RELEVANCE_IRRELEVANT
  
  
  def _gains_for_labels(labels: Sequence[str]) -> list[float]:
      return [float(RELEVANCE_GAIN_MAP.get(_normalize_label(label), 0.0)) for label in labels]
  
  
  def _binary_hits(labels: Sequence[str], relevant: Iterable[str]) -> list[int]:
      relevant_set = set(relevant)
      return [1 if _normalize_label(label) in relevant_set else 0 for label in labels]
  
  
  def _precision_at_k_from_hits(hits: Sequence[int], k: int) -> float:
c81b0fc1   tangwang   scripts/evaluatio...
36
37
      if k <= 0:
          return 0.0
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
38
      sliced = list(hits[:k])
c81b0fc1   tangwang   scripts/evaluatio...
39
40
      if not sliced:
          return 0.0
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
      return sum(sliced) / float(len(sliced))
  
  
  def _success_at_k_from_hits(hits: Sequence[int], k: int) -> float:
      if k <= 0:
          return 0.0
      return 1.0 if any(hits[:k]) else 0.0
  
  
  def _reciprocal_rank_from_hits(hits: Sequence[int], k: int) -> float:
      if k <= 0:
          return 0.0
      for idx, hit in enumerate(hits[:k], start=1):
          if hit:
              return 1.0 / float(idx)
      return 0.0
c81b0fc1   tangwang   scripts/evaluatio...
57
58
  
  
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
59
60
61
62
63
64
  def _dcg_at_k(gains: Sequence[float], k: int) -> float:
      if k <= 0:
          return 0.0
      total = 0.0
      for idx, gain in enumerate(gains[:k], start=1):
          if gain <= 0.0:
c81b0fc1   tangwang   scripts/evaluatio...
65
              continue
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
          total += gain / math.log2(idx + 1.0)
      return total
  
  
  def _ndcg_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float:
      actual_gains = _gains_for_labels(labels)
      ideal_gains = sorted(_gains_for_labels(ideal_labels), reverse=True)
      dcg = _dcg_at_k(actual_gains, k)
      idcg = _dcg_at_k(ideal_gains, k)
      if idcg <= 0.0:
          return 0.0
      return dcg / idcg
  
  
  def _gain_recall_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float:
      ideal_total_gain = sum(_gains_for_labels(ideal_labels))
      if ideal_total_gain <= 0.0:
c81b0fc1   tangwang   scripts/evaluatio...
83
          return 0.0
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
84
85
      actual_gain = sum(_gains_for_labels(labels[:k]))
      return actual_gain / ideal_total_gain
c81b0fc1   tangwang   scripts/evaluatio...
86
87
  
  
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  def _grade_avg_at_k(labels: Sequence[str], k: int) -> float:
      if k <= 0:
          return 0.0
      sliced = [_normalize_label(label) for label in labels[:k]]
      if not sliced:
          return 0.0
      return sum(float(RELEVANCE_GRADE_MAP.get(label, 0)) for label in sliced) / float(len(sliced))
  
  
  def compute_query_metrics(
      labels: Sequence[str],
      *,
      ideal_labels: Sequence[str] | None = None,
  ) -> Dict[str, float]:
      """Compute graded ranking metrics plus binary diagnostic slices.
  
      `labels` are the ranked results returned by search.
      `ideal_labels` is the judged label pool for the same query; when omitted we fall back
      to the retrieved labels, which still keeps the metrics well-defined.
      """
  
      ideal = list(ideal_labels) if ideal_labels is not None else list(labels)
c81b0fc1   tangwang   scripts/evaluatio...
110
      metrics: Dict[str, float] = {}
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
111
112
113
114
115
  
      exact_hits = _binary_hits(labels, [RELEVANCE_EXACT])
      strong_hits = _binary_hits(labels, RELEVANCE_STRONG)
      useful_hits = _binary_hits(labels, RELEVANCE_NON_IRRELEVANT)
  
c81b0fc1   tangwang   scripts/evaluatio...
116
      for k in (5, 10, 20, 50):
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
117
118
119
120
121
122
123
124
125
126
127
128
129
          metrics[f"NDCG@{k}"] = round(_ndcg_at_k(labels, ideal, k), 6)
      for k in (5, 10, 20):
          metrics[f"Exact_Precision@{k}"] = round(_precision_at_k_from_hits(exact_hits, k), 6)
          metrics[f"Strong_Precision@{k}"] = round(_precision_at_k_from_hits(strong_hits, k), 6)
      for k in (10, 20, 50):
          metrics[f"Useful_Precision@{k}"] = round(_precision_at_k_from_hits(useful_hits, k), 6)
          metrics[f"Gain_Recall@{k}"] = round(_gain_recall_at_k(labels, ideal, k), 6)
      for k in (5, 10):
          metrics[f"Exact_Success@{k}"] = round(_success_at_k_from_hits(exact_hits, k), 6)
          metrics[f"Strong_Success@{k}"] = round(_success_at_k_from_hits(strong_hits, k), 6)
      metrics["MRR_Exact@10"] = round(_reciprocal_rank_from_hits(exact_hits, 10), 6)
      metrics["MRR_Strong@10"] = round(_reciprocal_rank_from_hits(strong_hits, 10), 6)
      metrics["Avg_Grade@10"] = round(_grade_avg_at_k(labels, 10), 6)
c81b0fc1   tangwang   scripts/evaluatio...
130
131
132
133
134
135
      return metrics
  
  
  def aggregate_metrics(metric_items: Sequence[Dict[str, float]]) -> Dict[str, float]:
      if not metric_items:
          return {}
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
136
      all_keys = sorted({key for item in metric_items for key in item.keys()})
c81b0fc1   tangwang   scripts/evaluatio...
137
138
      return {
          key: round(sum(float(item.get(key, 0.0)) for item in metric_items) / len(metric_items), 6)
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
139
          for key in all_keys
c81b0fc1   tangwang   scripts/evaluatio...
140
141
142
143
144
145
      }
  
  
  def label_distribution(labels: Sequence[str]) -> Dict[str, int]:
      return {
          RELEVANCE_EXACT: sum(1 for label in labels if label == RELEVANCE_EXACT),
a345b01f   tangwang   eval framework
146
147
          RELEVANCE_HIGH: sum(1 for label in labels if label == RELEVANCE_HIGH),
          RELEVANCE_LOW: sum(1 for label in labels if label == RELEVANCE_LOW),
c81b0fc1   tangwang   scripts/evaluatio...
148
149
          RELEVANCE_IRRELEVANT: sum(1 for label in labels if label == RELEVANCE_IRRELEVANT),
      }