web_app.py 5.18 KB
"""FastAPI app for the search evaluation UI (static frontend + JSON APIs)."""

from __future__ import annotations

from pathlib import Path
from typing import Any, Dict

from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles

from .api_models import BatchEvalRequest, SearchEvalRequest
from .constants import DEFAULT_QUERY_FILE
from .datasets import list_registered_datasets, resolve_dataset
from .framework import SearchEvaluationFramework

_STATIC_DIR = Path(__file__).resolve().parent / "static"


def create_web_app(framework: SearchEvaluationFramework, initial_dataset_id: str | None = None) -> FastAPI:
    app = FastAPI(title="Search Evaluation UI", version="1.0.0")
    current_dataset_id = initial_dataset_id or "core_queries"

    app.mount(
        "/static",
        StaticFiles(directory=str(_STATIC_DIR)),
        name="static",
    )

    index_path = _STATIC_DIR / "index.html"

    @app.get("/", response_class=HTMLResponse)
    def home() -> str:
        return index_path.read_text(encoding="utf-8")

    @app.get("/api/datasets")
    def api_datasets() -> Dict[str, Any]:
        stats_by_query = {item["query"]: item for item in framework.store.list_query_label_stats(framework.tenant_id)}
        datasets = []
        for item in list_registered_datasets(enabled_only=True):
            snapshot = resolve_dataset(dataset_id=item.dataset_id, tenant_id=framework.tenant_id)
            labeled_queries = sum(1 for query in snapshot.queries if (stats_by_query.get(query) or {}).get("total", 0) > 0)
            datasets.append(
                {
                    **snapshot.summary(),
                    "coverage_summary": {
                        "labeled_queries": labeled_queries,
                        "coverage_ratio": (labeled_queries / snapshot.query_count) if snapshot.query_count else 0.0,
                    },
                }
            )
        return {"datasets": datasets, "current_dataset_id": current_dataset_id}

    @app.get("/api/datasets/{dataset_id}/queries")
    def api_dataset_queries(dataset_id: str) -> Dict[str, Any]:
        dataset = resolve_dataset(dataset_id=dataset_id, tenant_id=framework.tenant_id, require_enabled=True)
        return {"dataset": dataset.summary(), "queries": list(dataset.queries)}

    @app.get("/api/queries")
    def api_queries(dataset_id: str | None = None) -> Dict[str, Any]:
        dataset = resolve_dataset(dataset_id=dataset_id or current_dataset_id, tenant_id=framework.tenant_id)
        return {"dataset": dataset.summary(), "queries": list(dataset.queries)}

    @app.post("/api/search-eval")
    def api_search_eval(request: SearchEvalRequest) -> Dict[str, Any]:
        dataset = resolve_dataset(
            dataset_id=request.dataset_id or current_dataset_id,
            tenant_id=framework.tenant_id,
            language=request.language,
        )
        return framework.evaluate_live_query(
            query=request.query,
            top_k=request.top_k,
            auto_annotate=request.auto_annotate,
            language=dataset.language,
            dataset=dataset,
        )

    @app.post("/api/batch-eval")
    def api_batch_eval(request: BatchEvalRequest) -> Dict[str, Any]:
        dataset = resolve_dataset(
            dataset_id=request.dataset_id or current_dataset_id,
            tenant_id=framework.tenant_id,
            language=request.language,
        )
        queries = request.queries or list(dataset.queries)
        if not queries:
            raise HTTPException(status_code=400, detail="No queries provided")
        return framework.batch_evaluate(
            queries=queries,
            dataset=dataset,
            top_k=request.top_k,
            auto_annotate=request.auto_annotate,
            language=dataset.language,
            force_refresh_labels=request.force_refresh_labels,
        )

    @app.get("/api/history")
    def api_history(dataset_id: str | None = None, limit: int = 20) -> Dict[str, Any]:
        effective_dataset_id = dataset_id or current_dataset_id
        return {
            "history": framework.store.list_batch_runs(limit=limit, dataset_id=effective_dataset_id),
            "dataset_id": effective_dataset_id,
        }

    @app.get("/api/history/{batch_id}/report")
    def api_history_report(batch_id: str) -> Dict[str, Any]:
        row = framework.store.get_batch_run(batch_id)
        if row is None:
            raise HTTPException(status_code=404, detail="Unknown batch_id")
        report_path = Path(row["report_markdown_path"]).resolve()
        root = framework.artifact_root.resolve()
        try:
            report_path.relative_to(root)
        except ValueError:
            raise HTTPException(status_code=403, detail="Report path is outside artifact root")
        if not report_path.is_file():
            raise HTTPException(status_code=404, detail="Report file not found")
        return {
            "batch_id": row["batch_id"],
            "created_at": row["created_at"],
            "tenant_id": row["tenant_id"],
            "dataset": row["metadata"].get("dataset") or {},
            "report_markdown_path": str(report_path),
            "markdown": report_path.read_text(encoding="utf-8"),
        }

    return app