Commit 30b490e159692c99f66d7eae3d60bf4c8ee6b4d1

Authored by tangwang
1 parent 7ddd4cb3

添加ERR评估指标

scripts/evaluation/eval_framework/constants.py
... ... @@ -26,11 +26,20 @@ RELEVANCE_GRADE_MAP = {
26 26 RELEVANCE_LOW: 1,
27 27 RELEVANCE_IRRELEVANT: 0,
28 28 }
  29 +
29 30 RELEVANCE_GAIN_MAP = {
30 31 label: (2 ** grade) - 1
31 32 for label, grade in RELEVANCE_GRADE_MAP.items()
32 33 }
33 34  
  35 +# P(stop | relevance) for ERR (Expected Reciprocal Rank); cascade model (Chapelle et al., 2009).
  36 +STOP_PROB_MAP = {
  37 + RELEVANCE_EXACT: 0.99,
  38 + RELEVANCE_HIGH: 0.8,
  39 + RELEVANCE_LOW: 0.1,
  40 + RELEVANCE_IRRELEVANT: 0.0,
  41 +}
  42 +
34 43 _LEGACY_LABEL_MAP = {
35 44 "Exact": RELEVANCE_EXACT,
36 45 "Partial": RELEVANCE_HIGH,
... ...
scripts/evaluation/eval_framework/framework.py
... ... @@ -28,6 +28,7 @@ from .constants import (
28 28 RELEVANCE_EXACT,
29 29 RELEVANCE_GAIN_MAP,
30 30 RELEVANCE_HIGH,
  31 + STOP_PROB_MAP,
31 32 RELEVANCE_IRRELEVANT,
32 33 RELEVANCE_LOW,
33 34 RELEVANCE_NON_IRRELEVANT,
... ... @@ -55,8 +56,10 @@ def _metric_context_payload() -> Dict[str, Any]:
55 56 return {
56 57 "primary_metric": "NDCG@10",
57 58 "gain_scheme": dict(RELEVANCE_GAIN_MAP),
  59 + "stop_prob_scheme": dict(STOP_PROB_MAP),
58 60 "notes": [
59 61 "NDCG uses graded gains derived from the four relevance labels.",
  62 + "ERR (Expected Reciprocal Rank) uses per-grade stop probabilities in stop_prob_scheme.",
60 63 "Strong metrics treat Exact Match and High Relevant as strong business positives.",
61 64 "Useful metrics treat any non-irrelevant item as useful recall coverage.",
62 65 ],
... ...
scripts/evaluation/eval_framework/metrics.py
... ... @@ -14,6 +14,7 @@ from .constants import (
14 14 RELEVANCE_LOW,
15 15 RELEVANCE_NON_IRRELEVANT,
16 16 RELEVANCE_STRONG,
  17 + STOP_PROB_MAP,
17 18 )
18 19  
19 20  
... ... @@ -77,6 +78,19 @@ def _ndcg_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> fl
77 78 return dcg / idcg
78 79  
79 80  
  81 +def _err_at_k(labels: Sequence[str], k: int) -> float:
  82 + """Expected Reciprocal Rank on the first ``k`` positions (truncated ranked list)."""
  83 + if k <= 0:
  84 + return 0.0
  85 + err = 0.0
  86 + product = 1.0
  87 + for i, label in enumerate(labels[:k], start=1):
  88 + p_stop = float(STOP_PROB_MAP.get(_normalize_label(label), 0.0))
  89 + err += (1.0 / float(i)) * p_stop * product
  90 + product *= 1.0 - p_stop
  91 + return err
  92 +
  93 +
80 94 def _gain_recall_at_k(labels: Sequence[str], ideal_labels: Sequence[str], k: int) -> float:
81 95 ideal_total_gain = sum(_gains_for_labels(ideal_labels))
82 96 if ideal_total_gain <= 0.0:
... ... @@ -115,6 +129,7 @@ def compute_query_metrics(
115 129  
116 130 for k in (5, 10, 20, 50):
117 131 metrics[f"NDCG@{k}"] = round(_ndcg_at_k(labels, ideal, k), 6)
  132 + metrics[f"ERR@{k}"] = round(_err_at_k(labels, k), 6)
118 133 for k in (5, 10, 20):
119 134 metrics[f"Exact_Precision@{k}"] = round(_precision_at_k_from_hits(exact_hits, k), 6)
120 135 metrics[f"Strong_Precision@{k}"] = round(_precision_at_k_from_hits(strong_hits, k), 6)
... ...
scripts/evaluation/eval_framework/reports.py
... ... @@ -8,7 +8,19 @@ from .constants import RELEVANCE_EXACT, RELEVANCE_HIGH, RELEVANCE_IRRELEVANT, RE
8 8  
9 9  
10 10 def _append_metric_block(lines: list[str], metrics: Dict[str, Any]) -> None:
11   - primary_keys = ("NDCG@5", "NDCG@10", "NDCG@20", "Exact_Precision@10", "Strong_Precision@10", "Gain_Recall@50")
  11 + primary_keys = (
  12 + "NDCG@5",
  13 + "NDCG@10",
  14 + "NDCG@20",
  15 + "NDCG@50",
  16 + "ERR@5",
  17 + "ERR@10",
  18 + "ERR@20",
  19 + "ERR@50",
  20 + "Exact_Precision@10",
  21 + "Strong_Precision@10",
  22 + "Gain_Recall@50",
  23 + )
12 24 included = set()
13 25 for key in primary_keys:
14 26 if key in metrics:
... ... @@ -38,7 +50,8 @@ def render_batch_report_markdown(payload: Dict[str, Any]) -&gt; str:
38 50 lines.extend(
39 51 [
40 52 f"- Primary metric: {metric_context.get('primary_metric', 'N/A')}",
41   - f"- Gain scheme: {metric_context.get('gain_scheme', {})}",
  53 + f"- Gain scheme (NDCG): {metric_context.get('gain_scheme', {})}",
  54 + f"- Stop probabilities (ERR): {metric_context.get('stop_prob_scheme', {})}",
42 55 "",
43 56 ]
44 57 )
... ...
scripts/evaluation/eval_framework/static/eval_web.css
... ... @@ -30,15 +30,21 @@
30 30 button { border: 0; background: var(--accent); color: white; padding: 12px 16px; border-radius: 14px; cursor: pointer; font-weight: 600; }
31 31 button.secondary { background: #d9e6e3; color: #12433d; }
32 32 .grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(170px, 1fr)); gap: 12px; margin-bottom: 16px; }
33   - .metric-context { margin: 0 0 12px; line-height: 1.5; }
34   - .metric-section { margin-bottom: 18px; }
35   - .metric-section-head { display: flex; align-items: baseline; justify-content: space-between; gap: 12px; margin-bottom: 10px; }
36   - .metric-section-head h3 { margin: 0; font-size: 14px; color: #12433d; }
37   - .metric-section-head p { margin: 0; color: var(--muted); font-size: 12px; }
38   - .metric-grid { margin-bottom: 0; }
39   - .metric { background: var(--panel); border: 1px solid var(--line); border-radius: 16px; padding: 14px; }
40   - .metric .label { font-size: 12px; color: var(--muted); text-transform: uppercase; letter-spacing: 0.04em; }
41   - .metric .value { font-size: 24px; font-weight: 700; margin-top: 4px; }
  33 + .metric-context { margin: 0 0 12px; line-height: 1.5; font-size: 12px; color: var(--muted); }
  34 + .metrics-columns {
  35 + display: flex; flex-wrap: wrap; gap: 12px 18px; align-items: flex-start; margin-bottom: 8px;
  36 + }
  37 + .metric-column {
  38 + min-width: 132px; flex: 0 1 auto; padding: 10px 12px;
  39 + background: var(--panel); border: 1px solid var(--line); border-radius: 12px;
  40 + }
  41 + .metric-column-title {
  42 + margin: 0 0 8px; font-size: 11px; font-weight: 700; color: #12433d;
  43 + text-transform: uppercase; letter-spacing: 0.05em;
  44 + }
  45 + .metric-row { font-size: 13px; line-height: 1.5; font-variant-numeric: tabular-nums; }
  46 + .metric-row-name { color: var(--muted); }
  47 + .metric-row-value { font-weight: 600; color: var(--ink); margin-left: 2px; }
42 48 .results { display: grid; gap: 10px; }
43 49 .result { display: grid; grid-template-columns: 110px 100px 1fr; gap: 14px; align-items: center; background: var(--panel); border: 1px solid var(--line); border-radius: 18px; padding: 12px; }
44 50 .badge { display: inline-block; padding: 8px 10px; border-radius: 999px; color: white; font-weight: 700; text-align: center; }
... ... @@ -101,7 +107,6 @@
101 107 @media (max-width: 960px) {
102 108 .app { grid-template-columns: 1fr; }
103 109 .sidebar { border-right: 0; border-bottom: 1px solid var(--line); }
104   - .metric-section-head { flex-direction: column; align-items: flex-start; }
105 110 }
106 111 @media (max-width: 640px) {
107 112 .main, .sidebar { padding: 16px; }
... ...
scripts/evaluation/eval_framework/static/eval_web.js
... ... @@ -9,80 +9,106 @@ function fmtNumber(value, digits = 3) {
9 9 return Number(value).toFixed(digits);
10 10 }
11 11  
12   -function metricSections(metrics) {
13   - const groups = [
  12 +function metricColumns(metrics) {
  13 + const defs = [
  14 + { title: "NDCG", keys: ["NDCG@5", "NDCG@10", "NDCG@20", "NDCG@50"] },
  15 + { title: "ERR", keys: ["ERR@5", "ERR@10", "ERR@20", "ERR@50"] },
14 16 {
15   - title: "Primary Ranking",
16   - keys: ["NDCG@5", "NDCG@10", "NDCG@20", "NDCG@50"],
17   - description: "Graded ranking quality across the four relevance tiers.",
  17 + title: "Top slot",
  18 + keys: [
  19 + "Exact_Precision@5",
  20 + "Exact_Precision@10",
  21 + "Strong_Precision@5",
  22 + "Strong_Precision@10",
  23 + "Strong_Precision@20",
  24 + ],
18 25 },
19 26 {
20   - title: "Top Slot Quality",
21   - keys: ["Exact_Precision@5", "Exact_Precision@10", "Strong_Precision@5", "Strong_Precision@10", "Strong_Precision@20"],
22   - description: "How much of the visible top rank is exact or strong business relevance.",
  27 + title: "Recall",
  28 + keys: [
  29 + "Useful_Precision@10",
  30 + "Useful_Precision@20",
  31 + "Useful_Precision@50",
  32 + "Gain_Recall@10",
  33 + "Gain_Recall@20",
  34 + "Gain_Recall@50",
  35 + ],
23 36 },
24 37 {
25   - title: "Recall Coverage",
26   - keys: ["Useful_Precision@10", "Useful_Precision@20", "Useful_Precision@50", "Gain_Recall@10", "Gain_Recall@20", "Gain_Recall@50"],
27   - description: "How much judged relevance is captured in the returned list.",
28   - },
29   - {
30   - title: "First Good Result",
31   - keys: ["Exact_Success@5", "Exact_Success@10", "Strong_Success@5", "Strong_Success@10", "MRR_Exact@10", "MRR_Strong@10", "Avg_Grade@10"],
32   - description: "Whether users see a good result early and how good the top page feels overall.",
  38 + title: "First good",
  39 + keys: [
  40 + "Exact_Success@5",
  41 + "Exact_Success@10",
  42 + "Strong_Success@5",
  43 + "Strong_Success@10",
  44 + "MRR_Exact@10",
  45 + "MRR_Strong@10",
  46 + "Avg_Grade@10",
  47 + ],
33 48 },
34 49 ];
35 50 const seen = new Set();
36   - return groups
37   - .map((group) => {
38   - const items = group.keys
  51 + const columns = defs
  52 + .map((col) => {
  53 + const rows = col.keys
39 54 .filter((key) => metrics && Object.prototype.hasOwnProperty.call(metrics, key))
40 55 .map((key) => {
41 56 seen.add(key);
42 57 return [key, metrics[key]];
43 58 });
44   - return { ...group, items };
  59 + return { title: col.title, rows };
45 60 })
46   - .filter((group) => group.items.length)
47   - .concat(
48   - (() => {
49   - const rest = Object.entries(metrics || {}).filter(([key]) => !seen.has(key));
50   - return rest.length
51   - ? [{ title: "Other Metrics", description: "", items: rest }]
52   - : [];
53   - })()
54   - );
  61 + .filter((col) => col.rows.length);
  62 + const rest = Object.keys(metrics || {})
  63 + .filter((key) => !seen.has(key))
  64 + .sort()
  65 + .map((key) => [key, metrics[key]]);
  66 + if (rest.length) columns.push({ title: "Other", rows: rest });
  67 + return columns;
55 68 }
56 69  
57 70 function renderMetrics(metrics, metricContext) {
58 71 const root = document.getElementById("metrics");
59 72 root.innerHTML = "";
60 73 const ctx = document.getElementById("metricContext");
61   - const gainScheme = metricContext && metricContext.gain_scheme;
62   - const primary = metricContext && metricContext.primary_metric;
63   - ctx.textContent = primary
64   - ? `Primary metric: ${primary}. Gain scheme: ${Object.entries(gainScheme || {}).map(([label, gain]) => `${label}=${gain}`).join(", ")}.`
65   - : "";
  74 + const parts = [];
  75 + if (metricContext && metricContext.primary_metric) {
  76 + parts.push(`Primary: ${metricContext.primary_metric}`);
  77 + }
  78 + if (metricContext && metricContext.gain_scheme) {
  79 + parts.push(
  80 + `NDCG gains: ${Object.entries(metricContext.gain_scheme)
  81 + .map(([label, gain]) => `${label}=${gain}`)
  82 + .join(", ")}`
  83 + );
  84 + }
  85 + if (metricContext && metricContext.stop_prob_scheme) {
  86 + parts.push(
  87 + `ERR P(stop): ${Object.entries(metricContext.stop_prob_scheme)
  88 + .map(([label, p]) => `${label}=${p}`)
  89 + .join(", ")}`
  90 + );
  91 + }
  92 + ctx.textContent = parts.length ? `${parts.join(". ")}.` : "";
66 93  
67   - metricSections(metrics || {}).forEach((section) => {
68   - const wrap = document.createElement("section");
69   - wrap.className = "metric-section";
70   - wrap.innerHTML = `
71   - <div class="metric-section-head">
72   - <h3>${section.title}</h3>
73   - ${section.description ? `<p>${section.description}</p>` : ""}
74   - </div>
75   - <div class="grid metric-grid"></div>
76   - `;
77   - const grid = wrap.querySelector(".metric-grid");
78   - section.items.forEach(([key, value]) => {
79   - const card = document.createElement("div");
80   - card.className = "metric";
81   - card.innerHTML = `<div class="label">${key}</div><div class="value">${fmtNumber(value)}</div>`;
82   - grid.appendChild(card);
  94 + const bar = document.createElement("div");
  95 + bar.className = "metrics-columns";
  96 + metricColumns(metrics || {}).forEach((col) => {
  97 + const column = document.createElement("div");
  98 + column.className = "metric-column";
  99 + const h = document.createElement("h4");
  100 + h.className = "metric-column-title";
  101 + h.textContent = col.title;
  102 + column.appendChild(h);
  103 + col.rows.forEach(([key, value]) => {
  104 + const row = document.createElement("div");
  105 + row.className = "metric-row";
  106 + row.innerHTML = `<span class="metric-row-name">${key}:</span> <span class="metric-row-value">${fmtNumber(value)}</span>`;
  107 + column.appendChild(row);
83 108 });
84   - root.appendChild(wrap);
  109 + bar.appendChild(column);
85 110 });
  111 + root.appendChild(bar);
86 112 }
87 113  
88 114 function labelBadgeClass(label) {
... ... @@ -148,6 +174,7 @@ function historySummaryHtml(meta) {
148 174 const parts = [];
149 175 if (nq != null) parts.push(`<span>Queries</span> ${nq}`);
150 176 if (m && m["NDCG@10"] != null) parts.push(`<span>NDCG@10</span> ${fmtNumber(m["NDCG@10"])}`);
  177 + if (m && m["ERR@10"] != null) parts.push(`<span>ERR@10</span> ${fmtNumber(m["ERR@10"])}`);
151 178 if (m && m["Strong_Precision@10"] != null) parts.push(`<span>Strong@10</span> ${fmtNumber(m["Strong_Precision@10"])}`);
152 179 if (m && m["Gain_Recall@50"] != null) parts.push(`<span>Gain Recall@50</span> ${fmtNumber(m["Gain_Recall@50"])}`);
153 180 if (!parts.length) return "";
... ...
scripts/evaluation/eval_framework/static/index.html
... ... @@ -31,7 +31,7 @@
31 31 <section class="section">
32 32 <h2>Metrics</h2>
33 33 <p id="metricContext" class="muted metric-context"></p>
34   - <div id="metrics" class="grid"></div>
  34 + <div id="metrics"></div>
35 35 </section>
36 36 <section class="section">
37 37 <h2>Top Results</h2>
... ...
tests/test_eval_metrics.py 0 → 100644
... ... @@ -0,0 +1,26 @@
  1 +"""Tests for search evaluation ranking metrics (NDCG, ERR)."""
  2 +
  3 +from scripts.evaluation.eval_framework.constants import (
  4 + RELEVANCE_EXACT,
  5 + RELEVANCE_HIGH,
  6 + RELEVANCE_IRRELEVANT,
  7 + RELEVANCE_LOW,
  8 +)
  9 +from scripts.evaluation.eval_framework.metrics import compute_query_metrics
  10 +
  11 +
  12 +def test_err_matches_documented_three_item_examples():
  13 + # Model A: [Exact, Irrelevant, High] -> ERR ≈ 0.992667
  14 + m_a = compute_query_metrics(
  15 + [RELEVANCE_EXACT, RELEVANCE_IRRELEVANT, RELEVANCE_HIGH],
  16 + ideal_labels=[RELEVANCE_EXACT],
  17 + )
  18 + assert abs(m_a["ERR@5"] - (0.99 + (1.0 / 3.0) * 0.8 * 0.01)) < 1e-5
  19 +
  20 + # Model B: [High, Low, Exact] -> ERR ≈ 0.8694
  21 + m_b = compute_query_metrics(
  22 + [RELEVANCE_HIGH, RELEVANCE_LOW, RELEVANCE_EXACT],
  23 + ideal_labels=[RELEVANCE_EXACT],
  24 + )
  25 + expected_b = 0.8 + 0.5 * 0.1 * 0.2 + (1.0 / 3.0) * 0.99 * 0.18
  26 + assert abs(m_b["ERR@5"] - expected_b) < 1e-5
... ...