Commit 30b490e159692c99f66d7eae3d60bf4c8ee6b4d1

Authored by tangwang
1 parent 7ddd4cb3

添加ERR评估指标

scripts/evaluation/eval_framework/constants.py
@@ -26,11 +26,20 @@ RELEVANCE_GRADE_MAP = { @@ -26,11 +26,20 @@ RELEVANCE_GRADE_MAP = {
26 RELEVANCE_LOW: 1, 26 RELEVANCE_LOW: 1,
27 RELEVANCE_IRRELEVANT: 0, 27 RELEVANCE_IRRELEVANT: 0,
28 } 28 }
  29 +
29 RELEVANCE_GAIN_MAP = { 30 RELEVANCE_GAIN_MAP = {
30 label: (2 ** grade) - 1 31 label: (2 ** grade) - 1
31 for label, grade in RELEVANCE_GRADE_MAP.items() 32 for label, grade in RELEVANCE_GRADE_MAP.items()
32 } 33 }
33 34
  35 +# P(stop | relevance) for ERR (Expected Reciprocal Rank); cascade model (Chapelle et al., 2009).
  36 +STOP_PROB_MAP = {
  37 + RELEVANCE_EXACT: 0.99,
  38 + RELEVANCE_HIGH: 0.8,
  39 + RELEVANCE_LOW: 0.1,
  40 + RELEVANCE_IRRELEVANT: 0.0,
  41 +}
  42 +
34 _LEGACY_LABEL_MAP = { 43 _LEGACY_LABEL_MAP = {
35 "Exact": RELEVANCE_EXACT, 44 "Exact": RELEVANCE_EXACT,
36 "Partial": RELEVANCE_HIGH, 45 "Partial": RELEVANCE_HIGH,
scripts/evaluation/eval_framework/framework.py
@@ -28,6 +28,7 @@ from .constants import ( @@ -28,6 +28,7 @@ from .constants import (
28 RELEVANCE_EXACT, 28 RELEVANCE_EXACT,
29 RELEVANCE_GAIN_MAP, 29 RELEVANCE_GAIN_MAP,
30 RELEVANCE_HIGH, 30 RELEVANCE_HIGH,
  31 + STOP_PROB_MAP,
31 RELEVANCE_IRRELEVANT, 32 RELEVANCE_IRRELEVANT,
32 RELEVANCE_LOW, 33 RELEVANCE_LOW,
33 RELEVANCE_NON_IRRELEVANT, 34 RELEVANCE_NON_IRRELEVANT,
@@ -55,8 +56,10 @@ def _metric_context_payload() -> Dict[str, Any]: @@ -55,8 +56,10 @@ def _metric_context_payload() -> Dict[str, Any]:
55 return { 56 return {
56 "primary_metric": "NDCG@10", 57 "primary_metric": "NDCG@10",
57 "gain_scheme": dict(RELEVANCE_GAIN_MAP), 58 "gain_scheme": dict(RELEVANCE_GAIN_MAP),
  59 + "stop_prob_scheme": dict(STOP_PROB_MAP),
58 "notes": [ 60 "notes": [
59 "NDCG uses graded gains derived from the four relevance labels.", 61 "NDCG uses graded gains derived from the four relevance labels.",
  62 + "ERR (Expected Reciprocal Rank) uses per-grade stop probabilities in stop_prob_scheme.",
60 "Strong metrics treat Exact Match and High Relevant as strong business positives.", 63 "Strong metrics treat Exact Match and High Relevant as strong business positives.",
61 "Useful metrics treat any non-irrelevant item as useful recall coverage.", 64 "Useful metrics treat any non-irrelevant item as useful recall coverage.",
62 ], 65 ],
scripts/evaluation/eval_framework/metrics.py
@@ -14,6 +14,7 @@ from .constants import ( @@ -14,6 +14,7 @@ from .constants import (
14 RELEVANCE_LOW, 14 RELEVANCE_LOW,
15 RELEVANCE_NON_IRRELEVANT, 15 RELEVANCE_NON_IRRELEVANT,
16 RELEVANCE_STRONG, 16 RELEVANCE_STRONG,
  17 + STOP_PROB_MAP,
17 ) 18 )
18 19
19 20
@@ -77,6 +78,19 @@ def _ndcg_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> fl @@ -77,6 +78,19 @@ def _ndcg_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> fl
77 return dcg / idcg 78 return dcg / idcg
78 79
79 80
  81 +def _err_at_k(labels: Sequence[str], k: int) -> float:
  82 + """Expected Reciprocal Rank on the first ``k`` positions (truncated ranked list)."""
  83 + if k <= 0:
  84 + return 0.0
  85 + err = 0.0
  86 + product = 1.0
  87 + for i, label in enumerate(labels[:k], start=1):
  88 + p_stop = float(STOP_PROB_MAP.get(_normalize_label(label), 0.0))
  89 + err += (1.0 / float(i)) * p_stop * product
  90 + product *= 1.0 - p_stop
  91 + return err
  92 +
  93 +
80 def _gain_recall_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float: 94 def _gain_recall_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float:
81 ideal_total_gain = sum(_gains_for_labels(ideal_labels)) 95 ideal_total_gain = sum(_gains_for_labels(ideal_labels))
82 if ideal_total_gain <= 0.0: 96 if ideal_total_gain <= 0.0:
@@ -115,6 +129,7 @@ def compute_query_metrics( @@ -115,6 +129,7 @@ def compute_query_metrics(
115 129
116 for k in (5, 10, 20, 50): 130 for k in (5, 10, 20, 50):
117 metrics[f"NDCG@{k}"] = round(_ndcg_at_k(labels, ideal, k), 6) 131 metrics[f"NDCG@{k}"] = round(_ndcg_at_k(labels, ideal, k), 6)
  132 + metrics[f"ERR@{k}"] = round(_err_at_k(labels, k), 6)
118 for k in (5, 10, 20): 133 for k in (5, 10, 20):
119 metrics[f"Exact_Precision@{k}"] = round(_precision_at_k_from_hits(exact_hits, k), 6) 134 metrics[f"Exact_Precision@{k}"] = round(_precision_at_k_from_hits(exact_hits, k), 6)
120 metrics[f"Strong_Precision@{k}"] = round(_precision_at_k_from_hits(strong_hits, k), 6) 135 metrics[f"Strong_Precision@{k}"] = round(_precision_at_k_from_hits(strong_hits, k), 6)
scripts/evaluation/eval_framework/reports.py
@@ -8,7 +8,19 @@ from .constants import RELEVANCE_EXACT, RELEVANCE_HIGH, RELEVANCE_IRRELEVANT, RE @@ -8,7 +8,19 @@ from .constants import RELEVANCE_EXACT, RELEVANCE_HIGH, RELEVANCE_IRRELEVANT, RE
8 8
9 9
10 def _append_metric_block(lines: list[str], metrics: Dict[str, Any]) -> None: 10 def _append_metric_block(lines: list[str], metrics: Dict[str, Any]) -> None:
11 - primary_keys = ("NDCG@5", "NDCG@10", "NDCG@20", "Exact_Precision@10", "Strong_Precision@10", "Gain_Recall@50") 11 + primary_keys = (
  12 + "NDCG@5",
  13 + "NDCG@10",
  14 + "NDCG@20",
  15 + "NDCG@50",
  16 + "ERR@5",
  17 + "ERR@10",
  18 + "ERR@20",
  19 + "ERR@50",
  20 + "Exact_Precision@10",
  21 + "Strong_Precision@10",
  22 + "Gain_Recall@50",
  23 + )
12 included = set() 24 included = set()
13 for key in primary_keys: 25 for key in primary_keys:
14 if key in metrics: 26 if key in metrics:
@@ -38,7 +50,8 @@ def render_batch_report_markdown(payload: Dict[str, Any]) -&gt; str: @@ -38,7 +50,8 @@ def render_batch_report_markdown(payload: Dict[str, Any]) -&gt; str:
38 lines.extend( 50 lines.extend(
39 [ 51 [
40 f"- Primary metric: {metric_context.get('primary_metric', 'N/A')}", 52 f"- Primary metric: {metric_context.get('primary_metric', 'N/A')}",
41 - f"- Gain scheme: {metric_context.get('gain_scheme', {})}", 53 + f"- Gain scheme (NDCG): {metric_context.get('gain_scheme', {})}",
  54 + f"- Stop probabilities (ERR): {metric_context.get('stop_prob_scheme', {})}",
42 "", 55 "",
43 ] 56 ]
44 ) 57 )
scripts/evaluation/eval_framework/static/eval_web.css
@@ -30,15 +30,21 @@ @@ -30,15 +30,21 @@
30 button { border: 0; background: var(--accent); color: white; padding: 12px 16px; border-radius: 14px; cursor: pointer; font-weight: 600; } 30 button { border: 0; background: var(--accent); color: white; padding: 12px 16px; border-radius: 14px; cursor: pointer; font-weight: 600; }
31 button.secondary { background: #d9e6e3; color: #12433d; } 31 button.secondary { background: #d9e6e3; color: #12433d; }
32 .grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(170px, 1fr)); gap: 12px; margin-bottom: 16px; } 32 .grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(170px, 1fr)); gap: 12px; margin-bottom: 16px; }
33 - .metric-context { margin: 0 0 12px; line-height: 1.5; }  
34 - .metric-section { margin-bottom: 18px; }  
35 - .metric-section-head { display: flex; align-items: baseline; justify-content: space-between; gap: 12px; margin-bottom: 10px; }  
36 - .metric-section-head h3 { margin: 0; font-size: 14px; color: #12433d; }  
37 - .metric-section-head p { margin: 0; color: var(--muted); font-size: 12px; }  
38 - .metric-grid { margin-bottom: 0; }  
39 - .metric { background: var(--panel); border: 1px solid var(--line); border-radius: 16px; padding: 14px; }  
40 - .metric .label { font-size: 12px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.04em; }  
41 - .metric .value { font-size: 24px; font-weight: 700; margin-top: 4px; } 33 + .metric-context { margin: 0 0 12px; line-height: 1.5; font-size: 12px; color: var(--muted); }
  34 + .metrics-columns {
  35 + display: flex; flex-wrap: wrap; gap: 12px 18px; align-items: flex-start; margin-bottom: 8px;
  36 + }
  37 + .metric-column {
  38 + min-width: 132px; flex: 0 1 auto; padding: 10px 12px;
  39 + background: var(--panel); border: 1px solid var(--line); border-radius: 12px;
  40 + }
  41 + .metric-column-title {
  42 + margin: 0 0 8px; font-size: 11px; font-weight: 700; color: #12433d;
  43 + text-transform: uppercase; letter-spacing: 0.05em;
  44 + }
  45 + .metric-row { font-size: 13px; line-height: 1.5; font-variant-numeric: tabular-nums; }
  46 + .metric-row-name { color: var(--muted); }
  47 + .metric-row-value { font-weight: 600; color: var(--ink); margin-left: 2px; }
42 .results { display: grid; gap: 10px; } 48 .results { display: grid; gap: 10px; }
43 .result { display: grid; grid-template-columns: 110px 100px 1fr; gap: 14px; align-items: center; background: var(--panel); border: 1px solid var(--line); border-radius: 18px; padding: 12px; } 49 .result { display: grid; grid-template-columns: 110px 100px 1fr; gap: 14px; align-items: center; background: var(--panel); border: 1px solid var(--line); border-radius: 18px; padding: 12px; }
44 .badge { display: inline-block; padding: 8px 10px; border-radius: 999px; color: white; font-weight: 700; text-align: center; } 50 .badge { display: inline-block; padding: 8px 10px; border-radius: 999px; color: white; font-weight: 700; text-align: center; }
@@ -101,7 +107,6 @@ @@ -101,7 +107,6 @@
101 @media (max-width: 960px) { 107 @media (max-width: 960px) {
102 .app { grid-template-columns: 1fr; } 108 .app { grid-template-columns: 1fr; }
103 .sidebar { border-right: 0; border-bottom: 1px solid var(--line); } 109 .sidebar { border-right: 0; border-bottom: 1px solid var(--line); }
104 - .metric-section-head { flex-direction: column; align-items: flex-start; }  
105 } 110 }
106 @media (max-width: 640px) { 111 @media (max-width: 640px) {
107 .main, .sidebar { padding: 16px; } 112 .main, .sidebar { padding: 16px; }
scripts/evaluation/eval_framework/static/eval_web.js
@@ -9,80 +9,106 @@ function fmtNumber(value, digits = 3) { @@ -9,80 +9,106 @@ function fmtNumber(value, digits = 3) {
9 return Number(value).toFixed(digits); 9 return Number(value).toFixed(digits);
10 } 10 }
11 11
12 -function metricSections(metrics) {  
13 - const groups = [ 12 +function metricColumns(metrics) {
  13 + const defs = [
  14 + { title: "NDCG", keys: ["NDCG@5", "NDCG@10", "NDCG@20", "NDCG@50"] },
  15 + { title: "ERR", keys: ["ERR@5", "ERR@10", "ERR@20", "ERR@50"] },
14 { 16 {
15 - title: "Primary Ranking",  
16 - keys: ["NDCG@5", "NDCG@10", "NDCG@20", "NDCG@50"],  
17 - description: "Graded ranking quality across the four relevance tiers.", 17 + title: "Top slot",
  18 + keys: [
  19 + "Exact_Precision@5",
  20 + "Exact_Precision@10",
  21 + "Strong_Precision@5",
  22 + "Strong_Precision@10",
  23 + "Strong_Precision@20",
  24 + ],
18 }, 25 },
19 { 26 {
20 - title: "Top Slot Quality",  
21 - keys: ["Exact_Precision@5", "Exact_Precision@10", "Strong_Precision@5", "Strong_Precision@10", "Strong_Precision@20"],  
22 - description: "How much of the visible top rank is exact or strong business relevance.", 27 + title: "Recall",
  28 + keys: [
  29 + "Useful_Precision@10",
  30 + "Useful_Precision@20",
  31 + "Useful_Precision@50",
  32 + "Gain_Recall@10",
  33 + "Gain_Recall@20",
  34 + "Gain_Recall@50",
  35 + ],
23 }, 36 },
24 { 37 {
25 - title: "Recall Coverage",  
26 - keys: ["Useful_Precision@10", "Useful_Precision@20", "Useful_Precision@50", "Gain_Recall@10", "Gain_Recall@20", "Gain_Recall@50"],  
27 - description: "How much judged relevance is captured in the returned list.",  
28 - },  
29 - {  
30 - title: "First Good Result",  
31 - keys: ["Exact_Success@5", "Exact_Success@10", "Strong_Success@5", "Strong_Success@10", "MRR_Exact@10", "MRR_Strong@10", "Avg_Grade@10"],  
32 - description: "Whether users see a good result early and how good the top page feels overall.", 38 + title: "First good",
  39 + keys: [
  40 + "Exact_Success@5",
  41 + "Exact_Success@10",
  42 + "Strong_Success@5",
  43 + "Strong_Success@10",
  44 + "MRR_Exact@10",
  45 + "MRR_Strong@10",
  46 + "Avg_Grade@10",
  47 + ],
33 }, 48 },
34 ]; 49 ];
35 const seen = new Set(); 50 const seen = new Set();
36 - return groups  
37 - .map((group) => {  
38 - const items = group.keys 51 + const columns = defs
  52 + .map((col) => {
  53 + const rows = col.keys
39 .filter((key) => metrics && Object.prototype.hasOwnProperty.call(metrics, key)) 54 .filter((key) => metrics && Object.prototype.hasOwnProperty.call(metrics, key))
40 .map((key) => { 55 .map((key) => {
41 seen.add(key); 56 seen.add(key);
42 return [key, metrics[key]]; 57 return [key, metrics[key]];
43 }); 58 });
44 - return { ...group, items }; 59 + return { title: col.title, rows };
45 }) 60 })
46 - .filter((group) => group.items.length)  
47 - .concat(  
48 - (() => {  
49 - const rest = Object.entries(metrics || {}).filter(([key]) => !seen.has(key));  
50 - return rest.length  
51 - ? [{ title: "Other Metrics", description: "", items: rest }]  
52 - : [];  
53 - })()  
54 - ); 61 + .filter((col) => col.rows.length);
  62 + const rest = Object.keys(metrics || {})
  63 + .filter((key) => !seen.has(key))
  64 + .sort()
  65 + .map((key) => [key, metrics[key]]);
  66 + if (rest.length) columns.push({ title: "Other", rows: rest });
  67 + return columns;
55 } 68 }
56 69
57 function renderMetrics(metrics, metricContext) { 70 function renderMetrics(metrics, metricContext) {
58 const root = document.getElementById("metrics"); 71 const root = document.getElementById("metrics");
59 root.innerHTML = ""; 72 root.innerHTML = "";
60 const ctx = document.getElementById("metricContext"); 73 const ctx = document.getElementById("metricContext");
61 - const gainScheme = metricContext && metricContext.gain_scheme;  
62 - const primary = metricContext && metricContext.primary_metric;  
63 - ctx.textContent = primary  
64 - ? `Primary metric: ${primary}. Gain scheme: ${Object.entries(gainScheme || {}).map(([label, gain]) => `${label}=${gain}`).join(", ")}.`  
65 - : ""; 74 + const parts = [];
  75 + if (metricContext && metricContext.primary_metric) {
  76 + parts.push(`Primary: ${metricContext.primary_metric}`);
  77 + }
  78 + if (metricContext && metricContext.gain_scheme) {
  79 + parts.push(
  80 + `NDCG gains: ${Object.entries(metricContext.gain_scheme)
  81 + .map(([label, gain]) => `${label}=${gain}`)
  82 + .join(", ")}`
  83 + );
  84 + }
  85 + if (metricContext && metricContext.stop_prob_scheme) {
  86 + parts.push(
  87 + `ERR P(stop): ${Object.entries(metricContext.stop_prob_scheme)
  88 + .map(([label, p]) => `${label}=${p}`)
  89 + .join(", ")}`
  90 + );
  91 + }
  92 + ctx.textContent = parts.length ? `${parts.join(". ")}.` : "";
66 93
67 - metricSections(metrics || {}).forEach((section) => {  
68 - const wrap = document.createElement("section");  
69 - wrap.className = "metric-section";  
70 - wrap.innerHTML = `  
71 - <div class="metric-section-head">  
72 - <h3>${section.title}</h3>  
73 - ${section.description ? `<p>${section.description}</p>` : ""}  
74 - </div>  
75 - <div class="grid metric-grid"></div>  
76 - `;  
77 - const grid = wrap.querySelector(".metric-grid");  
78 - section.items.forEach(([key, value]) => {  
79 - const card = document.createElement("div");  
80 - card.className = "metric";  
81 - card.innerHTML = `<div class="label">${key}</div><div class="value">${fmtNumber(value)}</div>`;  
82 - grid.appendChild(card); 94 + const bar = document.createElement("div");
  95 + bar.className = "metrics-columns";
  96 + metricColumns(metrics || {}).forEach((col) => {
  97 + const column = document.createElement("div");
  98 + column.className = "metric-column";
  99 + const h = document.createElement("h4");
  100 + h.className = "metric-column-title";
  101 + h.textContent = col.title;
  102 + column.appendChild(h);
  103 + col.rows.forEach(([key, value]) => {
  104 + const row = document.createElement("div");
  105 + row.className = "metric-row";
  106 + row.innerHTML = `<span class="metric-row-name">${key}:</span> <span class="metric-row-value">${fmtNumber(value)}</span>`;
  107 + column.appendChild(row);
83 }); 108 });
84 - root.appendChild(wrap); 109 + bar.appendChild(column);
85 }); 110 });
  111 + root.appendChild(bar);
86 } 112 }
87 113
88 function labelBadgeClass(label) { 114 function labelBadgeClass(label) {
@@ -148,6 +174,7 @@ function historySummaryHtml(meta) { @@ -148,6 +174,7 @@ function historySummaryHtml(meta) {
148 const parts = []; 174 const parts = [];
149 if (nq != null) parts.push(`<span>Queries</span> ${nq}`); 175 if (nq != null) parts.push(`<span>Queries</span> ${nq}`);
150 if (m && m["NDCG@10"] != null) parts.push(`<span>NDCG@10</span> ${fmtNumber(m["NDCG@10"])}`); 176 if (m && m["NDCG@10"] != null) parts.push(`<span>NDCG@10</span> ${fmtNumber(m["NDCG@10"])}`);
  177 + if (m && m["ERR@10"] != null) parts.push(`<span>ERR@10</span> ${fmtNumber(m["ERR@10"])}`);
151 if (m && m["Strong_Precision@10"] != null) parts.push(`<span>Strong@10</span> ${fmtNumber(m["Strong_Precision@10"])}`); 178 if (m && m["Strong_Precision@10"] != null) parts.push(`<span>Strong@10</span> ${fmtNumber(m["Strong_Precision@10"])}`);
152 if (m && m["Gain_Recall@50"] != null) parts.push(`<span>Gain Recall@50</span> ${fmtNumber(m["Gain_Recall@50"])}`); 179 if (m && m["Gain_Recall@50"] != null) parts.push(`<span>Gain Recall@50</span> ${fmtNumber(m["Gain_Recall@50"])}`);
153 if (!parts.length) return ""; 180 if (!parts.length) return "";
scripts/evaluation/eval_framework/static/index.html
@@ -31,7 +31,7 @@ @@ -31,7 +31,7 @@
31 <section class="section"> 31 <section class="section">
32 <h2>Metrics</h2> 32 <h2>Metrics</h2>
33 <p id="metricContext" class="muted metric-context"></p> 33 <p id="metricContext" class="muted metric-context"></p>
34 - <div id="metrics" class="grid"></div> 34 + <div id="metrics"></div>
35 </section> 35 </section>
36 <section class="section"> 36 <section class="section">
37 <h2>Top Results</h2> 37 <h2>Top Results</h2>
tests/test_eval_metrics.py 0 → 100644
@@ -0,0 +1,26 @@ @@ -0,0 +1,26 @@
  1 +"""Tests for search evaluation ranking metrics (NDCG, ERR)."""
  2 +
  3 +from scripts.evaluation.eval_framework.constants import (
  4 + RELEVANCE_EXACT,
  5 + RELEVANCE_HIGH,
  6 + RELEVANCE_IRRELEVANT,
  7 + RELEVANCE_LOW,
  8 +)
  9 +from scripts.evaluation.eval_framework.metrics import compute_query_metrics
  10 +
  11 +
  12 +def test_err_matches_documented_three_item_examples():
  13 + # Model A: [Exact, Irrelevant, High] -> ERR ≈ 0.992667
  14 + m_a = compute_query_metrics(
  15 + [RELEVANCE_EXACT, RELEVANCE_IRRELEVANT, RELEVANCE_HIGH],
  16 + ideal_labels=[RELEVANCE_EXACT],
  17 + )
  18 + assert abs(m_a["ERR@5"] - (0.99 + (1.0 / 3.0) * 0.8 * 0.01)) < 1e-5
  19 +
  20 + # Model B: [High, Low, Exact] -> ERR ≈ 0.8694
  21 + m_b = compute_query_metrics(
  22 + [RELEVANCE_HIGH, RELEVANCE_LOW, RELEVANCE_EXACT],
  23 + ideal_labels=[RELEVANCE_EXACT],
  24 + )
  25 + expected_b = 0.8 + 0.5 * 0.1 * 0.2 + (1.0 / 3.0) * 0.99 * 0.18
  26 + assert abs(m_b["ERR@5"] - expected_b) < 1e-5