analyze_eval_index_regression.py 13.6 KB
#!/usr/bin/env python3
"""
Analyze search evaluation regressions between two batch reports and trace them back
to document field changes across two Elasticsearch indices.

Typical usage:
  ./.venv/bin/python scripts/inspect/analyze_eval_index_regression.py \
    --current-report artifacts/search_evaluation/batch_reports/batch_20260417T073901Z_00b6a8aa3d.json \
    --backup-report artifacts/search_evaluation/batch_reports/batch_20260417T074717Z_00b6a8aa3d.json \
    --current-index search_products_tenant_163 \
    --backup-index search_products_tenant_163_backup_20260415_1438
"""

from __future__ import annotations

import argparse
import json
import statistics
import sys
from collections import Counter
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from utils.es_client import get_es_client_from_env


SEARCHABLE_SOURCE_FIELDS: Sequence[str] = (
    "title",
    "keywords",
    "qanchors",
    "enriched_tags",
    "enriched_attributes",
    "option1_values",
    "option2_values",
    "option3_values",
    "tags",
    "category_path",
    "category_name_text",
)

CORE_FIELDS_TO_COMPARE: Sequence[str] = (
    "title",
    "keywords",
    "qanchors",
    "enriched_tags",
    "enriched_attributes",
    "option1_values",
    "option2_values",
    "option3_values",
    "tags",
)

STRONG_LABELS = {"Fully Relevant", "Mostly Relevant"}


def _load_report(path: str) -> Dict[str, Any]:
    return json.loads(Path(path).read_text())


def _rank_map(rows: Sequence[Dict[str, Any]]) -> Dict[str, int]:
    return {str(row["spu_id"]): int(row["rank"]) for row in rows}


def _label_map(rows: Sequence[Dict[str, Any]]) -> Dict[str, str]:
    return {str(row["spu_id"]): str(row["label"]) for row in rows}


def _count_items(value: Any) -> int:
    if isinstance(value, list):
        return len(value)
    if isinstance(value, str):
        return len([x for x in value.split(",") if x.strip()])
    return 0


def _json_short(value: Any, max_len: int = 220) -> str:
    payload = json.dumps(value, ensure_ascii=False, sort_keys=True)
    if len(payload) <= max_len:
        return payload
    return payload[: max_len - 3] + "..."


class SourceFetcher:
    def __init__(self) -> None:
        self.es = get_es_client_from_env().client
        self._cache: Dict[Tuple[str, str], Optional[Dict[str, Any]]] = {}

    def fetch(self, index_name: str, spu_id: str) -> Optional[Dict[str, Any]]:
        key = (index_name, spu_id)
        if key in self._cache:
            return self._cache[key]
        body = {
            "size": 1,
            "query": {"term": {"spu_id": spu_id}},
            "_source": ["spu_id", *SEARCHABLE_SOURCE_FIELDS],
        }
        hits = self.es.search(index=index_name, body=body)["hits"]["hits"]
        doc = hits[0]["_source"] if hits else None
        self._cache[key] = doc
        return doc


def _changed_fields(current_doc: Dict[str, Any], backup_doc: Dict[str, Any]) -> List[str]:
    return [field for field in CORE_FIELDS_TO_COMPARE if current_doc.get(field) != backup_doc.get(field)]


def _iter_regressed_docs(
    current_report: Dict[str, Any],
    backup_report: Dict[str, Any],
    rank_gap_threshold: int,
    scan_depth: int,
) -> Iterable[Dict[str, Any]]:
    current_per_query = {row["query"]: row for row in current_report["per_query"]}
    backup_per_query = {row["query"]: row for row in backup_report["per_query"]}
    for query, current_case in current_per_query.items():
        backup_case = backup_per_query[query]
        delta = (
            float(current_case["metrics"]["Primary_Metric_Score"])
            - float(backup_case["metrics"]["Primary_Metric_Score"])
        )
        if delta >= 0:
            continue
        current_ranks = _rank_map(current_case["top_results"])
        current_labels = _label_map(current_case["top_results"])
        for row in backup_case["top_results"][:scan_depth]:
            if row["label"] not in STRONG_LABELS:
                continue
            current_rank = current_ranks.get(row["spu_id"], 999)
            if current_rank <= int(row["rank"]) + rank_gap_threshold:
                continue
            yield {
                "query": query,
                "delta_primary": delta,
                "spu_id": str(row["spu_id"]),
                "backup_rank": int(row["rank"]),
                "backup_label": str(row["label"]),
                "current_rank": current_rank,
                "current_label": current_labels.get(row["spu_id"]),
            }


def _print_metric_summary(current_report: Dict[str, Any], backup_report: Dict[str, Any], top_n: int) -> None:
    current_per_query = {row["query"]: row for row in current_report["per_query"]}
    backup_per_query = {row["query"]: row for row in backup_report["per_query"]}
    deltas: List[Tuple[str, float, Dict[str, Any], Dict[str, Any]]] = []
    for query, current_case in current_per_query.items():
        backup_case = backup_per_query[query]
        deltas.append(
            (
                query,
                float(current_case["metrics"]["Primary_Metric_Score"])
                - float(backup_case["metrics"]["Primary_Metric_Score"]),
                current_case,
                backup_case,
            )
        )
    worse = sum(1 for _, delta, _, _ in deltas if delta < 0)
    better = sum(1 for _, delta, _, _ in deltas if delta > 0)
    print("Overall Query Delta")
    print("=" * 80)
    print(f"worse: {worse} | better: {better} | total: {len(deltas)}")
    print(
        "aggregate primary:"
        f" current={current_report['aggregate_metrics']['Primary_Metric_Score']:.6f}"
        f" backup={backup_report['aggregate_metrics']['Primary_Metric_Score']:.6f}"
        f" delta={current_report['aggregate_metrics']['Primary_Metric_Score'] - backup_report['aggregate_metrics']['Primary_Metric_Score']:+.6f}"
    )
    print()
    print(f"Worst {top_n} Queries By Primary_Metric_Score Delta")
    print("=" * 80)
    for query, delta, current_case, backup_case in sorted(deltas, key=lambda x: x[1])[:top_n]:
        print(
            f"{delta:+.4f}\t{query}\t"
            f"NDCG@20 {current_case['metrics']['NDCG@20'] - backup_case['metrics']['NDCG@20']:+.4f}\t"
            f"ERR@10 {current_case['metrics']['ERR@10'] - backup_case['metrics']['ERR@10']:+.4f}\t"
            f"SP@10 {current_case['metrics']['Strong_Precision@10'] - backup_case['metrics']['Strong_Precision@10']:+.2f}"
        )
    print()


def _print_field_change_summary(
    regressed_rows: Sequence[Dict[str, Any]],
    fetcher: SourceFetcher,
    current_index: str,
    backup_index: str,
) -> None:
    field_counter: Counter[str] = Counter()
    qanchor_counts_en: List[Tuple[int, int]] = []
    qanchor_counts_zh: List[Tuple[int, int]] = []
    tag_counts_en: List[Tuple[int, int]] = []
    tag_counts_zh: List[Tuple[int, int]] = []

    for row in regressed_rows:
        current_doc = fetcher.fetch(current_index, row["spu_id"])
        backup_doc = fetcher.fetch(backup_index, row["spu_id"])
        if not current_doc or not backup_doc:
            continue
        for field in _changed_fields(current_doc, backup_doc):
            field_counter[field] += 1

        current_qanchors = current_doc.get("qanchors") or {}
        backup_qanchors = backup_doc.get("qanchors") or {}
        current_tags = current_doc.get("enriched_tags") or {}
        backup_tags = backup_doc.get("enriched_tags") or {}
        qanchor_counts_en.append((_count_items(current_qanchors.get("en")), _count_items(backup_qanchors.get("en"))))
        qanchor_counts_zh.append((_count_items(current_qanchors.get("zh")), _count_items(backup_qanchors.get("zh"))))
        tag_counts_en.append((_count_items(current_tags.get("en")), _count_items(backup_tags.get("en"))))
        tag_counts_zh.append((_count_items(current_tags.get("zh")), _count_items(backup_tags.get("zh"))))

    print("Affected Strong-Relevant Docs")
    print("=" * 80)
    print(f"count: {len(regressed_rows)}")
    print("changed field frequency:")
    for field, count in field_counter.most_common():
        print(f"  {field}: {count}")
    print()

    def summarize_counts(name: str, pairs: Sequence[Tuple[int, int]]) -> None:
        if not pairs:
            return
        current_counts = [current for current, _ in pairs]
        backup_counts = [backup for _, backup in pairs]
        print(
            f"{name}: current_avg={statistics.mean(current_counts):.3f} "
            f"backup_avg={statistics.mean(backup_counts):.3f} "
            f"delta={statistics.mean(current - backup for current, backup in pairs):+.3f} "
            f"backup_more={sum(1 for current, backup in pairs if backup > current)} "
            f"current_more={sum(1 for current, backup in pairs if current > backup)}"
        )

    print("phrase/tag density on affected docs:")
    summarize_counts("qanchors.en", qanchor_counts_en)
    summarize_counts("qanchors.zh", qanchor_counts_zh)
    summarize_counts("enriched_tags.en", tag_counts_en)
    summarize_counts("enriched_tags.zh", tag_counts_zh)
    print()


def _print_query_details(
    current_report: Dict[str, Any],
    backup_report: Dict[str, Any],
    regressed_rows: Sequence[Dict[str, Any]],
    fetcher: SourceFetcher,
    current_index: str,
    backup_index: str,
    top_queries: int,
    max_docs_per_query: int,
) -> None:
    current_per_query = {row["query"]: row for row in current_report["per_query"]}
    backup_per_query = {row["query"]: row for row in backup_report["per_query"]}
    grouped: Dict[str, List[Dict[str, Any]]] = {}
    for row in regressed_rows:
        grouped.setdefault(row["query"], []).append(row)

    ordered_queries = sorted(grouped, key=lambda q: current_per_query[q]["metrics"]["Primary_Metric_Score"] - backup_per_query[q]["metrics"]["Primary_Metric_Score"])

    print(f"Detailed Query Samples (top {top_queries})")
    print("=" * 80)
    for query in ordered_queries[:top_queries]:
        current_case = current_per_query[query]
        backup_case = backup_per_query[query]
        delta = current_case["metrics"]["Primary_Metric_Score"] - backup_case["metrics"]["Primary_Metric_Score"]
        print(f"\n## {query}")
        print(
            f"delta_primary={delta:+.6f} | current_top10={current_case['top_label_sequence_top10']} | "
            f"backup_top10={backup_case['top_label_sequence_top10']}"
        )
        for row in sorted(grouped[query], key=lambda item: item["backup_rank"])[:max_docs_per_query]:
            current_doc = fetcher.fetch(current_index, row["spu_id"])
            backup_doc = fetcher.fetch(backup_index, row["spu_id"])
            if not current_doc or not backup_doc:
                print(
                    f"  - spu={row['spu_id']} backup_rank={row['backup_rank']} current_rank={row['current_rank']} "
                    "(missing source)"
                )
                continue
            changed = _changed_fields(current_doc, backup_doc)
            print(
                f"  - spu={row['spu_id']} backup_rank={row['backup_rank']} ({row['backup_label']}) "
                f"-> current_rank={row['current_rank']} ({row['current_label']})"
            )
            print(f"    changed_fields: {', '.join(changed) if changed else '(none)'}")
            for field in changed[:4]:
                print(f"    {field}.current: {_json_short(current_doc.get(field))}")
                print(f"    {field}.backup : {_json_short(backup_doc.get(field))}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Analyze eval regressions between two indices")
    parser.add_argument("--current-report", required=True, help="Report JSON for the worse/current index")
    parser.add_argument("--backup-report", required=True, help="Report JSON for the better/reference index")
    parser.add_argument("--current-index", required=True, help="Current/worse index name")
    parser.add_argument("--backup-index", required=True, help="Reference/better index name")
    parser.add_argument("--rank-gap-threshold", type=int, default=5, help="Treat a strong-relevant doc as regressed when current rank > backup rank + this gap")
    parser.add_argument("--scan-depth", type=int, default=20, help="Only inspect backup strong-relevant docs within this depth")
    parser.add_argument("--top-worst-queries", type=int, default=12, help="How many worst queries to print in the metric summary")
    parser.add_argument("--detail-queries", type=int, default=6, help="How many regressed queries to print detailed field diffs for")
    parser.add_argument("--detail-docs-per-query", type=int, default=3, help="How many regressed docs to print per detailed query")
    args = parser.parse_args()

    current_report = _load_report(args.current_report)
    backup_report = _load_report(args.backup_report)
    fetcher = SourceFetcher()
    regressed_rows = list(
        _iter_regressed_docs(
            current_report=current_report,
            backup_report=backup_report,
            rank_gap_threshold=args.rank_gap_threshold,
            scan_depth=args.scan_depth,
        )
    )

    _print_metric_summary(current_report, backup_report, top_n=args.top_worst_queries)
    _print_field_change_summary(
        regressed_rows=regressed_rows,
        fetcher=fetcher,
        current_index=args.current_index,
        backup_index=args.backup_index,
    )
    _print_query_details(
        current_report=current_report,
        backup_report=backup_report,
        regressed_rows=regressed_rows,
        fetcher=fetcher,
        current_index=args.current_index,
        backup_index=args.backup_index,
        top_queries=args.detail_queries,
        max_docs_per_query=args.detail_docs_per_query,
    )


if __name__ == "__main__":
    main()