tune_fusion.py 9.82 KB
#!/usr/bin/env python3

from __future__ import annotations

import argparse
import copy
import json
import re
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List

import requests
import yaml

PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from scripts.evaluation.eval_framework import (
    DEFAULT_ARTIFACT_ROOT,
    DEFAULT_QUERY_FILE,
    ensure_dir,
    utc_now_iso,
    utc_timestamp,
)


CONFIG_PATH = PROJECT_ROOT / "config" / "config.yaml"


@dataclass
class ExperimentSpec:
    name: str
    description: str
    params: Dict[str, Any]


def load_yaml(path: Path) -> Dict[str, Any]:
    return yaml.safe_load(path.read_text(encoding="utf-8"))


def write_yaml(path: Path, payload: Dict[str, Any]) -> None:
    path.write_text(
        yaml.safe_dump(payload, sort_keys=False, allow_unicode=True),
        encoding="utf-8",
    )


def set_nested_value(payload: Dict[str, Any], dotted_path: str, value: Any) -> None:
    current = payload
    parts = dotted_path.split(".")
    for part in parts[:-1]:
        current = current[part]
    current[parts[-1]] = value


def apply_params(base_config: Dict[str, Any], params: Dict[str, Any]) -> Dict[str, Any]:
    candidate = copy.deepcopy(base_config)
    for dotted_path, value in params.items():
        set_nested_value(candidate, dotted_path, value)
    return candidate


def wait_for_backend(base_url: str, timeout_sec: float = 300.0) -> Dict[str, Any]:
    deadline = time.time() + timeout_sec
    last_error = None
    while time.time() < deadline:
        try:
            response = requests.get(f"{base_url.rstrip('/')}/health", timeout=10)
            response.raise_for_status()
            payload = response.json()
            if str(payload.get("status")) == "healthy":
                return payload
            last_error = payload
        except Exception as exc:  # noqa: BLE001
            last_error = str(exc)
        time.sleep(2.0)
    raise RuntimeError(f"backend did not become healthy: {last_error}")


def run_restart() -> None:
    subprocess.run(["./restart.sh", "backend"], cwd=PROJECT_ROOT, check=True, timeout=600)


def read_queries(path: Path) -> List[str]:
    return [
        line.strip()
        for line in path.read_text(encoding="utf-8").splitlines()
        if line.strip() and not line.strip().startswith("#")
    ]


def run_batch_eval(
    *,
    tenant_id: str,
    queries_file: Path,
    top_k: int,
    language: str,
    force_refresh_labels: bool,
) -> Dict[str, Any]:
    cmd = [
        str(PROJECT_ROOT / ".venv" / "bin" / "python"),
        "scripts/evaluation/build_annotation_set.py",
        "batch",
        "--tenant-id",
        str(tenant_id),
        "--queries-file",
        str(queries_file),
        "--top-k",
        str(top_k),
        "--language",
        language,
    ]
    if force_refresh_labels:
        cmd.append("--force-refresh-labels")
    completed = subprocess.run(
        cmd,
        cwd=PROJECT_ROOT,
        check=True,
        capture_output=True,
        text=True,
        timeout=7200,
    )
    output = (completed.stdout or "") + "\n" + (completed.stderr or "")
    match = re.search(r"batch_id=([A-Za-z0-9_]+)\s+aggregate_metrics=(\{.*\})", output)
    if not match:
        raise RuntimeError(f"failed to parse batch output: {output[-2000:]}")
    batch_id = match.group(1)
    aggregate_metrics = json.loads(match.group(2).replace("'", '"'))
    return {
        "batch_id": batch_id,
        "aggregate_metrics": aggregate_metrics,
        "raw_output": output,
    }


def render_markdown(summary: Dict[str, Any]) -> str:
    lines = [
        "# Fusion Tuning Report",
        "",
        f"- Created at: {summary['created_at']}",
        f"- Tenant ID: {summary['tenant_id']}",
        f"- Query count: {summary['query_count']}",
        f"- Top K: {summary['top_k']}",
        f"- Score metric: {summary['score_metric']}",
        "",
        "## Experiments",
        "",
        "| Rank | Name | Score | MAP_3 | MAP_2_3 | P@5 | P@10 | Config |",
        "|---|---|---:|---:|---:|---:|---:|---|",
    ]
    for idx, item in enumerate(summary["experiments"], start=1):
        metrics = item["aggregate_metrics"]
        lines.append(
            "| "
            + " | ".join(
                [
                    str(idx),
                    item["name"],
                    str(item["score"]),
                    str(metrics.get("MAP_3", "")),
                    str(metrics.get("MAP_2_3", "")),
                    str(metrics.get("P@5", "")),
                    str(metrics.get("P@10", "")),
                    item["config_snapshot_path"],
                ]
            )
            + " |"
        )
    lines.extend(["", "## Details", ""])
    for item in summary["experiments"]:
        lines.append(f"### {item['name']}")
        lines.append("")
        lines.append(f"- Description: {item['description']}")
        lines.append(f"- Score: {item['score']}")
        lines.append(f"- Params: `{json.dumps(item['params'], ensure_ascii=False, sort_keys=True)}`")
        lines.append(f"- Batch report: {item['batch_report_path']}")
        lines.append("")
    return "\n".join(lines)


def load_experiments(path: Path) -> List[ExperimentSpec]:
    payload = json.loads(path.read_text(encoding="utf-8"))
    items = payload["experiments"] if isinstance(payload, dict) else payload
    experiments: List[ExperimentSpec] = []
    for item in items:
        experiments.append(
            ExperimentSpec(
                name=str(item["name"]),
                description=str(item.get("description") or ""),
                params=dict(item.get("params") or {}),
            )
        )
    return experiments


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Run fusion tuning experiments against the live backend")
    parser.add_argument("--tenant-id", default="163")
    parser.add_argument("--queries-file", default=str(DEFAULT_QUERY_FILE))
    parser.add_argument("--top-k", type=int, default=100)
    parser.add_argument("--language", default="en")
    parser.add_argument("--experiments-file", required=True)
    parser.add_argument("--search-base-url", default="http://127.0.0.1:6002")
    parser.add_argument("--score-metric", default="MAP_3")
    parser.add_argument("--apply-best", action="store_true")
    parser.add_argument("--force-refresh-labels-first-pass", action="store_true")
    return parser


def main() -> None:
    args = build_parser().parse_args()
    queries_file = Path(args.queries_file)
    queries = read_queries(queries_file)
    base_config_text = CONFIG_PATH.read_text(encoding="utf-8")
    base_config = load_yaml(CONFIG_PATH)
    experiments = load_experiments(Path(args.experiments_file))

    tuning_dir = ensure_dir(DEFAULT_ARTIFACT_ROOT / "tuning_runs")
    run_id = f"tuning_{utc_timestamp()}"
    run_dir = ensure_dir(tuning_dir / run_id)
    results: List[Dict[str, Any]] = []

    try:
        for experiment in experiments:
            candidate = apply_params(base_config, experiment.params)
            write_yaml(CONFIG_PATH, candidate)
            candidate_config_path = run_dir / f"{experiment.name}_config.yaml"
            write_yaml(candidate_config_path, candidate)

            run_restart()
            health = wait_for_backend(args.search_base_url)
            batch_result = run_batch_eval(
                tenant_id=args.tenant_id,
                queries_file=queries_file,
                top_k=args.top_k,
                language=args.language,
                force_refresh_labels=bool(args.force_refresh_labels_first_pass and not results),
            )
            aggregate_metrics = dict(batch_result["aggregate_metrics"])
            results.append(
                {
                    "name": experiment.name,
                    "description": experiment.description,
                    "params": experiment.params,
                    "aggregate_metrics": aggregate_metrics,
                    "score": float(aggregate_metrics.get(args.score_metric, 0.0)),
                    "batch_id": batch_result["batch_id"],
                    "batch_report_path": str(
                        DEFAULT_ARTIFACT_ROOT / "batch_reports" / f"{batch_result['batch_id']}.md"
                    ),
                    "config_snapshot_path": str(candidate_config_path),
                    "backend_health": health,
                    "batch_stdout": batch_result["raw_output"],
                }
            )
            print(
                f"[tune] {experiment.name} score={aggregate_metrics.get(args.score_metric)} "
                f"metrics={aggregate_metrics}"
            )
    finally:
        if args.apply_best and results:
            best = max(results, key=lambda item: item["score"])
            best_config = apply_params(base_config, best["params"])
            write_yaml(CONFIG_PATH, best_config)
            run_restart()
            wait_for_backend(args.search_base_url)
        else:
            CONFIG_PATH.write_text(base_config_text, encoding="utf-8")
            run_restart()
            wait_for_backend(args.search_base_url)

    results.sort(key=lambda item: item["score"], reverse=True)
    summary = {
        "run_id": run_id,
        "created_at": utc_now_iso(),
        "tenant_id": args.tenant_id,
        "query_count": len(queries),
        "top_k": args.top_k,
        "score_metric": args.score_metric,
        "experiments": results,
    }
    summary_json_path = run_dir / "summary.json"
    summary_md_path = run_dir / "summary.md"
    summary_json_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
    summary_md_path.write_text(render_markdown(summary), encoding="utf-8")
    print(f"[done] summary_json={summary_json_path}")
    print(f"[done] summary_md={summary_md_path}")


if __name__ == "__main__":
    main()