metrics.py
1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""IR metrics for labeled result lists."""
from __future__ import annotations
from typing import Dict, Sequence
from .constants import RELEVANCE_EXACT, RELEVANCE_IRRELEVANT, RELEVANCE_PARTIAL
def precision_at_k(labels: Sequence[str], k: int, relevant: Sequence[str]) -> float:
if k <= 0:
return 0.0
sliced = list(labels[:k])
if not sliced:
return 0.0
hits = sum(1 for label in sliced if label in relevant)
return hits / float(min(k, len(sliced)))
def average_precision(labels: Sequence[str], relevant: Sequence[str]) -> float:
hit_count = 0
precision_sum = 0.0
for idx, label in enumerate(labels, start=1):
if label not in relevant:
continue
hit_count += 1
precision_sum += hit_count / idx
if hit_count == 0:
return 0.0
return precision_sum / hit_count
def compute_query_metrics(labels: Sequence[str]) -> Dict[str, float]:
metrics: Dict[str, float] = {}
for k in (5, 10, 20, 50):
metrics[f"P@{k}"] = round(precision_at_k(labels, k, [RELEVANCE_EXACT]), 6)
metrics[f"P@{k}_2_3"] = round(precision_at_k(labels, k, [RELEVANCE_EXACT, RELEVANCE_PARTIAL]), 6)
metrics["MAP_3"] = round(average_precision(labels, [RELEVANCE_EXACT]), 6)
metrics["MAP_2_3"] = round(average_precision(labels, [RELEVANCE_EXACT, RELEVANCE_PARTIAL]), 6)
return metrics
def aggregate_metrics(metric_items: Sequence[Dict[str, float]]) -> Dict[str, float]:
if not metric_items:
return {}
keys = sorted(metric_items[0].keys())
return {
key: round(sum(float(item.get(key, 0.0)) for item in metric_items) / len(metric_items), 6)
for key in keys
}
def label_distribution(labels: Sequence[str]) -> Dict[str, int]:
return {
RELEVANCE_EXACT: sum(1 for label in labels if label == RELEVANCE_EXACT),
RELEVANCE_PARTIAL: sum(1 for label in labels if label == RELEVANCE_PARTIAL),
RELEVANCE_IRRELEVANT: sum(1 for label in labels if label == RELEVANCE_IRRELEVANT),
}