Blame view

scripts/evaluation/eval_framework/metrics.py 6.31 KB
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
1
  """Ranking metrics for graded e-commerce relevance labels."""
c81b0fc1   tangwang   scripts/evaluatio...
2
3
4
  
  from __future__ import annotations
  
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
5
6
  import math
  from typing import Dict, Iterable, Sequence
c81b0fc1   tangwang   scripts/evaluatio...
7
  
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
8
9
10
11
12
13
14
15
16
  from .constants import (
      RELEVANCE_EXACT,
      RELEVANCE_GAIN_MAP,
      RELEVANCE_GRADE_MAP,
      RELEVANCE_HIGH,
      RELEVANCE_IRRELEVANT,
      RELEVANCE_LOW,
      RELEVANCE_NON_IRRELEVANT,
      RELEVANCE_STRONG,
30b490e1   tangwang   添加ERR评估指标
17
      STOP_PROB_MAP,
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
18
  )
c81b0fc1   tangwang   scripts/evaluatio...
19
  
465f90e1   tangwang   添加LTR数据收集
20
21
22
23
24
25
26
27
28
29
30
31
  PRIMARY_METRIC_KEYS: tuple[str, ...] = (
      "NDCG@20",
      "NDCG@50",
      "ERR@10",
      "Strong_Precision@10",
      "Strong_Precision@20",
      "Useful_Precision@50",
      "Avg_Grade@10",
      "Gain_Recall@20",
  )
  PRIMARY_METRIC_GRADE_NORMALIZER = float(max(RELEVANCE_GRADE_MAP.values()) or 1.0)
  
c81b0fc1   tangwang   scripts/evaluatio...
32
  
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
  def _normalize_label(label: str) -> str:
      if label in RELEVANCE_GRADE_MAP:
          return label
      return RELEVANCE_IRRELEVANT
  
  
  def _gains_for_labels(labels: Sequence[str]) -> list[float]:
      return [float(RELEVANCE_GAIN_MAP.get(_normalize_label(label), 0.0)) for label in labels]
  
  
  def _binary_hits(labels: Sequence[str], relevant: Iterable[str]) -> list[int]:
      relevant_set = set(relevant)
      return [1 if _normalize_label(label) in relevant_set else 0 for label in labels]
  
  
  def _precision_at_k_from_hits(hits: Sequence[int], k: int) -> float:
c81b0fc1   tangwang   scripts/evaluatio...
49
50
      if k <= 0:
          return 0.0
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
51
      sliced = list(hits[:k])
c81b0fc1   tangwang   scripts/evaluatio...
52
53
      if not sliced:
          return 0.0
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
      return sum(sliced) / float(len(sliced))
  
  
  def _success_at_k_from_hits(hits: Sequence[int], k: int) -> float:
      if k <= 0:
          return 0.0
      return 1.0 if any(hits[:k]) else 0.0
  
  
  def _reciprocal_rank_from_hits(hits: Sequence[int], k: int) -> float:
      if k <= 0:
          return 0.0
      for idx, hit in enumerate(hits[:k], start=1):
          if hit:
              return 1.0 / float(idx)
      return 0.0
c81b0fc1   tangwang   scripts/evaluatio...
70
71
  
  
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
72
73
74
75
76
77
  def _dcg_at_k(gains: Sequence[float], k: int) -> float:
      if k <= 0:
          return 0.0
      total = 0.0
      for idx, gain in enumerate(gains[:k], start=1):
          if gain <= 0.0:
c81b0fc1   tangwang   scripts/evaluatio...
78
              continue
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
79
80
81
82
83
84
85
86
87
88
89
90
91
92
          total += gain / math.log2(idx + 1.0)
      return total
  
  
  def _ndcg_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float:
      actual_gains = _gains_for_labels(labels)
      ideal_gains = sorted(_gains_for_labels(ideal_labels), reverse=True)
      dcg = _dcg_at_k(actual_gains, k)
      idcg = _dcg_at_k(ideal_gains, k)
      if idcg <= 0.0:
          return 0.0
      return dcg / idcg
  
  
30b490e1   tangwang   添加ERR评估指标
93
94
95
96
97
98
99
100
101
102
103
104
105
  def _err_at_k(labels: Sequence[str], k: int) -> float:
      """Expected Reciprocal Rank on the first ``k`` positions (truncated ranked list)."""
      if k <= 0:
          return 0.0
      err = 0.0
      product = 1.0
      for i, label in enumerate(labels[:k], start=1):
          p_stop = float(STOP_PROB_MAP.get(_normalize_label(label), 0.0))
          err += (1.0 / float(i)) * p_stop * product
          product *= 1.0 - p_stop
      return err
  
  
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
106
107
108
  def _gain_recall_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float:
      ideal_total_gain = sum(_gains_for_labels(ideal_labels))
      if ideal_total_gain <= 0.0:
c81b0fc1   tangwang   scripts/evaluatio...
109
          return 0.0
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
110
111
      actual_gain = sum(_gains_for_labels(labels[:k]))
      return actual_gain / ideal_total_gain
c81b0fc1   tangwang   scripts/evaluatio...
112
113
  
  
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
  def _grade_avg_at_k(labels: Sequence[str], k: int) -> float:
      if k <= 0:
          return 0.0
      sliced = [_normalize_label(label) for label in labels[:k]]
      if not sliced:
          return 0.0
      return sum(float(RELEVANCE_GRADE_MAP.get(label, 0)) for label in sliced) / float(len(sliced))
  
  
  def compute_query_metrics(
      labels: Sequence[str],
      *,
      ideal_labels: Sequence[str] | None = None,
  ) -> Dict[str, float]:
      """Compute graded ranking metrics plus binary diagnostic slices.
  
      `labels` are the ranked results returned by search.
      `ideal_labels` is the judged label pool for the same query; when omitted we fall back
      to the retrieved labels, which still keeps the metrics well-defined.
      """
  
      ideal = list(ideal_labels) if ideal_labels is not None else list(labels)
c81b0fc1   tangwang   scripts/evaluatio...
136
      metrics: Dict[str, float] = {}
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
137
138
139
140
141
  
      exact_hits = _binary_hits(labels, [RELEVANCE_EXACT])
      strong_hits = _binary_hits(labels, RELEVANCE_STRONG)
      useful_hits = _binary_hits(labels, RELEVANCE_NON_IRRELEVANT)
  
c81b0fc1   tangwang   scripts/evaluatio...
142
      for k in (5, 10, 20, 50):
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
143
          metrics[f"NDCG@{k}"] = round(_ndcg_at_k(labels, ideal, k), 6)
30b490e1   tangwang   添加ERR评估指标
144
          metrics[f"ERR@{k}"] = round(_err_at_k(labels, k), 6)
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
145
146
147
148
149
150
151
152
153
154
155
156
      for k in (5, 10, 20):
          metrics[f"Exact_Precision@{k}"] = round(_precision_at_k_from_hits(exact_hits, k), 6)
          metrics[f"Strong_Precision@{k}"] = round(_precision_at_k_from_hits(strong_hits, k), 6)
      for k in (10, 20, 50):
          metrics[f"Useful_Precision@{k}"] = round(_precision_at_k_from_hits(useful_hits, k), 6)
          metrics[f"Gain_Recall@{k}"] = round(_gain_recall_at_k(labels, ideal, k), 6)
      for k in (5, 10):
          metrics[f"Exact_Success@{k}"] = round(_success_at_k_from_hits(exact_hits, k), 6)
          metrics[f"Strong_Success@{k}"] = round(_success_at_k_from_hits(strong_hits, k), 6)
      metrics["MRR_Exact@10"] = round(_reciprocal_rank_from_hits(exact_hits, 10), 6)
      metrics["MRR_Strong@10"] = round(_reciprocal_rank_from_hits(strong_hits, 10), 6)
      metrics["Avg_Grade@10"] = round(_grade_avg_at_k(labels, 10), 6)
465f90e1   tangwang   添加LTR数据收集
157
      metrics["Primary_Metric_Score"] = round(primary_metric_score(metrics), 6)
c81b0fc1   tangwang   scripts/evaluatio...
158
159
160
      return metrics
  
  
465f90e1   tangwang   添加LTR数据收集
161
162
163
164
165
166
167
168
169
170
171
172
173
  def primary_metric_score(metrics: Dict[str, float]) -> float:
      normalized_values: list[float] = []
      for key in PRIMARY_METRIC_KEYS:
          value = float(metrics.get(key, 0.0))
          if key == "Avg_Grade@10":
              normalized_values.append(value / PRIMARY_METRIC_GRADE_NORMALIZER)
          else:
              normalized_values.append(value)
      if not normalized_values:
          return 0.0
      return sum(normalized_values) / float(len(normalized_values))
  
  
c81b0fc1   tangwang   scripts/evaluatio...
174
175
176
  def aggregate_metrics(metric_items: Sequence[Dict[str, float]]) -> Dict[str, float]:
      if not metric_items:
          return {}
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
177
      all_keys = sorted({key for item in metric_items for key in item.keys()})
c81b0fc1   tangwang   scripts/evaluatio...
178
179
      return {
          key: round(sum(float(item.get(key, 0.0)) for item in metric_items) / len(metric_items), 6)
7ddd4cb3   tangwang   评估体系从三等级->四等级 Exa...
180
          for key in all_keys
c81b0fc1   tangwang   scripts/evaluatio...
181
182
183
184
185
186
      }
  
  
  def label_distribution(labels: Sequence[str]) -> Dict[str, int]:
      return {
          RELEVANCE_EXACT: sum(1 for label in labels if label == RELEVANCE_EXACT),
a345b01f   tangwang   eval framework
187
188
          RELEVANCE_HIGH: sum(1 for label in labels if label == RELEVANCE_HIGH),
          RELEVANCE_LOW: sum(1 for label in labels if label == RELEVANCE_LOW),
c81b0fc1   tangwang   scripts/evaluatio...
189
190
          RELEVANCE_IRRELEVANT: sum(1 for label in labels if label == RELEVANCE_IRRELEVANT),
      }