7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
1
|
"""Ranking metrics for graded e-commerce relevance labels."""
|
c81b0fc1
tangwang
scripts/evaluatio...
|
2
3
4
|
from __future__ import annotations
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
5
6
|
import math
from typing import Dict, Iterable, Sequence
|
c81b0fc1
tangwang
scripts/evaluatio...
|
7
|
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
8
9
10
11
12
13
14
15
16
|
from .constants import (
RELEVANCE_EXACT,
RELEVANCE_GAIN_MAP,
RELEVANCE_GRADE_MAP,
RELEVANCE_HIGH,
RELEVANCE_IRRELEVANT,
RELEVANCE_LOW,
RELEVANCE_NON_IRRELEVANT,
RELEVANCE_STRONG,
|
30b490e1
tangwang
添加ERR评估指标
|
17
|
STOP_PROB_MAP,
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
18
|
)
|
c81b0fc1
tangwang
scripts/evaluatio...
|
19
20
|
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
def _normalize_label(label: str) -> str:
if label in RELEVANCE_GRADE_MAP:
return label
return RELEVANCE_IRRELEVANT
def _gains_for_labels(labels: Sequence[str]) -> list[float]:
return [float(RELEVANCE_GAIN_MAP.get(_normalize_label(label), 0.0)) for label in labels]
def _binary_hits(labels: Sequence[str], relevant: Iterable[str]) -> list[int]:
relevant_set = set(relevant)
return [1 if _normalize_label(label) in relevant_set else 0 for label in labels]
def _precision_at_k_from_hits(hits: Sequence[int], k: int) -> float:
|
c81b0fc1
tangwang
scripts/evaluatio...
|
37
38
|
if k <= 0:
return 0.0
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
39
|
sliced = list(hits[:k])
|
c81b0fc1
tangwang
scripts/evaluatio...
|
40
41
|
if not sliced:
return 0.0
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
|
return sum(sliced) / float(len(sliced))
def _success_at_k_from_hits(hits: Sequence[int], k: int) -> float:
if k <= 0:
return 0.0
return 1.0 if any(hits[:k]) else 0.0
def _reciprocal_rank_from_hits(hits: Sequence[int], k: int) -> float:
if k <= 0:
return 0.0
for idx, hit in enumerate(hits[:k], start=1):
if hit:
return 1.0 / float(idx)
return 0.0
|
c81b0fc1
tangwang
scripts/evaluatio...
|
58
59
|
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
60
61
62
63
64
65
|
def _dcg_at_k(gains: Sequence[float], k: int) -> float:
if k <= 0:
return 0.0
total = 0.0
for idx, gain in enumerate(gains[:k], start=1):
if gain <= 0.0:
|
c81b0fc1
tangwang
scripts/evaluatio...
|
66
|
continue
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
67
68
69
70
71
72
73
74
75
76
77
78
79
80
|
total += gain / math.log2(idx + 1.0)
return total
def _ndcg_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float:
actual_gains = _gains_for_labels(labels)
ideal_gains = sorted(_gains_for_labels(ideal_labels), reverse=True)
dcg = _dcg_at_k(actual_gains, k)
idcg = _dcg_at_k(ideal_gains, k)
if idcg <= 0.0:
return 0.0
return dcg / idcg
|
30b490e1
tangwang
添加ERR评估指标
|
81
82
83
84
85
86
87
88
89
90
91
92
93
|
def _err_at_k(labels: Sequence[str], k: int) -> float:
"""Expected Reciprocal Rank on the first ``k`` positions (truncated ranked list)."""
if k <= 0:
return 0.0
err = 0.0
product = 1.0
for i, label in enumerate(labels[:k], start=1):
p_stop = float(STOP_PROB_MAP.get(_normalize_label(label), 0.0))
err += (1.0 / float(i)) * p_stop * product
product *= 1.0 - p_stop
return err
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
94
95
96
|
def _gain_recall_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float:
ideal_total_gain = sum(_gains_for_labels(ideal_labels))
if ideal_total_gain <= 0.0:
|
c81b0fc1
tangwang
scripts/evaluatio...
|
97
|
return 0.0
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
98
99
|
actual_gain = sum(_gains_for_labels(labels[:k]))
return actual_gain / ideal_total_gain
|
c81b0fc1
tangwang
scripts/evaluatio...
|
100
101
|
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
def _grade_avg_at_k(labels: Sequence[str], k: int) -> float:
if k <= 0:
return 0.0
sliced = [_normalize_label(label) for label in labels[:k]]
if not sliced:
return 0.0
return sum(float(RELEVANCE_GRADE_MAP.get(label, 0)) for label in sliced) / float(len(sliced))
def compute_query_metrics(
labels: Sequence[str],
*,
ideal_labels: Sequence[str] | None = None,
) -> Dict[str, float]:
"""Compute graded ranking metrics plus binary diagnostic slices.
`labels` are the ranked results returned by search.
`ideal_labels` is the judged label pool for the same query; when omitted we fall back
to the retrieved labels, which still keeps the metrics well-defined.
"""
ideal = list(ideal_labels) if ideal_labels is not None else list(labels)
|
c81b0fc1
tangwang
scripts/evaluatio...
|
124
|
metrics: Dict[str, float] = {}
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
125
126
127
128
129
|
exact_hits = _binary_hits(labels, [RELEVANCE_EXACT])
strong_hits = _binary_hits(labels, RELEVANCE_STRONG)
useful_hits = _binary_hits(labels, RELEVANCE_NON_IRRELEVANT)
|
c81b0fc1
tangwang
scripts/evaluatio...
|
130
|
for k in (5, 10, 20, 50):
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
131
|
metrics[f"NDCG@{k}"] = round(_ndcg_at_k(labels, ideal, k), 6)
|
30b490e1
tangwang
添加ERR评估指标
|
132
|
metrics[f"ERR@{k}"] = round(_err_at_k(labels, k), 6)
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
133
134
135
136
137
138
139
140
141
142
143
144
|
for k in (5, 10, 20):
metrics[f"Exact_Precision@{k}"] = round(_precision_at_k_from_hits(exact_hits, k), 6)
metrics[f"Strong_Precision@{k}"] = round(_precision_at_k_from_hits(strong_hits, k), 6)
for k in (10, 20, 50):
metrics[f"Useful_Precision@{k}"] = round(_precision_at_k_from_hits(useful_hits, k), 6)
metrics[f"Gain_Recall@{k}"] = round(_gain_recall_at_k(labels, ideal, k), 6)
for k in (5, 10):
metrics[f"Exact_Success@{k}"] = round(_success_at_k_from_hits(exact_hits, k), 6)
metrics[f"Strong_Success@{k}"] = round(_success_at_k_from_hits(strong_hits, k), 6)
metrics["MRR_Exact@10"] = round(_reciprocal_rank_from_hits(exact_hits, 10), 6)
metrics["MRR_Strong@10"] = round(_reciprocal_rank_from_hits(strong_hits, 10), 6)
metrics["Avg_Grade@10"] = round(_grade_avg_at_k(labels, 10), 6)
|
c81b0fc1
tangwang
scripts/evaluatio...
|
145
146
147
148
149
150
|
return metrics
def aggregate_metrics(metric_items: Sequence[Dict[str, float]]) -> Dict[str, float]:
if not metric_items:
return {}
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
151
|
all_keys = sorted({key for item in metric_items for key in item.keys()})
|
c81b0fc1
tangwang
scripts/evaluatio...
|
152
153
|
return {
key: round(sum(float(item.get(key, 0.0)) for item in metric_items) / len(metric_items), 6)
|
7ddd4cb3
tangwang
评估体系从三等级->四等级 Exa...
|
154
|
for key in all_keys
|
c81b0fc1
tangwang
scripts/evaluatio...
|
155
156
157
158
159
160
|
}
def label_distribution(labels: Sequence[str]) -> Dict[str, int]:
return {
RELEVANCE_EXACT: sum(1 for label in labels if label == RELEVANCE_EXACT),
|
a345b01f
tangwang
eval framework
|
161
162
|
RELEVANCE_HIGH: sum(1 for label in labels if label == RELEVANCE_HIGH),
RELEVANCE_LOW: sum(1 for label in labels if label == RELEVANCE_LOW),
|
c81b0fc1
tangwang
scripts/evaluatio...
|
163
164
|
RELEVANCE_IRRELEVANT: sum(1 for label in labels if label == RELEVANCE_IRRELEVANT),
}
|