#!/usr/bin/env python3 from __future__ import annotations import argparse import copy import json import re import subprocess import sys import time from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List import requests import yaml PROJECT_ROOT = Path(__file__).resolve().parents[2] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from scripts.evaluation.eval_framework import ( DEFAULT_ARTIFACT_ROOT, DEFAULT_QUERY_FILE, ensure_dir, utc_now_iso, utc_timestamp, ) CONFIG_PATH = PROJECT_ROOT / "config" / "config.yaml" @dataclass class ExperimentSpec: name: str description: str params: Dict[str, Any] def load_yaml(path: Path) -> Dict[str, Any]: return yaml.safe_load(path.read_text(encoding="utf-8")) def write_yaml(path: Path, payload: Dict[str, Any]) -> None: path.write_text( yaml.safe_dump(payload, sort_keys=False, allow_unicode=True), encoding="utf-8", ) def set_nested_value(payload: Dict[str, Any], dotted_path: str, value: Any) -> None: current = payload parts = dotted_path.split(".") for part in parts[:-1]: current = current[part] current[parts[-1]] = value def apply_params(base_config: Dict[str, Any], params: Dict[str, Any]) -> Dict[str, Any]: candidate = copy.deepcopy(base_config) for dotted_path, value in params.items(): set_nested_value(candidate, dotted_path, value) return candidate def wait_for_backend(base_url: str, timeout_sec: float = 300.0) -> Dict[str, Any]: deadline = time.time() + timeout_sec last_error = None while time.time() < deadline: try: response = requests.get(f"{base_url.rstrip('/')}/health", timeout=10) response.raise_for_status() payload = response.json() if str(payload.get("status")) == "healthy": return payload last_error = payload except Exception as exc: # noqa: BLE001 last_error = str(exc) time.sleep(2.0) raise RuntimeError(f"backend did not become healthy: {last_error}") def run_restart() -> None: subprocess.run(["./restart.sh", "backend"], cwd=PROJECT_ROOT, check=True, timeout=600) def read_queries(path: Path) -> List[str]: return [ line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip() and not line.strip().startswith("#") ] def run_batch_eval( *, tenant_id: str, queries_file: Path, top_k: int, language: str, force_refresh_labels: bool, ) -> Dict[str, Any]: cmd = [ str(PROJECT_ROOT / ".venv" / "bin" / "python"), "scripts/evaluation/build_annotation_set.py", "batch", "--tenant-id", str(tenant_id), "--queries-file", str(queries_file), "--top-k", str(top_k), "--language", language, ] if force_refresh_labels: cmd.append("--force-refresh-labels") completed = subprocess.run( cmd, cwd=PROJECT_ROOT, check=True, capture_output=True, text=True, timeout=7200, ) output = (completed.stdout or "") + "\n" + (completed.stderr or "") match = re.search(r"batch_id=([A-Za-z0-9_]+)\s+aggregate_metrics=(\{.*\})", output) if not match: raise RuntimeError(f"failed to parse batch output: {output[-2000:]}") batch_id = match.group(1) aggregate_metrics = json.loads(match.group(2).replace("'", '"')) return { "batch_id": batch_id, "aggregate_metrics": aggregate_metrics, "raw_output": output, } def render_markdown(summary: Dict[str, Any]) -> str: lines = [ "# Fusion Tuning Report", "", f"- Created at: {summary['created_at']}", f"- Tenant ID: {summary['tenant_id']}", f"- Query count: {summary['query_count']}", f"- Top K: {summary['top_k']}", f"- Score metric: {summary['score_metric']}", "", "## Experiments", "", "| Rank | Name | Score | MAP_3 | MAP_2_3 | P@5 | P@10 | Config |", "|---|---|---:|---:|---:|---:|---:|---|", ] for idx, item in enumerate(summary["experiments"], start=1): metrics = item["aggregate_metrics"] lines.append( "| " + " | ".join( [ str(idx), item["name"], str(item["score"]), str(metrics.get("MAP_3", "")), str(metrics.get("MAP_2_3", "")), str(metrics.get("P@5", "")), str(metrics.get("P@10", "")), item["config_snapshot_path"], ] ) + " |" ) lines.extend(["", "## Details", ""]) for item in summary["experiments"]: lines.append(f"### {item['name']}") lines.append("") lines.append(f"- Description: {item['description']}") lines.append(f"- Score: {item['score']}") lines.append(f"- Params: `{json.dumps(item['params'], ensure_ascii=False, sort_keys=True)}`") lines.append(f"- Batch report: {item['batch_report_path']}") lines.append("") return "\n".join(lines) def load_experiments(path: Path) -> List[ExperimentSpec]: payload = json.loads(path.read_text(encoding="utf-8")) items = payload["experiments"] if isinstance(payload, dict) else payload experiments: List[ExperimentSpec] = [] for item in items: experiments.append( ExperimentSpec( name=str(item["name"]), description=str(item.get("description") or ""), params=dict(item.get("params") or {}), ) ) return experiments def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run fusion tuning experiments against the live backend") parser.add_argument("--tenant-id", default="163") parser.add_argument("--queries-file", default=str(DEFAULT_QUERY_FILE)) parser.add_argument("--top-k", type=int, default=100) parser.add_argument("--language", default="en") parser.add_argument("--experiments-file", required=True) parser.add_argument("--search-base-url", default="http://127.0.0.1:6002") parser.add_argument("--score-metric", default="MAP_3") parser.add_argument("--apply-best", action="store_true") parser.add_argument("--force-refresh-labels-first-pass", action="store_true") return parser def main() -> None: args = build_parser().parse_args() queries_file = Path(args.queries_file) queries = read_queries(queries_file) base_config_text = CONFIG_PATH.read_text(encoding="utf-8") base_config = load_yaml(CONFIG_PATH) experiments = load_experiments(Path(args.experiments_file)) tuning_dir = ensure_dir(DEFAULT_ARTIFACT_ROOT / "tuning_runs") run_id = f"tuning_{utc_timestamp()}" run_dir = ensure_dir(tuning_dir / run_id) results: List[Dict[str, Any]] = [] try: for experiment in experiments: candidate = apply_params(base_config, experiment.params) write_yaml(CONFIG_PATH, candidate) candidate_config_path = run_dir / f"{experiment.name}_config.yaml" write_yaml(candidate_config_path, candidate) run_restart() health = wait_for_backend(args.search_base_url) batch_result = run_batch_eval( tenant_id=args.tenant_id, queries_file=queries_file, top_k=args.top_k, language=args.language, force_refresh_labels=bool(args.force_refresh_labels_first_pass and not results), ) aggregate_metrics = dict(batch_result["aggregate_metrics"]) results.append( { "name": experiment.name, "description": experiment.description, "params": experiment.params, "aggregate_metrics": aggregate_metrics, "score": float(aggregate_metrics.get(args.score_metric, 0.0)), "batch_id": batch_result["batch_id"], "batch_report_path": str( DEFAULT_ARTIFACT_ROOT / "batch_reports" / f"{batch_result['batch_id']}.md" ), "config_snapshot_path": str(candidate_config_path), "backend_health": health, "batch_stdout": batch_result["raw_output"], } ) print( f"[tune] {experiment.name} score={aggregate_metrics.get(args.score_metric)} " f"metrics={aggregate_metrics}" ) finally: if args.apply_best and results: best = max(results, key=lambda item: item["score"]) best_config = apply_params(base_config, best["params"]) write_yaml(CONFIG_PATH, best_config) run_restart() wait_for_backend(args.search_base_url) else: CONFIG_PATH.write_text(base_config_text, encoding="utf-8") run_restart() wait_for_backend(args.search_base_url) results.sort(key=lambda item: item["score"], reverse=True) summary = { "run_id": run_id, "created_at": utc_now_iso(), "tenant_id": args.tenant_id, "query_count": len(queries), "top_k": args.top_k, "score_metric": args.score_metric, "experiments": results, } summary_json_path = run_dir / "summary.json" summary_md_path = run_dir / "summary.md" summary_json_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") summary_md_path.write_text(render_markdown(summary), encoding="utf-8") print(f"[done] summary_json={summary_json_path}") print(f"[done] summary_md={summary_md_path}") if __name__ == "__main__": main()