Commit 30b490e159692c99f66d7eae3d60bf4c8ee6b4d1
1 parent
7ddd4cb3
添加ERR评估指标
Showing
8 changed files
with
162 additions
and
64 deletions
Show diff stats
scripts/evaluation/eval_framework/constants.py
| ... | ... | @@ -26,11 +26,20 @@ RELEVANCE_GRADE_MAP = { |
| 26 | 26 | RELEVANCE_LOW: 1, |
| 27 | 27 | RELEVANCE_IRRELEVANT: 0, |
| 28 | 28 | } |
| 29 | + | |
| 29 | 30 | RELEVANCE_GAIN_MAP = { |
| 30 | 31 | label: (2 ** grade) - 1 |
| 31 | 32 | for label, grade in RELEVANCE_GRADE_MAP.items() |
| 32 | 33 | } |
| 33 | 34 | |
| 35 | +# P(stop | relevance) for ERR (Expected Reciprocal Rank); cascade model (Chapelle et al., 2009). | |
| 36 | +STOP_PROB_MAP = { | |
| 37 | + RELEVANCE_EXACT: 0.99, | |
| 38 | + RELEVANCE_HIGH: 0.8, | |
| 39 | + RELEVANCE_LOW: 0.1, | |
| 40 | + RELEVANCE_IRRELEVANT: 0.0, | |
| 41 | +} | |
| 42 | + | |
| 34 | 43 | _LEGACY_LABEL_MAP = { |
| 35 | 44 | "Exact": RELEVANCE_EXACT, |
| 36 | 45 | "Partial": RELEVANCE_HIGH, | ... | ... |
scripts/evaluation/eval_framework/framework.py
| ... | ... | @@ -28,6 +28,7 @@ from .constants import ( |
| 28 | 28 | RELEVANCE_EXACT, |
| 29 | 29 | RELEVANCE_GAIN_MAP, |
| 30 | 30 | RELEVANCE_HIGH, |
| 31 | + STOP_PROB_MAP, | |
| 31 | 32 | RELEVANCE_IRRELEVANT, |
| 32 | 33 | RELEVANCE_LOW, |
| 33 | 34 | RELEVANCE_NON_IRRELEVANT, |
| ... | ... | @@ -55,8 +56,10 @@ def _metric_context_payload() -> Dict[str, Any]: |
| 55 | 56 | return { |
| 56 | 57 | "primary_metric": "NDCG@10", |
| 57 | 58 | "gain_scheme": dict(RELEVANCE_GAIN_MAP), |
| 59 | + "stop_prob_scheme": dict(STOP_PROB_MAP), | |
| 58 | 60 | "notes": [ |
| 59 | 61 | "NDCG uses graded gains derived from the four relevance labels.", |
| 62 | + "ERR (Expected Reciprocal Rank) uses per-grade stop probabilities in stop_prob_scheme.", | |
| 60 | 63 | "Strong metrics treat Exact Match and High Relevant as strong business positives.", |
| 61 | 64 | "Useful metrics treat any non-irrelevant item as useful recall coverage.", |
| 62 | 65 | ], | ... | ... |
scripts/evaluation/eval_framework/metrics.py
| ... | ... | @@ -14,6 +14,7 @@ from .constants import ( |
| 14 | 14 | RELEVANCE_LOW, |
| 15 | 15 | RELEVANCE_NON_IRRELEVANT, |
| 16 | 16 | RELEVANCE_STRONG, |
| 17 | + STOP_PROB_MAP, | |
| 17 | 18 | ) |
| 18 | 19 | |
| 19 | 20 | |
| ... | ... | @@ -77,6 +78,19 @@ def _ndcg_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> fl |
| 77 | 78 | return dcg / idcg |
| 78 | 79 | |
| 79 | 80 | |
| 81 | +def _err_at_k(labels: Sequence[str], k: int) -> float: | |
| 82 | + """Expected Reciprocal Rank on the first ``k`` positions (truncated ranked list).""" | |
| 83 | + if k <= 0: | |
| 84 | + return 0.0 | |
| 85 | + err = 0.0 | |
| 86 | + product = 1.0 | |
| 87 | + for i, label in enumerate(labels[:k], start=1): | |
| 88 | + p_stop = float(STOP_PROB_MAP.get(_normalize_label(label), 0.0)) | |
| 89 | + err += (1.0 / float(i)) * p_stop * product | |
| 90 | + product *= 1.0 - p_stop | |
| 91 | + return err | |
| 92 | + | |
| 93 | + | |
| 80 | 94 | def _gain_recall_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float: |
| 81 | 95 | ideal_total_gain = sum(_gains_for_labels(ideal_labels)) |
| 82 | 96 | if ideal_total_gain <= 0.0: |
| ... | ... | @@ -115,6 +129,7 @@ def compute_query_metrics( |
| 115 | 129 | |
| 116 | 130 | for k in (5, 10, 20, 50): |
| 117 | 131 | metrics[f"NDCG@{k}"] = round(_ndcg_at_k(labels, ideal, k), 6) |
| 132 | + metrics[f"ERR@{k}"] = round(_err_at_k(labels, k), 6) | |
| 118 | 133 | for k in (5, 10, 20): |
| 119 | 134 | metrics[f"Exact_Precision@{k}"] = round(_precision_at_k_from_hits(exact_hits, k), 6) |
| 120 | 135 | metrics[f"Strong_Precision@{k}"] = round(_precision_at_k_from_hits(strong_hits, k), 6) | ... | ... |
scripts/evaluation/eval_framework/reports.py
| ... | ... | @@ -8,7 +8,19 @@ from .constants import RELEVANCE_EXACT, RELEVANCE_HIGH, RELEVANCE_IRRELEVANT, RE |
| 8 | 8 | |
| 9 | 9 | |
| 10 | 10 | def _append_metric_block(lines: list[str], metrics: Dict[str, Any]) -> None: |
| 11 | - primary_keys = ("NDCG@5", "NDCG@10", "NDCG@20", "Exact_Precision@10", "Strong_Precision@10", "Gain_Recall@50") | |
| 11 | + primary_keys = ( | |
| 12 | + "NDCG@5", | |
| 13 | + "NDCG@10", | |
| 14 | + "NDCG@20", | |
| 15 | + "NDCG@50", | |
| 16 | + "ERR@5", | |
| 17 | + "ERR@10", | |
| 18 | + "ERR@20", | |
| 19 | + "ERR@50", | |
| 20 | + "Exact_Precision@10", | |
| 21 | + "Strong_Precision@10", | |
| 22 | + "Gain_Recall@50", | |
| 23 | + ) | |
| 12 | 24 | included = set() |
| 13 | 25 | for key in primary_keys: |
| 14 | 26 | if key in metrics: |
| ... | ... | @@ -38,7 +50,8 @@ def render_batch_report_markdown(payload: Dict[str, Any]) -> str: |
| 38 | 50 | lines.extend( |
| 39 | 51 | [ |
| 40 | 52 | f"- Primary metric: {metric_context.get('primary_metric', 'N/A')}", |
| 41 | - f"- Gain scheme: {metric_context.get('gain_scheme', {})}", | |
| 53 | + f"- Gain scheme (NDCG): {metric_context.get('gain_scheme', {})}", | |
| 54 | + f"- Stop probabilities (ERR): {metric_context.get('stop_prob_scheme', {})}", | |
| 42 | 55 | "", |
| 43 | 56 | ] |
| 44 | 57 | ) | ... | ... |
scripts/evaluation/eval_framework/static/eval_web.css
| ... | ... | @@ -30,15 +30,21 @@ |
| 30 | 30 | button { border: 0; background: var(--accent); color: white; padding: 12px 16px; border-radius: 14px; cursor: pointer; font-weight: 600; } |
| 31 | 31 | button.secondary { background: #d9e6e3; color: #12433d; } |
| 32 | 32 | .grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(170px, 1fr)); gap: 12px; margin-bottom: 16px; } |
| 33 | - .metric-context { margin: 0 0 12px; line-height: 1.5; } | |
| 34 | - .metric-section { margin-bottom: 18px; } | |
| 35 | - .metric-section-head { display: flex; align-items: baseline; justify-content: space-between; gap: 12px; margin-bottom: 10px; } | |
| 36 | - .metric-section-head h3 { margin: 0; font-size: 14px; color: #12433d; } | |
| 37 | - .metric-section-head p { margin: 0; color: var(--muted); font-size: 12px; } | |
| 38 | - .metric-grid { margin-bottom: 0; } | |
| 39 | - .metric { background: var(--panel); border: 1px solid var(--line); border-radius: 16px; padding: 14px; } | |
| 40 | - .metric .label { font-size: 12px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.04em; } | |
| 41 | - .metric .value { font-size: 24px; font-weight: 700; margin-top: 4px; } | |
| 33 | + .metric-context { margin: 0 0 12px; line-height: 1.5; font-size: 12px; color: var(--muted); } | |
| 34 | + .metrics-columns { | |
| 35 | + display: flex; flex-wrap: wrap; gap: 12px 18px; align-items: flex-start; margin-bottom: 8px; | |
| 36 | + } | |
| 37 | + .metric-column { | |
| 38 | + min-width: 132px; flex: 0 1 auto; padding: 10px 12px; | |
| 39 | + background: var(--panel); border: 1px solid var(--line); border-radius: 12px; | |
| 40 | + } | |
| 41 | + .metric-column-title { | |
| 42 | + margin: 0 0 8px; font-size: 11px; font-weight: 700; color: #12433d; | |
| 43 | + text-transform: uppercase; letter-spacing: 0.05em; | |
| 44 | + } | |
| 45 | + .metric-row { font-size: 13px; line-height: 1.5; font-variant-numeric: tabular-nums; } | |
| 46 | + .metric-row-name { color: var(--muted); } | |
| 47 | + .metric-row-value { font-weight: 600; color: var(--ink); margin-left: 2px; } | |
| 42 | 48 | .results { display: grid; gap: 10px; } |
| 43 | 49 | .result { display: grid; grid-template-columns: 110px 100px 1fr; gap: 14px; align-items: center; background: var(--panel); border: 1px solid var(--line); border-radius: 18px; padding: 12px; } |
| 44 | 50 | .badge { display: inline-block; padding: 8px 10px; border-radius: 999px; color: white; font-weight: 700; text-align: center; } |
| ... | ... | @@ -101,7 +107,6 @@ |
| 101 | 107 | @media (max-width: 960px) { |
| 102 | 108 | .app { grid-template-columns: 1fr; } |
| 103 | 109 | .sidebar { border-right: 0; border-bottom: 1px solid var(--line); } |
| 104 | - .metric-section-head { flex-direction: column; align-items: flex-start; } | |
| 105 | 110 | } |
| 106 | 111 | @media (max-width: 640px) { |
| 107 | 112 | .main, .sidebar { padding: 16px; } | ... | ... |
scripts/evaluation/eval_framework/static/eval_web.js
| ... | ... | @@ -9,80 +9,106 @@ function fmtNumber(value, digits = 3) { |
| 9 | 9 | return Number(value).toFixed(digits); |
| 10 | 10 | } |
| 11 | 11 | |
| 12 | -function metricSections(metrics) { | |
| 13 | - const groups = [ | |
| 12 | +function metricColumns(metrics) { | |
| 13 | + const defs = [ | |
| 14 | + { title: "NDCG", keys: ["NDCG@5", "NDCG@10", "NDCG@20", "NDCG@50"] }, | |
| 15 | + { title: "ERR", keys: ["ERR@5", "ERR@10", "ERR@20", "ERR@50"] }, | |
| 14 | 16 | { |
| 15 | - title: "Primary Ranking", | |
| 16 | - keys: ["NDCG@5", "NDCG@10", "NDCG@20", "NDCG@50"], | |
| 17 | - description: "Graded ranking quality across the four relevance tiers.", | |
| 17 | + title: "Top slot", | |
| 18 | + keys: [ | |
| 19 | + "Exact_Precision@5", | |
| 20 | + "Exact_Precision@10", | |
| 21 | + "Strong_Precision@5", | |
| 22 | + "Strong_Precision@10", | |
| 23 | + "Strong_Precision@20", | |
| 24 | + ], | |
| 18 | 25 | }, |
| 19 | 26 | { |
| 20 | - title: "Top Slot Quality", | |
| 21 | - keys: ["Exact_Precision@5", "Exact_Precision@10", "Strong_Precision@5", "Strong_Precision@10", "Strong_Precision@20"], | |
| 22 | - description: "How much of the visible top rank is exact or strong business relevance.", | |
| 27 | + title: "Recall", | |
| 28 | + keys: [ | |
| 29 | + "Useful_Precision@10", | |
| 30 | + "Useful_Precision@20", | |
| 31 | + "Useful_Precision@50", | |
| 32 | + "Gain_Recall@10", | |
| 33 | + "Gain_Recall@20", | |
| 34 | + "Gain_Recall@50", | |
| 35 | + ], | |
| 23 | 36 | }, |
| 24 | 37 | { |
| 25 | - title: "Recall Coverage", | |
| 26 | - keys: ["Useful_Precision@10", "Useful_Precision@20", "Useful_Precision@50", "Gain_Recall@10", "Gain_Recall@20", "Gain_Recall@50"], | |
| 27 | - description: "How much judged relevance is captured in the returned list.", | |
| 28 | - }, | |
| 29 | - { | |
| 30 | - title: "First Good Result", | |
| 31 | - keys: ["Exact_Success@5", "Exact_Success@10", "Strong_Success@5", "Strong_Success@10", "MRR_Exact@10", "MRR_Strong@10", "Avg_Grade@10"], | |
| 32 | - description: "Whether users see a good result early and how good the top page feels overall.", | |
| 38 | + title: "First good", | |
| 39 | + keys: [ | |
| 40 | + "Exact_Success@5", | |
| 41 | + "Exact_Success@10", | |
| 42 | + "Strong_Success@5", | |
| 43 | + "Strong_Success@10", | |
| 44 | + "MRR_Exact@10", | |
| 45 | + "MRR_Strong@10", | |
| 46 | + "Avg_Grade@10", | |
| 47 | + ], | |
| 33 | 48 | }, |
| 34 | 49 | ]; |
| 35 | 50 | const seen = new Set(); |
| 36 | - return groups | |
| 37 | - .map((group) => { | |
| 38 | - const items = group.keys | |
| 51 | + const columns = defs | |
| 52 | + .map((col) => { | |
| 53 | + const rows = col.keys | |
| 39 | 54 | .filter((key) => metrics && Object.prototype.hasOwnProperty.call(metrics, key)) |
| 40 | 55 | .map((key) => { |
| 41 | 56 | seen.add(key); |
| 42 | 57 | return [key, metrics[key]]; |
| 43 | 58 | }); |
| 44 | - return { ...group, items }; | |
| 59 | + return { title: col.title, rows }; | |
| 45 | 60 | }) |
| 46 | - .filter((group) => group.items.length) | |
| 47 | - .concat( | |
| 48 | - (() => { | |
| 49 | - const rest = Object.entries(metrics || {}).filter(([key]) => !seen.has(key)); | |
| 50 | - return rest.length | |
| 51 | - ? [{ title: "Other Metrics", description: "", items: rest }] | |
| 52 | - : []; | |
| 53 | - })() | |
| 54 | - ); | |
| 61 | + .filter((col) => col.rows.length); | |
| 62 | + const rest = Object.keys(metrics || {}) | |
| 63 | + .filter((key) => !seen.has(key)) | |
| 64 | + .sort() | |
| 65 | + .map((key) => [key, metrics[key]]); | |
| 66 | + if (rest.length) columns.push({ title: "Other", rows: rest }); | |
| 67 | + return columns; | |
| 55 | 68 | } |
| 56 | 69 | |
| 57 | 70 | function renderMetrics(metrics, metricContext) { |
| 58 | 71 | const root = document.getElementById("metrics"); |
| 59 | 72 | root.innerHTML = ""; |
| 60 | 73 | const ctx = document.getElementById("metricContext"); |
| 61 | - const gainScheme = metricContext && metricContext.gain_scheme; | |
| 62 | - const primary = metricContext && metricContext.primary_metric; | |
| 63 | - ctx.textContent = primary | |
| 64 | - ? `Primary metric: ${primary}. Gain scheme: ${Object.entries(gainScheme || {}).map(([label, gain]) => `${label}=${gain}`).join(", ")}.` | |
| 65 | - : ""; | |
| 74 | + const parts = []; | |
| 75 | + if (metricContext && metricContext.primary_metric) { | |
| 76 | + parts.push(`Primary: ${metricContext.primary_metric}`); | |
| 77 | + } | |
| 78 | + if (metricContext && metricContext.gain_scheme) { | |
| 79 | + parts.push( | |
| 80 | + `NDCG gains: ${Object.entries(metricContext.gain_scheme) | |
| 81 | + .map(([label, gain]) => `${label}=${gain}`) | |
| 82 | + .join(", ")}` | |
| 83 | + ); | |
| 84 | + } | |
| 85 | + if (metricContext && metricContext.stop_prob_scheme) { | |
| 86 | + parts.push( | |
| 87 | + `ERR P(stop): ${Object.entries(metricContext.stop_prob_scheme) | |
| 88 | + .map(([label, p]) => `${label}=${p}`) | |
| 89 | + .join(", ")}` | |
| 90 | + ); | |
| 91 | + } | |
| 92 | + ctx.textContent = parts.length ? `${parts.join(". ")}.` : ""; | |
| 66 | 93 | |
| 67 | - metricSections(metrics || {}).forEach((section) => { | |
| 68 | - const wrap = document.createElement("section"); | |
| 69 | - wrap.className = "metric-section"; | |
| 70 | - wrap.innerHTML = ` | |
| 71 | - <div class="metric-section-head"> | |
| 72 | - <h3>${section.title}</h3> | |
| 73 | - ${section.description ? `<p>${section.description}</p>` : ""} | |
| 74 | - </div> | |
| 75 | - <div class="grid metric-grid"></div> | |
| 76 | - `; | |
| 77 | - const grid = wrap.querySelector(".metric-grid"); | |
| 78 | - section.items.forEach(([key, value]) => { | |
| 79 | - const card = document.createElement("div"); | |
| 80 | - card.className = "metric"; | |
| 81 | - card.innerHTML = `<div class="label">${key}</div><div class="value">${fmtNumber(value)}</div>`; | |
| 82 | - grid.appendChild(card); | |
| 94 | + const bar = document.createElement("div"); | |
| 95 | + bar.className = "metrics-columns"; | |
| 96 | + metricColumns(metrics || {}).forEach((col) => { | |
| 97 | + const column = document.createElement("div"); | |
| 98 | + column.className = "metric-column"; | |
| 99 | + const h = document.createElement("h4"); | |
| 100 | + h.className = "metric-column-title"; | |
| 101 | + h.textContent = col.title; | |
| 102 | + column.appendChild(h); | |
| 103 | + col.rows.forEach(([key, value]) => { | |
| 104 | + const row = document.createElement("div"); | |
| 105 | + row.className = "metric-row"; | |
| 106 | + row.innerHTML = `<span class="metric-row-name">${key}:</span> <span class="metric-row-value">${fmtNumber(value)}</span>`; | |
| 107 | + column.appendChild(row); | |
| 83 | 108 | }); |
| 84 | - root.appendChild(wrap); | |
| 109 | + bar.appendChild(column); | |
| 85 | 110 | }); |
| 111 | + root.appendChild(bar); | |
| 86 | 112 | } |
| 87 | 113 | |
| 88 | 114 | function labelBadgeClass(label) { |
| ... | ... | @@ -148,6 +174,7 @@ function historySummaryHtml(meta) { |
| 148 | 174 | const parts = []; |
| 149 | 175 | if (nq != null) parts.push(`<span>Queries</span> ${nq}`); |
| 150 | 176 | if (m && m["NDCG@10"] != null) parts.push(`<span>NDCG@10</span> ${fmtNumber(m["NDCG@10"])}`); |
| 177 | + if (m && m["ERR@10"] != null) parts.push(`<span>ERR@10</span> ${fmtNumber(m["ERR@10"])}`); | |
| 151 | 178 | if (m && m["Strong_Precision@10"] != null) parts.push(`<span>Strong@10</span> ${fmtNumber(m["Strong_Precision@10"])}`); |
| 152 | 179 | if (m && m["Gain_Recall@50"] != null) parts.push(`<span>Gain Recall@50</span> ${fmtNumber(m["Gain_Recall@50"])}`); |
| 153 | 180 | if (!parts.length) return ""; | ... | ... |
scripts/evaluation/eval_framework/static/index.html
| ... | ... | @@ -0,0 +1,26 @@ |
| 1 | +"""Tests for search evaluation ranking metrics (NDCG, ERR).""" | |
| 2 | + | |
| 3 | +from scripts.evaluation.eval_framework.constants import ( | |
| 4 | + RELEVANCE_EXACT, | |
| 5 | + RELEVANCE_HIGH, | |
| 6 | + RELEVANCE_IRRELEVANT, | |
| 7 | + RELEVANCE_LOW, | |
| 8 | +) | |
| 9 | +from scripts.evaluation.eval_framework.metrics import compute_query_metrics | |
| 10 | + | |
| 11 | + | |
| 12 | +def test_err_matches_documented_three_item_examples(): | |
| 13 | + # Model A: [Exact, Irrelevant, High] -> ERR ≈ 0.992667 | |
| 14 | + m_a = compute_query_metrics( | |
| 15 | + [RELEVANCE_EXACT, RELEVANCE_IRRELEVANT, RELEVANCE_HIGH], | |
| 16 | + ideal_labels=[RELEVANCE_EXACT], | |
| 17 | + ) | |
| 18 | + assert abs(m_a["ERR@5"] - (0.99 + (1.0 / 3.0) * 0.8 * 0.01)) < 1e-5 | |
| 19 | + | |
| 20 | + # Model B: [High, Low, Exact] -> ERR ≈ 0.8694 | |
| 21 | + m_b = compute_query_metrics( | |
| 22 | + [RELEVANCE_HIGH, RELEVANCE_LOW, RELEVANCE_EXACT], | |
| 23 | + ideal_labels=[RELEVANCE_EXACT], | |
| 24 | + ) | |
| 25 | + expected_b = 0.8 + 0.5 * 0.1 * 0.2 + (1.0 / 3.0) * 0.99 * 0.18 | |
| 26 | + assert abs(m_b["ERR@5"] - expected_b) < 1e-5 | ... | ... |