analyze_eval_regressions.py 10.3 KB
#!/usr/bin/env python3
"""
Analyze per-query regressions between two batch evaluation JSON reports and
attribute likely causes by inspecting ES documents from two indices.

Outputs:
- Top regressions by Primary_Metric_Score delta
- For each regressed query:
  - metric deltas
  - top-10 SPU overlap and swapped-in SPUs
  - for swapped-in SPUs, show which search fields contain the query term

This is a heuristic attribution tool (string containment), but it's fast and
usually enough to pinpoint regressions caused by missing/noisy fields such as
qanchors/keywords/title in different languages.

Usage:
  set -a; source .env; set +a
  ./.venv/bin/python scripts/inspect/analyze_eval_regressions.py \
    --old-report artifacts/search_evaluation/batch_reports/batch_...073901....json \
    --new-report artifacts/search_evaluation/batch_reports/batch_...074717....json \
    --old-index search_products_tenant_163 \
    --new-index search_products_tenant_163_backup_20260415_1438 \
    --top-n 10
"""

from __future__ import annotations

import argparse
import json
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

from elasticsearch import Elasticsearch


def load_json(path: str) -> Dict[str, Any]:
    return json.loads(Path(path).read_text())


def norm_str(x: Any) -> str:
    if x is None:
        return ""
    if isinstance(x, str):
        return x
    return str(x)


def is_cjk(s: str) -> bool:
    return bool(re.search(r"[\u4e00-\u9fff]", s))


def flatten_text_values(v: Any) -> List[str]:
    """Extract strings from nested objects/lists (best-effort)."""
    out: List[str] = []
    if v is None:
        return out
    if isinstance(v, str):
        return [v]
    if isinstance(v, (int, float, bool)):
        return [str(v)]
    if isinstance(v, dict):
        for vv in v.values():
            out.extend(flatten_text_values(vv))
        return out
    if isinstance(v, list):
        for vv in v[:20]:
            out.extend(flatten_text_values(vv))
        return out
    return [str(v)]


def get_lang_obj(src: Dict[str, Any], field: str, lang: str) -> Any:
    obj = src.get(field)
    if isinstance(obj, dict):
        return obj.get(lang)
    return None


def contains_query(val: Any, query: str) -> bool:
    q = query.strip()
    if not q:
        return False
    texts = flatten_text_values(val)
    # simple substring match (case-insensitive for non-cjk)
    if is_cjk(q):
        return any(q in t for t in texts)
    ql = q.lower()
    return any(ql in (t or "").lower() for t in texts)


@dataclass
class PerQuery:
    query: str
    metrics: Dict[str, float]
    top_results: List[Dict[str, Any]]
    request_id: Optional[str]


def per_query_map(report: Dict[str, Any]) -> Dict[str, PerQuery]:
    out: Dict[str, PerQuery] = {}
    for rec in report.get("per_query") or []:
        q = rec.get("query")
        if not q:
            continue
        metrics = {k: float(v) for k, v in (rec.get("metrics") or {}).items() if isinstance(v, (int, float))}
        out[q] = PerQuery(
            query=q,
            metrics=metrics,
            top_results=list(rec.get("top_results") or []),
            request_id=rec.get("request_id"),
        )
    return out


def top_spus(pq: PerQuery, n: int = 10) -> List[str]:
    spus: List[str] = []
    for r in pq.top_results[:n]:
        spu = r.get("spu_id")
        if spu is not None:
            spus.append(str(spu))
    return spus


def build_es() -> Elasticsearch:
    es_url = os.environ.get("ES") or os.environ.get("ES_HOST") or "http://127.0.0.1:9200"
    auth = os.environ.get("ES_AUTH")
    if auth and ":" in auth:
        user, pwd = auth.split(":", 1)
        return Elasticsearch(hosts=[es_url], basic_auth=(user, pwd))
    return Elasticsearch(hosts=[es_url])


def mget_sources(es: Elasticsearch, index: str, ids: Sequence[str]) -> Dict[str, Dict[str, Any]]:
    resp = es.mget(index=index, body={"ids": list(ids)})
    out: Dict[str, Dict[str, Any]] = {}
    for d in resp.get("docs") or []:
        if d.get("found") and d.get("_id") and isinstance(d.get("_source"), dict):
            out[str(d["_id"])] = d["_source"]
    return out


def non_empty(v: Any) -> bool:
    if v is None:
        return False
    if isinstance(v, str):
        return bool(v.strip())
    if isinstance(v, (list, tuple, set)):
        return len(v) > 0
    if isinstance(v, dict):
        return any(non_empty(x) for x in v.values())
    return True


def summarize_field(src: Dict[str, Any], field: str, lang: Optional[str]) -> Dict[str, Any]:
    """Summarize presence and a small sample for a field (optionally language-specific)."""
    obj = src.get(field)
    if lang and isinstance(obj, dict):
        obj = obj.get(lang)
    present = non_empty(obj)
    sample = None
    if isinstance(obj, str):
        sample = obj[:80]
    elif isinstance(obj, list):
        sample = obj[:3]
    elif isinstance(obj, dict):
        sample = {k: obj.get(k) for k in list(obj.keys())[:3]}
    return {"present": present, "sample": sample}


def main() -> int:
    ap = argparse.ArgumentParser(description="Analyze regressions between two eval batch reports.")
    ap.add_argument("--old-report", required=True, help="Older/worse/baseline batch JSON path")
    ap.add_argument("--new-report", required=True, help="Newer candidate batch JSON path")
    ap.add_argument("--old-index", required=True, help="ES index used by old report")
    ap.add_argument("--new-index", required=True, help="ES index used by new report")
    ap.add_argument("--top-n", type=int, default=10, help="How many worst regressions to analyze (default 10)")
    ap.add_argument("--metric", default="Primary_Metric_Score", help="Metric to rank regressions by")
    ap.add_argument("--topk", type=int, default=10, help="Top-K results to compare per query (default 10)")
    args = ap.parse_args()

    old = load_json(args.old_report)
    new = load_json(args.new_report)
    old_map = per_query_map(old)
    new_map = per_query_map(new)

    metric = args.metric
    queries = list(new.get("queries") or old.get("queries") or [])

    deltas: List[Tuple[str, float]] = []
    for q in queries:
        o = old_map.get(q)
        n = new_map.get(q)
        if not o or not n:
            continue
        d = float(n.metrics.get(metric, 0.0)) - float(o.metrics.get(metric, 0.0))
        deltas.append((q, d))

    deltas.sort(key=lambda x: x[1])
    worst = deltas[: args.top_n]

    print("=" * 100)
    print(f"Top {len(worst)} regressions by {metric} (new - old)")
    print("=" * 100)
    for q, d in worst:
        o = old_map[q]
        n = new_map[q]
        print(f"- {q}: {d:+.4f}  old={o.metrics.get(metric, 0.0):.4f} -> new={n.metrics.get(metric, 0.0):.4f}")

    es = build_es()

    # Fields that matter according to config.yaml
    # (keep it aligned with multilingual_fields + best_fields/phrase_fields)
    inspect_fields = [
        "title",
        "keywords",
        "qanchors",
        "category_name_text",
        "vendor",
        "tags",
        "option1_values",
        "option2_values",
        "option3_values",
    ]

    print("\n" + "=" * 100)
    print("Heuristic attribution for worst regressions")
    print("=" * 100)

    for q, d in worst:
        o = old_map[q]
        n = new_map[q]
        old_spus = top_spus(o, args.topk)
        new_spus = top_spus(n, args.topk)
        old_set, new_set = set(old_spus), set(new_spus)
        swapped_in = [s for s in new_spus if s not in old_set]
        swapped_out = [s for s in old_spus if s not in new_set]

        print("\n" + "-" * 100)
        print(f"Query: {q}")
        print(f"Delta {metric}: {d:+.4f}")
        # show a few key metrics
        for m in ["NDCG@20", "Strong_Precision@10", "Gain_Recall@20", "ERR@10"]:
            if m in o.metrics and m in n.metrics:
                print(f"  {m}: {n.metrics[m]-o.metrics[m]:+.4f}  (old {o.metrics[m]:.4f} -> new {n.metrics[m]:.4f})")
        print(f"  old request_id={o.request_id}  new request_id={n.request_id}")
        print(f"  top{args.topk} overlap: {len(old_set & new_set)}/{args.topk}")
        print(f"  swapped_in (new only): {swapped_in[:10]}")
        print(f"  swapped_out (old only): {swapped_out[:10]}")

        # Fetch swapped_in docs from both indices to spot index-field differences.
        if not swapped_in:
            continue
        docs_new = mget_sources(es, args.new_index, swapped_in)
        docs_old = mget_sources(es, args.old_index, swapped_in)

        lang = "zh" if is_cjk(q) else "en"
        print(f"  language_guess: {lang}")
        for spu in swapped_in[:8]:
            src_new = docs_new.get(spu) or {}
            src_old = docs_old.get(spu) or {}

            title = get_lang_obj(src_new, "title", lang) or get_lang_obj(src_new, "title", "en") or ""
            print(f"    - spu={spu} title≈{norm_str(title)[:60]!r}")

            presence_new = {f: summarize_field(src_new, f, lang) for f in inspect_fields}
            presence_old = {f: summarize_field(src_old, f, lang) for f in inspect_fields}

            new_only = [f for f in inspect_fields if presence_new[f]["present"] and not presence_old[f]["present"]]
            old_only = [f for f in inspect_fields if presence_old[f]["present"] and not presence_new[f]["present"]]
            if new_only or old_only:
                print(f"      field_presence_diff: new_only={new_only} old_only={old_only}")

            # still report exact-substring match where it exists (often useful for English)
            hits = []
            for f in inspect_fields:
                v = get_lang_obj(src_new, f, lang)
                if v is None:
                    v = src_new.get(f)
                if contains_query(v, q):
                    hits.append(f)
            if hits:
                print(f"      exact_substring_matched_fields: {hits}")

            # compact samples for the most likely culprits
            for f in ["qanchors", "keywords", "title"]:
                pn = presence_new.get(f)
                po = presence_old.get(f)
                if pn and po and (pn["present"] or po["present"]):
                    print(
                        f"      {f}: new.present={pn['present']} old.present={po['present']}  "
                        f"new.sample={pn['sample']}  old.sample={po['sample']}"
                    )

    return 0


if __name__ == "__main__":
    raise SystemExit(main())