diff --git a/scripts/evaluation/eval_framework/constants.py b/scripts/evaluation/eval_framework/constants.py index 19e9194..2b88029 100644 --- a/scripts/evaluation/eval_framework/constants.py +++ b/scripts/evaluation/eval_framework/constants.py @@ -26,11 +26,20 @@ RELEVANCE_GRADE_MAP = { RELEVANCE_LOW: 1, RELEVANCE_IRRELEVANT: 0, } + RELEVANCE_GAIN_MAP = { label: (2 ** grade) - 1 for label, grade in RELEVANCE_GRADE_MAP.items() } +# P(stop | relevance) for ERR (Expected Reciprocal Rank); cascade model (Chapelle et al., 2009). +STOP_PROB_MAP = { + RELEVANCE_EXACT: 0.99, + RELEVANCE_HIGH: 0.8, + RELEVANCE_LOW: 0.1, + RELEVANCE_IRRELEVANT: 0.0, +} + _LEGACY_LABEL_MAP = { "Exact": RELEVANCE_EXACT, "Partial": RELEVANCE_HIGH, diff --git a/scripts/evaluation/eval_framework/framework.py b/scripts/evaluation/eval_framework/framework.py index d71a5d7..a807fbe 100644 --- a/scripts/evaluation/eval_framework/framework.py +++ b/scripts/evaluation/eval_framework/framework.py @@ -28,6 +28,7 @@ from .constants import ( RELEVANCE_EXACT, RELEVANCE_GAIN_MAP, RELEVANCE_HIGH, + STOP_PROB_MAP, RELEVANCE_IRRELEVANT, RELEVANCE_LOW, RELEVANCE_NON_IRRELEVANT, @@ -55,8 +56,10 @@ def _metric_context_payload() -> Dict[str, Any]: return { "primary_metric": "NDCG@10", "gain_scheme": dict(RELEVANCE_GAIN_MAP), + "stop_prob_scheme": dict(STOP_PROB_MAP), "notes": [ "NDCG uses graded gains derived from the four relevance labels.", + "ERR (Expected Reciprocal Rank) uses per-grade stop probabilities in stop_prob_scheme.", "Strong metrics treat Exact Match and High Relevant as strong business positives.", "Useful metrics treat any non-irrelevant item as useful recall coverage.", ], diff --git a/scripts/evaluation/eval_framework/metrics.py b/scripts/evaluation/eval_framework/metrics.py index 7848023..9d73944 100644 --- a/scripts/evaluation/eval_framework/metrics.py +++ b/scripts/evaluation/eval_framework/metrics.py @@ -14,6 +14,7 @@ from .constants import ( RELEVANCE_LOW, RELEVANCE_NON_IRRELEVANT, RELEVANCE_STRONG, + STOP_PROB_MAP, ) @@ -77,6 +78,19 @@ def _ndcg_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> fl return dcg / idcg +def _err_at_k(labels: Sequence[str], k: int) -> float: + """Expected Reciprocal Rank on the first ``k`` positions (truncated ranked list).""" + if k <= 0: + return 0.0 + err = 0.0 + product = 1.0 + for i, label in enumerate(labels[:k], start=1): + p_stop = float(STOP_PROB_MAP.get(_normalize_label(label), 0.0)) + err += (1.0 / float(i)) * p_stop * product + product *= 1.0 - p_stop + return err + + def _gain_recall_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float: ideal_total_gain = sum(_gains_for_labels(ideal_labels)) if ideal_total_gain <= 0.0: @@ -115,6 +129,7 @@ def compute_query_metrics( for k in (5, 10, 20, 50): metrics[f"NDCG@{k}"] = round(_ndcg_at_k(labels, ideal, k), 6) + metrics[f"ERR@{k}"] = round(_err_at_k(labels, k), 6) for k in (5, 10, 20): metrics[f"Exact_Precision@{k}"] = round(_precision_at_k_from_hits(exact_hits, k), 6) metrics[f"Strong_Precision@{k}"] = round(_precision_at_k_from_hits(strong_hits, k), 6) diff --git a/scripts/evaluation/eval_framework/reports.py b/scripts/evaluation/eval_framework/reports.py index 2df34d3..bd60a05 100644 --- a/scripts/evaluation/eval_framework/reports.py +++ b/scripts/evaluation/eval_framework/reports.py @@ -8,7 +8,19 @@ from .constants import RELEVANCE_EXACT, RELEVANCE_HIGH, RELEVANCE_IRRELEVANT, RE def _append_metric_block(lines: list[str], metrics: Dict[str, Any]) -> None: - primary_keys = ("NDCG@5", "NDCG@10", "NDCG@20", "Exact_Precision@10", "Strong_Precision@10", "Gain_Recall@50") + primary_keys = ( + "NDCG@5", + "NDCG@10", + "NDCG@20", + "NDCG@50", + "ERR@5", + "ERR@10", + "ERR@20", + "ERR@50", + "Exact_Precision@10", + "Strong_Precision@10", + "Gain_Recall@50", + ) included = set() for key in primary_keys: if key in metrics: @@ -38,7 +50,8 @@ def render_batch_report_markdown(payload: Dict[str, Any]) -> str: lines.extend( [ f"- Primary metric: {metric_context.get('primary_metric', 'N/A')}", - f"- Gain scheme: {metric_context.get('gain_scheme', {})}", + f"- Gain scheme (NDCG): {metric_context.get('gain_scheme', {})}", + f"- Stop probabilities (ERR): {metric_context.get('stop_prob_scheme', {})}", "", ] ) diff --git a/scripts/evaluation/eval_framework/static/eval_web.css b/scripts/evaluation/eval_framework/static/eval_web.css index 2123d40..0e73cd9 100644 --- a/scripts/evaluation/eval_framework/static/eval_web.css +++ b/scripts/evaluation/eval_framework/static/eval_web.css @@ -30,15 +30,21 @@ button { border: 0; background: var(--accent); color: white; padding: 12px 16px; border-radius: 14px; cursor: pointer; font-weight: 600; } button.secondary { background: #d9e6e3; color: #12433d; } .grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(170px, 1fr)); gap: 12px; margin-bottom: 16px; } - .metric-context { margin: 0 0 12px; line-height: 1.5; } - .metric-section { margin-bottom: 18px; } - .metric-section-head { display: flex; align-items: baseline; justify-content: space-between; gap: 12px; margin-bottom: 10px; } - .metric-section-head h3 { margin: 0; font-size: 14px; color: #12433d; } - .metric-section-head p { margin: 0; color: var(--muted); font-size: 12px; } - .metric-grid { margin-bottom: 0; } - .metric { background: var(--panel); border: 1px solid var(--line); border-radius: 16px; padding: 14px; } - .metric .label { font-size: 12px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.04em; } - .metric .value { font-size: 24px; font-weight: 700; margin-top: 4px; } + .metric-context { margin: 0 0 12px; line-height: 1.5; font-size: 12px; color: var(--muted); } + .metrics-columns { + display: flex; flex-wrap: wrap; gap: 12px 18px; align-items: flex-start; margin-bottom: 8px; + } + .metric-column { + min-width: 132px; flex: 0 1 auto; padding: 10px 12px; + background: var(--panel); border: 1px solid var(--line); border-radius: 12px; + } + .metric-column-title { + margin: 0 0 8px; font-size: 11px; font-weight: 700; color: #12433d; + text-transform: uppercase; letter-spacing: 0.05em; + } + .metric-row { font-size: 13px; line-height: 1.5; font-variant-numeric: tabular-nums; } + .metric-row-name { color: var(--muted); } + .metric-row-value { font-weight: 600; color: var(--ink); margin-left: 2px; } .results { display: grid; gap: 10px; } .result { display: grid; grid-template-columns: 110px 100px 1fr; gap: 14px; align-items: center; background: var(--panel); border: 1px solid var(--line); border-radius: 18px; padding: 12px; } .badge { display: inline-block; padding: 8px 10px; border-radius: 999px; color: white; font-weight: 700; text-align: center; } @@ -101,7 +107,6 @@ @media (max-width: 960px) { .app { grid-template-columns: 1fr; } .sidebar { border-right: 0; border-bottom: 1px solid var(--line); } - .metric-section-head { flex-direction: column; align-items: flex-start; } } @media (max-width: 640px) { .main, .sidebar { padding: 16px; } diff --git a/scripts/evaluation/eval_framework/static/eval_web.js b/scripts/evaluation/eval_framework/static/eval_web.js index ec93f38..a3555c6 100644 --- a/scripts/evaluation/eval_framework/static/eval_web.js +++ b/scripts/evaluation/eval_framework/static/eval_web.js @@ -9,80 +9,106 @@ function fmtNumber(value, digits = 3) { return Number(value).toFixed(digits); } -function metricSections(metrics) { - const groups = [ +function metricColumns(metrics) { + const defs = [ + { title: "NDCG", keys: ["NDCG@5", "NDCG@10", "NDCG@20", "NDCG@50"] }, + { title: "ERR", keys: ["ERR@5", "ERR@10", "ERR@20", "ERR@50"] }, { - title: "Primary Ranking", - keys: ["NDCG@5", "NDCG@10", "NDCG@20", "NDCG@50"], - description: "Graded ranking quality across the four relevance tiers.", + title: "Top slot", + keys: [ + "Exact_Precision@5", + "Exact_Precision@10", + "Strong_Precision@5", + "Strong_Precision@10", + "Strong_Precision@20", + ], }, { - title: "Top Slot Quality", - keys: ["Exact_Precision@5", "Exact_Precision@10", "Strong_Precision@5", "Strong_Precision@10", "Strong_Precision@20"], - description: "How much of the visible top rank is exact or strong business relevance.", + title: "Recall", + keys: [ + "Useful_Precision@10", + "Useful_Precision@20", + "Useful_Precision@50", + "Gain_Recall@10", + "Gain_Recall@20", + "Gain_Recall@50", + ], }, { - title: "Recall Coverage", - keys: ["Useful_Precision@10", "Useful_Precision@20", "Useful_Precision@50", "Gain_Recall@10", "Gain_Recall@20", "Gain_Recall@50"], - description: "How much judged relevance is captured in the returned list.", - }, - { - title: "First Good Result", - keys: ["Exact_Success@5", "Exact_Success@10", "Strong_Success@5", "Strong_Success@10", "MRR_Exact@10", "MRR_Strong@10", "Avg_Grade@10"], - description: "Whether users see a good result early and how good the top page feels overall.", + title: "First good", + keys: [ + "Exact_Success@5", + "Exact_Success@10", + "Strong_Success@5", + "Strong_Success@10", + "MRR_Exact@10", + "MRR_Strong@10", + "Avg_Grade@10", + ], }, ]; const seen = new Set(); - return groups - .map((group) => { - const items = group.keys + const columns = defs + .map((col) => { + const rows = col.keys .filter((key) => metrics && Object.prototype.hasOwnProperty.call(metrics, key)) .map((key) => { seen.add(key); return [key, metrics[key]]; }); - return { ...group, items }; + return { title: col.title, rows }; }) - .filter((group) => group.items.length) - .concat( - (() => { - const rest = Object.entries(metrics || {}).filter(([key]) => !seen.has(key)); - return rest.length - ? [{ title: "Other Metrics", description: "", items: rest }] - : []; - })() - ); + .filter((col) => col.rows.length); + const rest = Object.keys(metrics || {}) + .filter((key) => !seen.has(key)) + .sort() + .map((key) => [key, metrics[key]]); + if (rest.length) columns.push({ title: "Other", rows: rest }); + return columns; } function renderMetrics(metrics, metricContext) { const root = document.getElementById("metrics"); root.innerHTML = ""; const ctx = document.getElementById("metricContext"); - const gainScheme = metricContext && metricContext.gain_scheme; - const primary = metricContext && metricContext.primary_metric; - ctx.textContent = primary - ? `Primary metric: ${primary}. Gain scheme: ${Object.entries(gainScheme || {}).map(([label, gain]) => `${label}=${gain}`).join(", ")}.` - : ""; + const parts = []; + if (metricContext && metricContext.primary_metric) { + parts.push(`Primary: ${metricContext.primary_metric}`); + } + if (metricContext && metricContext.gain_scheme) { + parts.push( + `NDCG gains: ${Object.entries(metricContext.gain_scheme) + .map(([label, gain]) => `${label}=${gain}`) + .join(", ")}` + ); + } + if (metricContext && metricContext.stop_prob_scheme) { + parts.push( + `ERR P(stop): ${Object.entries(metricContext.stop_prob_scheme) + .map(([label, p]) => `${label}=${p}`) + .join(", ")}` + ); + } + ctx.textContent = parts.length ? `${parts.join(". ")}.` : ""; - metricSections(metrics || {}).forEach((section) => { - const wrap = document.createElement("section"); - wrap.className = "metric-section"; - wrap.innerHTML = ` -
-

${section.title}

- ${section.description ? `

${section.description}

` : ""} -
-
- `; - const grid = wrap.querySelector(".metric-grid"); - section.items.forEach(([key, value]) => { - const card = document.createElement("div"); - card.className = "metric"; - card.innerHTML = `
${key}
${fmtNumber(value)}
`; - grid.appendChild(card); + const bar = document.createElement("div"); + bar.className = "metrics-columns"; + metricColumns(metrics || {}).forEach((col) => { + const column = document.createElement("div"); + column.className = "metric-column"; + const h = document.createElement("h4"); + h.className = "metric-column-title"; + h.textContent = col.title; + column.appendChild(h); + col.rows.forEach(([key, value]) => { + const row = document.createElement("div"); + row.className = "metric-row"; + row.innerHTML = `${key}: ${fmtNumber(value)}`; + column.appendChild(row); }); - root.appendChild(wrap); + bar.appendChild(column); }); + root.appendChild(bar); } function labelBadgeClass(label) { @@ -148,6 +174,7 @@ function historySummaryHtml(meta) { const parts = []; if (nq != null) parts.push(`Queries ${nq}`); if (m && m["NDCG@10"] != null) parts.push(`NDCG@10 ${fmtNumber(m["NDCG@10"])}`); + if (m && m["ERR@10"] != null) parts.push(`ERR@10 ${fmtNumber(m["ERR@10"])}`); if (m && m["Strong_Precision@10"] != null) parts.push(`Strong@10 ${fmtNumber(m["Strong_Precision@10"])}`); if (m && m["Gain_Recall@50"] != null) parts.push(`Gain Recall@50 ${fmtNumber(m["Gain_Recall@50"])}`); if (!parts.length) return ""; diff --git a/scripts/evaluation/eval_framework/static/index.html b/scripts/evaluation/eval_framework/static/index.html index 974945d..d4831cb 100644 --- a/scripts/evaluation/eval_framework/static/index.html +++ b/scripts/evaluation/eval_framework/static/index.html @@ -31,7 +31,7 @@

Metrics

-
+

Top Results

diff --git a/tests/test_eval_metrics.py b/tests/test_eval_metrics.py new file mode 100644 index 0000000..19981ea --- /dev/null +++ b/tests/test_eval_metrics.py @@ -0,0 +1,26 @@ +"""Tests for search evaluation ranking metrics (NDCG, ERR).""" + +from scripts.evaluation.eval_framework.constants import ( + RELEVANCE_EXACT, + RELEVANCE_HIGH, + RELEVANCE_IRRELEVANT, + RELEVANCE_LOW, +) +from scripts.evaluation.eval_framework.metrics import compute_query_metrics + + +def test_err_matches_documented_three_item_examples(): + # Model A: [Exact, Irrelevant, High] -> ERR ≈ 0.992667 + m_a = compute_query_metrics( + [RELEVANCE_EXACT, RELEVANCE_IRRELEVANT, RELEVANCE_HIGH], + ideal_labels=[RELEVANCE_EXACT], + ) + assert abs(m_a["ERR@5"] - (0.99 + (1.0 / 3.0) * 0.8 * 0.01)) < 1e-5 + + # Model B: [High, Low, Exact] -> ERR ≈ 0.8694 + m_b = compute_query_metrics( + [RELEVANCE_HIGH, RELEVANCE_LOW, RELEVANCE_EXACT], + ideal_labels=[RELEVANCE_EXACT], + ) + expected_b = 0.8 + 0.5 * 0.1 * 0.2 + (1.0 / 3.0) * 0.99 * 0.18 + assert abs(m_b["ERR@5"] - expected_b) < 1e-5 -- libgit2 0.21.2