"""Ranking metrics for graded e-commerce relevance labels.""" from __future__ import annotations import math from typing import Dict, Iterable, Sequence from .constants import ( RELEVANCE_EXACT, RELEVANCE_GAIN_MAP, RELEVANCE_GRADE_MAP, RELEVANCE_HIGH, RELEVANCE_IRRELEVANT, RELEVANCE_LOW, RELEVANCE_NON_IRRELEVANT, RELEVANCE_STRONG, STOP_PROB_MAP, ) def _normalize_label(label: str) -> str: if label in RELEVANCE_GRADE_MAP: return label return RELEVANCE_IRRELEVANT def _gains_for_labels(labels: Sequence[str]) -> list[float]: return [float(RELEVANCE_GAIN_MAP.get(_normalize_label(label), 0.0)) for label in labels] def _binary_hits(labels: Sequence[str], relevant: Iterable[str]) -> list[int]: relevant_set = set(relevant) return [1 if _normalize_label(label) in relevant_set else 0 for label in labels] def _precision_at_k_from_hits(hits: Sequence[int], k: int) -> float: if k <= 0: return 0.0 sliced = list(hits[:k]) if not sliced: return 0.0 return sum(sliced) / float(len(sliced)) def _success_at_k_from_hits(hits: Sequence[int], k: int) -> float: if k <= 0: return 0.0 return 1.0 if any(hits[:k]) else 0.0 def _reciprocal_rank_from_hits(hits: Sequence[int], k: int) -> float: if k <= 0: return 0.0 for idx, hit in enumerate(hits[:k], start=1): if hit: return 1.0 / float(idx) return 0.0 def _dcg_at_k(gains: Sequence[float], k: int) -> float: if k <= 0: return 0.0 total = 0.0 for idx, gain in enumerate(gains[:k], start=1): if gain <= 0.0: continue total += gain / math.log2(idx + 1.0) return total def _ndcg_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float: actual_gains = _gains_for_labels(labels) ideal_gains = sorted(_gains_for_labels(ideal_labels), reverse=True) dcg = _dcg_at_k(actual_gains, k) idcg = _dcg_at_k(ideal_gains, k) if idcg <= 0.0: return 0.0 return dcg / idcg def _err_at_k(labels: Sequence[str], k: int) -> float: """Expected Reciprocal Rank on the first ``k`` positions (truncated ranked list).""" if k <= 0: return 0.0 err = 0.0 product = 1.0 for i, label in enumerate(labels[:k], start=1): p_stop = float(STOP_PROB_MAP.get(_normalize_label(label), 0.0)) err += (1.0 / float(i)) * p_stop * product product *= 1.0 - p_stop return err def _gain_recall_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float: ideal_total_gain = sum(_gains_for_labels(ideal_labels)) if ideal_total_gain <= 0.0: return 0.0 actual_gain = sum(_gains_for_labels(labels[:k])) return actual_gain / ideal_total_gain def _grade_avg_at_k(labels: Sequence[str], k: int) -> float: if k <= 0: return 0.0 sliced = [_normalize_label(label) for label in labels[:k]] if not sliced: return 0.0 return sum(float(RELEVANCE_GRADE_MAP.get(label, 0)) for label in sliced) / float(len(sliced)) def compute_query_metrics( labels: Sequence[str], *, ideal_labels: Sequence[str] | None = None, ) -> Dict[str, float]: """Compute graded ranking metrics plus binary diagnostic slices. `labels` are the ranked results returned by search. `ideal_labels` is the judged label pool for the same query; when omitted we fall back to the retrieved labels, which still keeps the metrics well-defined. """ ideal = list(ideal_labels) if ideal_labels is not None else list(labels) metrics: Dict[str, float] = {} exact_hits = _binary_hits(labels, [RELEVANCE_EXACT]) strong_hits = _binary_hits(labels, RELEVANCE_STRONG) useful_hits = _binary_hits(labels, RELEVANCE_NON_IRRELEVANT) for k in (5, 10, 20, 50): metrics[f"NDCG@{k}"] = round(_ndcg_at_k(labels, ideal, k), 6) metrics[f"ERR@{k}"] = round(_err_at_k(labels, k), 6) for k in (5, 10, 20): metrics[f"Exact_Precision@{k}"] = round(_precision_at_k_from_hits(exact_hits, k), 6) metrics[f"Strong_Precision@{k}"] = round(_precision_at_k_from_hits(strong_hits, k), 6) for k in (10, 20, 50): metrics[f"Useful_Precision@{k}"] = round(_precision_at_k_from_hits(useful_hits, k), 6) metrics[f"Gain_Recall@{k}"] = round(_gain_recall_at_k(labels, ideal, k), 6) for k in (5, 10): metrics[f"Exact_Success@{k}"] = round(_success_at_k_from_hits(exact_hits, k), 6) metrics[f"Strong_Success@{k}"] = round(_success_at_k_from_hits(strong_hits, k), 6) metrics["MRR_Exact@10"] = round(_reciprocal_rank_from_hits(exact_hits, 10), 6) metrics["MRR_Strong@10"] = round(_reciprocal_rank_from_hits(strong_hits, 10), 6) metrics["Avg_Grade@10"] = round(_grade_avg_at_k(labels, 10), 6) return metrics def aggregate_metrics(metric_items: Sequence[Dict[str, float]]) -> Dict[str, float]: if not metric_items: return {} all_keys = sorted({key for item in metric_items for key in item.keys()}) return { key: round(sum(float(item.get(key, 0.0)) for item in metric_items) / len(metric_items), 6) for key in all_keys } def label_distribution(labels: Sequence[str]) -> Dict[str, int]: return { RELEVANCE_EXACT: sum(1 for label in labels if label == RELEVANCE_EXACT), RELEVANCE_HIGH: sum(1 for label in labels if label == RELEVANCE_HIGH), RELEVANCE_LOW: sum(1 for label in labels if label == RELEVANCE_LOW), RELEVANCE_IRRELEVANT: sum(1 for label in labels if label == RELEVANCE_IRRELEVANT), }