metrics.py 1.99 KB
"""IR metrics for labeled result lists."""

from __future__ import annotations

from typing import Dict, Sequence

from .constants import RELEVANCE_EXACT, RELEVANCE_IRRELEVANT, RELEVANCE_PARTIAL


def precision_at_k(labels: Sequence[str], k: int, relevant: Sequence[str]) -> float:
    if k <= 0:
        return 0.0
    sliced = list(labels[:k])
    if not sliced:
        return 0.0
    hits = sum(1 for label in sliced if label in relevant)
    return hits / float(min(k, len(sliced)))


def average_precision(labels: Sequence[str], relevant: Sequence[str]) -> float:
    hit_count = 0
    precision_sum = 0.0
    for idx, label in enumerate(labels, start=1):
        if label not in relevant:
            continue
        hit_count += 1
        precision_sum += hit_count / idx
    if hit_count == 0:
        return 0.0
    return precision_sum / hit_count


def compute_query_metrics(labels: Sequence[str]) -> Dict[str, float]:
    metrics: Dict[str, float] = {}
    for k in (5, 10, 20, 50):
        metrics[f"P@{k}"] = round(precision_at_k(labels, k, [RELEVANCE_EXACT]), 6)
        metrics[f"P@{k}_2_3"] = round(precision_at_k(labels, k, [RELEVANCE_EXACT, RELEVANCE_PARTIAL]), 6)
    metrics["MAP_3"] = round(average_precision(labels, [RELEVANCE_EXACT]), 6)
    metrics["MAP_2_3"] = round(average_precision(labels, [RELEVANCE_EXACT, RELEVANCE_PARTIAL]), 6)
    return metrics


def aggregate_metrics(metric_items: Sequence[Dict[str, float]]) -> Dict[str, float]:
    if not metric_items:
        return {}
    keys = sorted(metric_items[0].keys())
    return {
        key: round(sum(float(item.get(key, 0.0)) for item in metric_items) / len(metric_items), 6)
        for key in keys
    }


def label_distribution(labels: Sequence[str]) -> Dict[str, int]:
    return {
        RELEVANCE_EXACT: sum(1 for label in labels if label == RELEVANCE_EXACT),
        RELEVANCE_PARTIAL: sum(1 for label in labels if label == RELEVANCE_PARTIAL),
        RELEVANCE_IRRELEVANT: sum(1 for label in labels if label == RELEVANCE_IRRELEVANT),
    }