metrics.py 5.49 KB
"""Ranking metrics for graded e-commerce relevance labels."""

from __future__ import annotations

import math
from typing import Dict, Iterable, Sequence

from .constants import (
    RELEVANCE_EXACT,
    RELEVANCE_GAIN_MAP,
    RELEVANCE_GRADE_MAP,
    RELEVANCE_HIGH,
    RELEVANCE_IRRELEVANT,
    RELEVANCE_LOW,
    RELEVANCE_NON_IRRELEVANT,
    RELEVANCE_STRONG,
    STOP_PROB_MAP,
)


def _normalize_label(label: str) -> str:
    if label in RELEVANCE_GRADE_MAP:
        return label
    return RELEVANCE_IRRELEVANT


def _gains_for_labels(labels: Sequence[str]) -> list[float]:
    return [float(RELEVANCE_GAIN_MAP.get(_normalize_label(label), 0.0)) for label in labels]


def _binary_hits(labels: Sequence[str], relevant: Iterable[str]) -> list[int]:
    relevant_set = set(relevant)
    return [1 if _normalize_label(label) in relevant_set else 0 for label in labels]


def _precision_at_k_from_hits(hits: Sequence[int], k: int) -> float:
    if k <= 0:
        return 0.0
    sliced = list(hits[:k])
    if not sliced:
        return 0.0
    return sum(sliced) / float(len(sliced))


def _success_at_k_from_hits(hits: Sequence[int], k: int) -> float:
    if k <= 0:
        return 0.0
    return 1.0 if any(hits[:k]) else 0.0


def _reciprocal_rank_from_hits(hits: Sequence[int], k: int) -> float:
    if k <= 0:
        return 0.0
    for idx, hit in enumerate(hits[:k], start=1):
        if hit:
            return 1.0 / float(idx)
    return 0.0


def _dcg_at_k(gains: Sequence[float], k: int) -> float:
    if k <= 0:
        return 0.0
    total = 0.0
    for idx, gain in enumerate(gains[:k], start=1):
        if gain <= 0.0:
            continue
        total += gain / math.log2(idx + 1.0)
    return total


def _ndcg_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float:
    actual_gains = _gains_for_labels(labels)
    ideal_gains = sorted(_gains_for_labels(ideal_labels), reverse=True)
    dcg = _dcg_at_k(actual_gains, k)
    idcg = _dcg_at_k(ideal_gains, k)
    if idcg <= 0.0:
        return 0.0
    return dcg / idcg


def _err_at_k(labels: Sequence[str], k: int) -> float:
    """Expected Reciprocal Rank on the first ``k`` positions (truncated ranked list)."""
    if k <= 0:
        return 0.0
    err = 0.0
    product = 1.0
    for i, label in enumerate(labels[:k], start=1):
        p_stop = float(STOP_PROB_MAP.get(_normalize_label(label), 0.0))
        err += (1.0 / float(i)) * p_stop * product
        product *= 1.0 - p_stop
    return err


def _gain_recall_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float:
    ideal_total_gain = sum(_gains_for_labels(ideal_labels))
    if ideal_total_gain <= 0.0:
        return 0.0
    actual_gain = sum(_gains_for_labels(labels[:k]))
    return actual_gain / ideal_total_gain


def _grade_avg_at_k(labels: Sequence[str], k: int) -> float:
    if k <= 0:
        return 0.0
    sliced = [_normalize_label(label) for label in labels[:k]]
    if not sliced:
        return 0.0
    return sum(float(RELEVANCE_GRADE_MAP.get(label, 0)) for label in sliced) / float(len(sliced))


def compute_query_metrics(
    labels: Sequence[str],
    *,
    ideal_labels: Sequence[str] | None = None,
) -> Dict[str, float]:
    """Compute graded ranking metrics plus binary diagnostic slices.

    `labels` are the ranked results returned by search.
    `ideal_labels` is the judged label pool for the same query; when omitted we fall back
    to the retrieved labels, which still keeps the metrics well-defined.
    """

    ideal = list(ideal_labels) if ideal_labels is not None else list(labels)
    metrics: Dict[str, float] = {}

    exact_hits = _binary_hits(labels, [RELEVANCE_EXACT])
    strong_hits = _binary_hits(labels, RELEVANCE_STRONG)
    useful_hits = _binary_hits(labels, RELEVANCE_NON_IRRELEVANT)

    for k in (5, 10, 20, 50):
        metrics[f"NDCG@{k}"] = round(_ndcg_at_k(labels, ideal, k), 6)
        metrics[f"ERR@{k}"] = round(_err_at_k(labels, k), 6)
    for k in (5, 10, 20):
        metrics[f"Exact_Precision@{k}"] = round(_precision_at_k_from_hits(exact_hits, k), 6)
        metrics[f"Strong_Precision@{k}"] = round(_precision_at_k_from_hits(strong_hits, k), 6)
    for k in (10, 20, 50):
        metrics[f"Useful_Precision@{k}"] = round(_precision_at_k_from_hits(useful_hits, k), 6)
        metrics[f"Gain_Recall@{k}"] = round(_gain_recall_at_k(labels, ideal, k), 6)
    for k in (5, 10):
        metrics[f"Exact_Success@{k}"] = round(_success_at_k_from_hits(exact_hits, k), 6)
        metrics[f"Strong_Success@{k}"] = round(_success_at_k_from_hits(strong_hits, k), 6)
    metrics["MRR_Exact@10"] = round(_reciprocal_rank_from_hits(exact_hits, 10), 6)
    metrics["MRR_Strong@10"] = round(_reciprocal_rank_from_hits(strong_hits, 10), 6)
    metrics["Avg_Grade@10"] = round(_grade_avg_at_k(labels, 10), 6)
    return metrics


def aggregate_metrics(metric_items: Sequence[Dict[str, float]]) -> Dict[str, float]:
    if not metric_items:
        return {}
    all_keys = sorted({key for item in metric_items for key in item.keys()})
    return {
        key: round(sum(float(item.get(key, 0.0)) for item in metric_items) / len(metric_items), 6)
        for key in all_keys
    }


def label_distribution(labels: Sequence[str]) -> Dict[str, int]:
    return {
        RELEVANCE_EXACT: sum(1 for label in labels if label == RELEVANCE_EXACT),
        RELEVANCE_HIGH: sum(1 for label in labels if label == RELEVANCE_HIGH),
        RELEVANCE_LOW: sum(1 for label in labels if label == RELEVANCE_LOW),
        RELEVANCE_IRRELEVANT: sum(1 for label in labels if label == RELEVANCE_IRRELEVANT),
    }