"""IR metrics for labeled result lists.""" from __future__ import annotations from typing import Dict, Sequence from .constants import RELEVANCE_EXACT, RELEVANCE_IRRELEVANT, RELEVANCE_PARTIAL def precision_at_k(labels: Sequence[str], k: int, relevant: Sequence[str]) -> float: if k <= 0: return 0.0 sliced = list(labels[:k]) if not sliced: return 0.0 hits = sum(1 for label in sliced if label in relevant) return hits / float(min(k, len(sliced))) def average_precision(labels: Sequence[str], relevant: Sequence[str]) -> float: hit_count = 0 precision_sum = 0.0 for idx, label in enumerate(labels, start=1): if label not in relevant: continue hit_count += 1 precision_sum += hit_count / idx if hit_count == 0: return 0.0 return precision_sum / hit_count def compute_query_metrics(labels: Sequence[str]) -> Dict[str, float]: metrics: Dict[str, float] = {} for k in (5, 10, 20, 50): metrics[f"P@{k}"] = round(precision_at_k(labels, k, [RELEVANCE_EXACT]), 6) metrics[f"P@{k}_2_3"] = round(precision_at_k(labels, k, [RELEVANCE_EXACT, RELEVANCE_PARTIAL]), 6) metrics["MAP_3"] = round(average_precision(labels, [RELEVANCE_EXACT]), 6) metrics["MAP_2_3"] = round(average_precision(labels, [RELEVANCE_EXACT, RELEVANCE_PARTIAL]), 6) return metrics def aggregate_metrics(metric_items: Sequence[Dict[str, float]]) -> Dict[str, float]: if not metric_items: return {} keys = sorted(metric_items[0].keys()) return { key: round(sum(float(item.get(key, 0.0)) for item in metric_items) / len(metric_items), 6) for key in keys } def label_distribution(labels: Sequence[str]) -> Dict[str, int]: return { RELEVANCE_EXACT: sum(1 for label in labels if label == RELEVANCE_EXACT), RELEVANCE_PARTIAL: sum(1 for label in labels if label == RELEVANCE_PARTIAL), RELEVANCE_IRRELEVANT: sum(1 for label in labels if label == RELEVANCE_IRRELEVANT), }