cli.py 12.1 KB
"""CLI: build annotations, batch eval, audit, serve web UI."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any, Dict

from .constants import (
    DEFAULT_LABELER_MODE,
    DEFAULT_QUERY_FILE,
    DEFAULT_REBUILD_IRRELEVANT_STOP_RATIO,
    DEFAULT_REBUILD_IRRELEVANT_STOP_STREAK,
    DEFAULT_REBUILD_LLM_BATCH_SIZE,
    DEFAULT_REBUILD_MAX_LLM_BATCHES,
    DEFAULT_REBUILD_MIN_LLM_BATCHES,
    DEFAULT_RERANK_HIGH_SKIP_COUNT,
    DEFAULT_RERANK_HIGH_THRESHOLD,
    DEFAULT_SEARCH_RECALL_TOP_K,
)
from .framework import SearchEvaluationFramework
from .utils import ensure_dir, utc_now_iso, utc_timestamp
from .web_app import create_web_app


def add_judge_llm_args(p: argparse.ArgumentParser) -> None:
    p.add_argument(
        "--judge-model",
        default=None,
        metavar="MODEL",
        help="Judge LLM model (default: eval_framework.constants.DEFAULT_JUDGE_MODEL).",
    )
    p.add_argument(
        "--enable-thinking",
        action=argparse.BooleanOptionalAction,
        default=None,
        help="enable_thinking for DashScope (default: DEFAULT_JUDGE_ENABLE_THINKING).",
    )
    p.add_argument(
        "--dashscope-batch",
        action=argparse.BooleanOptionalAction,
        default=None,
        help="DashScope Batch File API vs sync chat (default: DEFAULT_JUDGE_DASHSCOPE_BATCH).",
    )


def framework_kwargs_from_args(args: argparse.Namespace) -> Dict[str, Any]:
    kw: Dict[str, Any] = {}
    if args.judge_model is not None:
        kw["judge_model"] = args.judge_model
    if args.enable_thinking is not None:
        kw["enable_thinking"] = args.enable_thinking
    if args.dashscope_batch is not None:
        kw["use_dashscope_batch"] = args.dashscope_batch
    return kw


def build_cli_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Search evaluation annotation builder and web UI")
    sub = parser.add_subparsers(dest="command", required=True)

    build = sub.add_parser("build", help="Build pooled annotation set for queries")
    build.add_argument("--tenant-id", default="163")
    build.add_argument("--queries-file", default=str(DEFAULT_QUERY_FILE))
    build.add_argument("--search-depth", type=int, default=1000)
    build.add_argument("--rerank-depth", type=int, default=10000)
    build.add_argument("--annotate-search-top-k", type=int, default=120)
    build.add_argument("--annotate-rerank-top-k", type=int, default=200)
    build.add_argument(
        "--search-recall-top-k",
        type=int,
        default=None,
        help="Rebuild mode only: top-K search hits enter recall pool with score 1 (default when --force-refresh-labels: 500).",
    )
    build.add_argument(
        "--rerank-high-threshold",
        type=float,
        default=None,
        help="Rebuild only: count rerank scores above this on non-pool docs (default 0.5).",
    )
    build.add_argument(
        "--rerank-high-skip-count",
        type=int,
        default=None,
        help="Rebuild only: skip query if more than this many non-pool docs have rerank score > threshold (default 1000).",
    )
    build.add_argument("--rebuild-llm-batch-size", type=int, default=None, help="Rebuild only: LLM batch size (default 50).")
    build.add_argument("--rebuild-min-batches", type=int, default=None, help="Rebuild only: min LLM batches before early stop (default 20).")
    build.add_argument("--rebuild-max-batches", type=int, default=None, help="Rebuild only: max LLM batches (default 40).")
    build.add_argument(
        "--rebuild-irrelevant-stop-ratio",
        type=float,
        default=None,
        help="Rebuild only: irrelevant ratio above this counts toward early-stop streak (default 0.92).",
    )
    build.add_argument(
        "--rebuild-irrelevant-stop-streak",
        type=int,
        default=None,
        help="Rebuild only: stop after this many consecutive batches above irrelevant ratio (default 3).",
    )
    build.add_argument("--language", default="en")
    build.add_argument("--force-refresh-rerank", action="store_true")
    build.add_argument("--force-refresh-labels", action="store_true")
    build.add_argument("--labeler-mode", default=DEFAULT_LABELER_MODE, choices=["simple", "complex"])
    add_judge_llm_args(build)

    batch = sub.add_parser("batch", help="Run batch evaluation against live search")
    batch.add_argument("--tenant-id", default="163")
    batch.add_argument("--queries-file", default=str(DEFAULT_QUERY_FILE))
    batch.add_argument("--top-k", type=int, default=100)
    batch.add_argument("--language", default="en")
    batch.add_argument("--force-refresh-labels", action="store_true")
    batch.add_argument("--labeler-mode", default=DEFAULT_LABELER_MODE, choices=["simple", "complex"])
    add_judge_llm_args(batch)

    audit = sub.add_parser("audit", help="Audit annotation quality for queries")
    audit.add_argument("--tenant-id", default="163")
    audit.add_argument("--queries-file", default=str(DEFAULT_QUERY_FILE))
    audit.add_argument("--top-k", type=int, default=100)
    audit.add_argument("--language", default="en")
    audit.add_argument("--limit-suspicious", type=int, default=5)
    audit.add_argument("--force-refresh-labels", action="store_true")
    audit.add_argument("--labeler-mode", default=DEFAULT_LABELER_MODE, choices=["simple", "complex"])
    add_judge_llm_args(audit)

    serve = sub.add_parser("serve", help="Serve evaluation web UI on port 6010")
    serve.add_argument("--tenant-id", default="163")
    serve.add_argument("--queries-file", default=str(DEFAULT_QUERY_FILE))
    serve.add_argument("--host", default="0.0.0.0")
    serve.add_argument("--port", type=int, default=6010)
    serve.add_argument("--labeler-mode", default=DEFAULT_LABELER_MODE, choices=["simple", "complex"])
    add_judge_llm_args(serve)

    return parser


def run_build(args: argparse.Namespace) -> None:
    framework = SearchEvaluationFramework(
        tenant_id=args.tenant_id, labeler_mode=args.labeler_mode, **framework_kwargs_from_args(args)
    )
    queries = framework.queries_from_file(Path(args.queries_file))
    summary = []
    rebuild_kwargs = {}
    if args.force_refresh_labels:
        rebuild_kwargs = {
            "search_recall_top_k": args.search_recall_top_k if args.search_recall_top_k is not None else DEFAULT_SEARCH_RECALL_TOP_K,
            "rerank_high_threshold": args.rerank_high_threshold if args.rerank_high_threshold is not None else DEFAULT_RERANK_HIGH_THRESHOLD,
            "rerank_high_skip_count": args.rerank_high_skip_count if args.rerank_high_skip_count is not None else DEFAULT_RERANK_HIGH_SKIP_COUNT,
            "rebuild_llm_batch_size": args.rebuild_llm_batch_size if args.rebuild_llm_batch_size is not None else DEFAULT_REBUILD_LLM_BATCH_SIZE,
            "rebuild_min_batches": args.rebuild_min_batches if args.rebuild_min_batches is not None else DEFAULT_REBUILD_MIN_LLM_BATCHES,
            "rebuild_max_batches": args.rebuild_max_batches if args.rebuild_max_batches is not None else DEFAULT_REBUILD_MAX_LLM_BATCHES,
            "rebuild_irrelevant_stop_ratio": args.rebuild_irrelevant_stop_ratio
            if args.rebuild_irrelevant_stop_ratio is not None
            else DEFAULT_REBUILD_IRRELEVANT_STOP_RATIO,
            "rebuild_irrelevant_stop_streak": args.rebuild_irrelevant_stop_streak
            if args.rebuild_irrelevant_stop_streak is not None
            else DEFAULT_REBUILD_IRRELEVANT_STOP_STREAK,
        }
    for query in queries:
        result = framework.build_query_annotation_set(
            query=query,
            search_depth=args.search_depth,
            rerank_depth=args.rerank_depth,
            annotate_search_top_k=args.annotate_search_top_k,
            annotate_rerank_top_k=args.annotate_rerank_top_k,
            language=args.language,
            force_refresh_rerank=args.force_refresh_rerank,
            force_refresh_labels=args.force_refresh_labels,
            **rebuild_kwargs,
        )
        summary.append(
            {
                "query": result.query,
                "search_total": result.search_total,
                "search_depth": result.search_depth,
                "rerank_corpus_size": result.rerank_corpus_size,
                "annotated_count": result.annotated_count,
                "output_json_path": str(result.output_json_path),
            }
        )
        print(
            f"[build] query={result.query!r} search_total={result.search_total} "
            f"search_depth={result.search_depth} corpus={result.rerank_corpus_size} "
            f"annotated={result.annotated_count} output={result.output_json_path}"
        )
    out_path = ensure_dir(framework.artifact_root / "query_builds") / f"build_summary_{utc_timestamp()}.json"
    out_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"[done] summary={out_path}")


def run_batch(args: argparse.Namespace) -> None:
    framework = SearchEvaluationFramework(
        tenant_id=args.tenant_id, labeler_mode=args.labeler_mode, **framework_kwargs_from_args(args)
    )
    queries = framework.queries_from_file(Path(args.queries_file))
    payload = framework.batch_evaluate(
        queries=queries,
        top_k=args.top_k,
        auto_annotate=True,
        language=args.language,
        force_refresh_labels=args.force_refresh_labels,
    )
    print(f"[done] batch_id={payload['batch_id']} aggregate_metrics={payload['aggregate_metrics']}")


def run_audit(args: argparse.Namespace) -> None:
    framework = SearchEvaluationFramework(
        tenant_id=args.tenant_id, labeler_mode=args.labeler_mode, **framework_kwargs_from_args(args)
    )
    queries = framework.queries_from_file(Path(args.queries_file))
    audit_items = []
    for query in queries:
        item = framework.audit_live_query(
            query=query,
            top_k=args.top_k,
            language=args.language,
            auto_annotate=not args.force_refresh_labels,
        )
        if args.force_refresh_labels:
            live_payload = framework.search_client.search(query=query, size=max(args.top_k, 100), from_=0, language=args.language)
            framework.annotate_missing_labels(
                query=query,
                docs=list(live_payload.get("results") or [])[: args.top_k],
                force_refresh=True,
            )
            item = framework.audit_live_query(
                query=query,
                top_k=args.top_k,
                language=args.language,
                auto_annotate=False,
            )
        audit_items.append(
            {
                "query": query,
                "metrics": item["metrics"],
                "distribution": item["distribution"],
                "suspicious_count": len(item["suspicious"]),
                "suspicious_examples": item["suspicious"][: args.limit_suspicious],
            }
        )
        print(
            f"[audit] query={query!r} suspicious={len(item['suspicious'])} metrics={item['metrics']}"
        )

    summary = {
        "created_at": utc_now_iso(),
        "tenant_id": args.tenant_id,
        "top_k": args.top_k,
        "query_count": len(queries),
        "total_suspicious": sum(item["suspicious_count"] for item in audit_items),
        "queries": audit_items,
    }
    out_path = ensure_dir(framework.artifact_root / "audits") / f"audit_{utc_timestamp()}.json"
    out_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"[done] audit={out_path}")


def run_serve(args: argparse.Namespace) -> None:
    framework = SearchEvaluationFramework(
        tenant_id=args.tenant_id, labeler_mode=args.labeler_mode, **framework_kwargs_from_args(args)
    )
    app = create_web_app(framework, Path(args.queries_file))
    import uvicorn

    uvicorn.run(app, host=args.host, port=args.port, log_level="info")


def main() -> None:
    parser = build_cli_parser()
    args = parser.parse_args()
    if args.command == "build":
        run_build(args)
        return
    if args.command == "batch":
        run_batch(args)
        return
    if args.command == "audit":
        run_audit(args)
        return
    if args.command == "serve":
        run_serve(args)
        return
    raise SystemExit(f"unknown command: {args.command}")