cli.py 7.57 KB
"""CLI: build annotations, batch eval, audit, serve web UI."""

from __future__ import annotations

import argparse
import json
from pathlib import Path

from .constants import DEFAULT_LABELER_MODE, DEFAULT_QUERY_FILE
from .framework import SearchEvaluationFramework
from .utils import ensure_dir, utc_now_iso, utc_timestamp
from .web_app import create_web_app


def build_cli_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Search evaluation annotation builder and web UI")
    sub = parser.add_subparsers(dest="command", required=True)

    build = sub.add_parser("build", help="Build pooled annotation set for queries")
    build.add_argument("--tenant-id", default="163")
    build.add_argument("--queries-file", default=str(DEFAULT_QUERY_FILE))
    build.add_argument("--search-depth", type=int, default=1000)
    build.add_argument("--rerank-depth", type=int, default=10000)
    build.add_argument("--annotate-search-top-k", type=int, default=120)
    build.add_argument("--annotate-rerank-top-k", type=int, default=200)
    build.add_argument("--language", default="en")
    build.add_argument("--force-refresh-rerank", action="store_true")
    build.add_argument("--force-refresh-labels", action="store_true")
    build.add_argument("--labeler-mode", default=DEFAULT_LABELER_MODE, choices=["simple", "complex"])

    batch = sub.add_parser("batch", help="Run batch evaluation against live search")
    batch.add_argument("--tenant-id", default="163")
    batch.add_argument("--queries-file", default=str(DEFAULT_QUERY_FILE))
    batch.add_argument("--top-k", type=int, default=100)
    batch.add_argument("--language", default="en")
    batch.add_argument("--force-refresh-labels", action="store_true")
    batch.add_argument("--labeler-mode", default=DEFAULT_LABELER_MODE, choices=["simple", "complex"])

    audit = sub.add_parser("audit", help="Audit annotation quality for queries")
    audit.add_argument("--tenant-id", default="163")
    audit.add_argument("--queries-file", default=str(DEFAULT_QUERY_FILE))
    audit.add_argument("--top-k", type=int, default=100)
    audit.add_argument("--language", default="en")
    audit.add_argument("--limit-suspicious", type=int, default=5)
    audit.add_argument("--force-refresh-labels", action="store_true")
    audit.add_argument("--labeler-mode", default=DEFAULT_LABELER_MODE, choices=["simple", "complex"])

    serve = sub.add_parser("serve", help="Serve evaluation web UI on port 6010")
    serve.add_argument("--tenant-id", default="163")
    serve.add_argument("--queries-file", default=str(DEFAULT_QUERY_FILE))
    serve.add_argument("--host", default="0.0.0.0")
    serve.add_argument("--port", type=int, default=6010)
    serve.add_argument("--labeler-mode", default=DEFAULT_LABELER_MODE, choices=["simple", "complex"])

    return parser


def run_build(args: argparse.Namespace) -> None:
    framework = SearchEvaluationFramework(tenant_id=args.tenant_id, labeler_mode=args.labeler_mode)
    queries = framework.queries_from_file(Path(args.queries_file))
    summary = []
    for query in queries:
        result = framework.build_query_annotation_set(
            query=query,
            search_depth=args.search_depth,
            rerank_depth=args.rerank_depth,
            annotate_search_top_k=args.annotate_search_top_k,
            annotate_rerank_top_k=args.annotate_rerank_top_k,
            language=args.language,
            force_refresh_rerank=args.force_refresh_rerank,
            force_refresh_labels=args.force_refresh_labels,
        )
        summary.append(
            {
                "query": result.query,
                "search_total": result.search_total,
                "search_depth": result.search_depth,
                "rerank_corpus_size": result.rerank_corpus_size,
                "annotated_count": result.annotated_count,
                "output_json_path": str(result.output_json_path),
            }
        )
        print(
            f"[build] query={result.query!r} search_total={result.search_total} "
            f"search_depth={result.search_depth} corpus={result.rerank_corpus_size} "
            f"annotated={result.annotated_count} output={result.output_json_path}"
        )
    out_path = ensure_dir(framework.artifact_root / "query_builds") / f"build_summary_{utc_timestamp()}.json"
    out_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"[done] summary={out_path}")


def run_batch(args: argparse.Namespace) -> None:
    framework = SearchEvaluationFramework(tenant_id=args.tenant_id, labeler_mode=args.labeler_mode)
    queries = framework.queries_from_file(Path(args.queries_file))
    payload = framework.batch_evaluate(
        queries=queries,
        top_k=args.top_k,
        auto_annotate=True,
        language=args.language,
        force_refresh_labels=args.force_refresh_labels,
    )
    print(f"[done] batch_id={payload['batch_id']} aggregate_metrics={payload['aggregate_metrics']}")


def run_audit(args: argparse.Namespace) -> None:
    framework = SearchEvaluationFramework(tenant_id=args.tenant_id, labeler_mode=args.labeler_mode)
    queries = framework.queries_from_file(Path(args.queries_file))
    audit_items = []
    for query in queries:
        item = framework.audit_live_query(
            query=query,
            top_k=args.top_k,
            language=args.language,
            auto_annotate=not args.force_refresh_labels,
        )
        if args.force_refresh_labels:
            live_payload = framework.search_client.search(query=query, size=max(args.top_k, 100), from_=0, language=args.language)
            framework.annotate_missing_labels(
                query=query,
                docs=list(live_payload.get("results") or [])[: args.top_k],
                force_refresh=True,
            )
            item = framework.audit_live_query(
                query=query,
                top_k=args.top_k,
                language=args.language,
                auto_annotate=False,
            )
        audit_items.append(
            {
                "query": query,
                "metrics": item["metrics"],
                "distribution": item["distribution"],
                "suspicious_count": len(item["suspicious"]),
                "suspicious_examples": item["suspicious"][: args.limit_suspicious],
            }
        )
        print(
            f"[audit] query={query!r} suspicious={len(item['suspicious'])} metrics={item['metrics']}"
        )

    summary = {
        "created_at": utc_now_iso(),
        "tenant_id": args.tenant_id,
        "top_k": args.top_k,
        "query_count": len(queries),
        "total_suspicious": sum(item["suspicious_count"] for item in audit_items),
        "queries": audit_items,
    }
    out_path = ensure_dir(framework.artifact_root / "audits") / f"audit_{utc_timestamp()}.json"
    out_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"[done] audit={out_path}")


def run_serve(args: argparse.Namespace) -> None:
    framework = SearchEvaluationFramework(tenant_id=args.tenant_id, labeler_mode=args.labeler_mode)
    app = create_web_app(framework, Path(args.queries_file))
    import uvicorn

    uvicorn.run(app, host=args.host, port=args.port, log_level="info")


def main() -> None:
    parser = build_cli_parser()
    args = parser.parse_args()
    if args.command == "build":
        run_build(args)
        return
    if args.command == "batch":
        run_batch(args)
        return
    if args.command == "audit":
        run_audit(args)
        return
    if args.command == "serve":
        run_serve(args)
        return
    raise SystemExit(f"unknown command: {args.command}")