test_eval_metrics.py 2.01 KB
"""Tests for search evaluation ranking metrics (NDCG, ERR)."""

import math

import pytest

pytestmark = [pytest.mark.eval, pytest.mark.regression]

from scripts.evaluation.eval_framework.constants import (
    RELEVANCE_LV0,
    RELEVANCE_LV1,
    RELEVANCE_LV2,
    RELEVANCE_LV3,
    STOP_PROB_MAP,
)
from scripts.evaluation.eval_framework.metrics import compute_query_metrics


def _expected_err(labels):
    err = 0.0
    product = 1.0
    for i, label in enumerate(labels, start=1):
        p = STOP_PROB_MAP[label]
        err += (1.0 / i) * p * product
        product *= 1.0 - p
    return err


def test_err_matches_cascade_formula_on_four_level_labels():
    """ERR@k must equal the textbook cascade formula against the four-level label set.

    The metric is the primary ranking signal (see `PRIMARY_METRIC_KEYS` in
    `eval_framework.metrics`); any regression here invalidates the whole
    evaluation pipeline.
    """

    ranked_a = [RELEVANCE_LV3, RELEVANCE_LV0, RELEVANCE_LV2]
    ranked_b = [RELEVANCE_LV2, RELEVANCE_LV1, RELEVANCE_LV3]

    m_a = compute_query_metrics(ranked_a, ideal_labels=[RELEVANCE_LV3])
    m_b = compute_query_metrics(ranked_b, ideal_labels=[RELEVANCE_LV3])

    assert math.isclose(m_a["ERR@5"], _expected_err(ranked_a), abs_tol=1e-5)
    assert math.isclose(m_b["ERR@5"], _expected_err(ranked_b), abs_tol=1e-5)
    assert m_a["ERR@5"] > m_b["ERR@5"]


def test_ndcg_at_k_is_1_when_actual_equals_ideal():
    labels = [RELEVANCE_LV3, RELEVANCE_LV2, RELEVANCE_LV1]
    metrics = compute_query_metrics(labels, ideal_labels=labels)
    assert math.isclose(metrics["NDCG@5"], 1.0, abs_tol=1e-9)
    assert math.isclose(metrics["NDCG@20"], 1.0, abs_tol=1e-9)


def test_all_irrelevant_zeroes_out_primary_signals():
    labels = [RELEVANCE_LV0, RELEVANCE_LV0, RELEVANCE_LV0]
    metrics = compute_query_metrics(labels, ideal_labels=[RELEVANCE_LV3])
    assert metrics["ERR@10"] == 0.0
    assert metrics["NDCG@20"] == 0.0
    assert metrics["Strong_Precision@10"] == 0.0
    assert metrics["Primary_Metric_Score"] == 0.0