datasets.py 5.64 KB
"""Evaluation dataset registry helpers and artifact path conventions."""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence

from config.loader import get_app_config
from config.schema import SearchEvaluationDatasetConfig

from .utils import ensure_dir, sha1_text


@dataclass(frozen=True)
class EvalDatasetSnapshot:
    """Resolved dataset metadata for one evaluation run."""

    dataset_id: str
    display_name: str
    description: str
    query_file: Path
    tenant_id: str
    language: str
    enabled: bool
    queries: tuple[str, ...]
    query_count: int
    query_sha1: str
    source: str = "registry"

    def summary(self) -> Dict[str, Any]:
        return {
            "dataset_id": self.dataset_id,
            "display_name": self.display_name,
            "description": self.description,
            "query_file": str(self.query_file),
            "tenant_id": self.tenant_id,
            "language": self.language,
            "enabled": self.enabled,
            "query_count": self.query_count,
            "query_sha1": self.query_sha1,
            "source": self.source,
        }


def read_queries_file(path: Path) -> List[str]:
    return [
        line.strip()
        for line in path.read_text(encoding="utf-8").splitlines()
        if line.strip() and not line.strip().startswith("#")
    ]


def query_sha1(queries: Sequence[str]) -> str:
    return sha1_text("\n".join(str(item).strip() for item in queries if str(item).strip()))


def _enabled_datasets(datasets: Iterable[SearchEvaluationDatasetConfig]) -> List[SearchEvaluationDatasetConfig]:
    return [item for item in datasets if item.enabled]


def list_registered_datasets(enabled_only: bool = False) -> List[SearchEvaluationDatasetConfig]:
    se = get_app_config().search_evaluation
    datasets = list(se.datasets)
    return _enabled_datasets(datasets) if enabled_only else datasets


def resolve_registered_dataset(dataset_id: str) -> SearchEvaluationDatasetConfig:
    for item in list_registered_datasets(enabled_only=False):
        if item.dataset_id == dataset_id:
            return item
    raise KeyError(f"unknown evaluation dataset: {dataset_id}")


def resolve_dataset(
    *,
    dataset_id: Optional[str] = None,
    query_file: Optional[Path] = None,
    tenant_id: Optional[str] = None,
    language: Optional[str] = None,
    require_enabled: bool = False,
) -> EvalDatasetSnapshot:
    se = get_app_config().search_evaluation
    registered = list_registered_datasets(enabled_only=False)
    selected: Optional[SearchEvaluationDatasetConfig] = None

    if dataset_id:
        selected = resolve_registered_dataset(dataset_id)
    elif query_file is not None:
        normalized = query_file.resolve()
        for item in registered:
            if item.query_file.resolve() == normalized:
                selected = item
                break
    else:
        selected = resolve_registered_dataset(se.default_dataset_id)

    if selected is None:
        path = (query_file or se.queries_file).resolve()
        queries = tuple(read_queries_file(path))
        derived_id = dataset_id or f"adhoc_{sha1_text(str(path))[:12]}"
        effective_tenant = str(tenant_id or se.default_tenant_id)
        effective_language = str(language or se.default_language)
        return EvalDatasetSnapshot(
            dataset_id=derived_id,
            display_name=path.name,
            description="Ad-hoc evaluation dataset from explicit query file",
            query_file=path,
            tenant_id=effective_tenant,
            language=effective_language,
            enabled=True,
            queries=queries,
            query_count=len(queries),
            query_sha1=query_sha1(queries),
            source="adhoc",
        )

    if require_enabled and not selected.enabled:
        raise ValueError(f"evaluation dataset is disabled: {selected.dataset_id}")

    effective_tenant = str(tenant_id or selected.tenant_id or se.default_tenant_id)
    effective_language = str(language or selected.language or se.default_language)
    queries = tuple(read_queries_file(selected.query_file))
    return EvalDatasetSnapshot(
        dataset_id=selected.dataset_id,
        display_name=selected.display_name,
        description=selected.description,
        query_file=selected.query_file.resolve(),
        tenant_id=effective_tenant,
        language=effective_language,
        enabled=selected.enabled,
        queries=queries,
        query_count=len(queries),
        query_sha1=query_sha1(queries),
        source="registry",
    )


def infer_dataset_id_from_queries(queries: Sequence[str]) -> Optional[str]:
    target_sha = query_sha1(queries)
    for item in list_registered_datasets(enabled_only=False):
        snapshot = resolve_dataset(dataset_id=item.dataset_id)
        if snapshot.query_sha1 == target_sha:
            return snapshot.dataset_id
    return None


def artifact_dataset_root(artifact_root: Path, dataset_id: str) -> Path:
    return ensure_dir(artifact_root / "datasets" / dataset_id)


def query_builds_dir(artifact_root: Path, dataset_id: str) -> Path:
    return ensure_dir(artifact_dataset_root(artifact_root, dataset_id) / "query_builds")


def batch_reports_root(artifact_root: Path, dataset_id: str) -> Path:
    return ensure_dir(artifact_dataset_root(artifact_root, dataset_id) / "batch_reports")


def batch_report_run_dir(artifact_root: Path, dataset_id: str, batch_id: str) -> Path:
    return ensure_dir(batch_reports_root(artifact_root, dataset_id) / batch_id)


def audits_dir(artifact_root: Path, dataset_id: str) -> Path:
    return ensure_dir(artifact_dataset_root(artifact_root, dataset_id) / "audits")