datasets.py
5.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""Evaluation dataset registry helpers and artifact path conventions."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence
from config.loader import get_app_config
from config.schema import SearchEvaluationDatasetConfig
from .utils import ensure_dir, sha1_text
@dataclass(frozen=True)
class EvalDatasetSnapshot:
"""Resolved dataset metadata for one evaluation run."""
dataset_id: str
display_name: str
description: str
query_file: Path
tenant_id: str
language: str
enabled: bool
queries: tuple[str, ...]
query_count: int
query_sha1: str
source: str = "registry"
def summary(self) -> Dict[str, Any]:
return {
"dataset_id": self.dataset_id,
"display_name": self.display_name,
"description": self.description,
"query_file": str(self.query_file),
"tenant_id": self.tenant_id,
"language": self.language,
"enabled": self.enabled,
"query_count": self.query_count,
"query_sha1": self.query_sha1,
"source": self.source,
}
def read_queries_file(path: Path) -> List[str]:
return [
line.strip()
for line in path.read_text(encoding="utf-8").splitlines()
if line.strip() and not line.strip().startswith("#")
]
def query_sha1(queries: Sequence[str]) -> str:
return sha1_text("\n".join(str(item).strip() for item in queries if str(item).strip()))
def _enabled_datasets(datasets: Iterable[SearchEvaluationDatasetConfig]) -> List[SearchEvaluationDatasetConfig]:
return [item for item in datasets if item.enabled]
def list_registered_datasets(enabled_only: bool = False) -> List[SearchEvaluationDatasetConfig]:
se = get_app_config().search_evaluation
datasets = list(se.datasets)
return _enabled_datasets(datasets) if enabled_only else datasets
def resolve_registered_dataset(dataset_id: str) -> SearchEvaluationDatasetConfig:
for item in list_registered_datasets(enabled_only=False):
if item.dataset_id == dataset_id:
return item
raise KeyError(f"unknown evaluation dataset: {dataset_id}")
def resolve_dataset(
*,
dataset_id: Optional[str] = None,
query_file: Optional[Path] = None,
tenant_id: Optional[str] = None,
language: Optional[str] = None,
require_enabled: bool = False,
) -> EvalDatasetSnapshot:
se = get_app_config().search_evaluation
registered = list_registered_datasets(enabled_only=False)
selected: Optional[SearchEvaluationDatasetConfig] = None
if dataset_id:
selected = resolve_registered_dataset(dataset_id)
elif query_file is not None:
normalized = query_file.resolve()
for item in registered:
if item.query_file.resolve() == normalized:
selected = item
break
else:
selected = resolve_registered_dataset(se.default_dataset_id)
if selected is None:
path = (query_file or se.queries_file).resolve()
queries = tuple(read_queries_file(path))
derived_id = dataset_id or f"adhoc_{sha1_text(str(path))[:12]}"
effective_tenant = str(tenant_id or se.default_tenant_id)
effective_language = str(language or se.default_language)
return EvalDatasetSnapshot(
dataset_id=derived_id,
display_name=path.name,
description="Ad-hoc evaluation dataset from explicit query file",
query_file=path,
tenant_id=effective_tenant,
language=effective_language,
enabled=True,
queries=queries,
query_count=len(queries),
query_sha1=query_sha1(queries),
source="adhoc",
)
if require_enabled and not selected.enabled:
raise ValueError(f"evaluation dataset is disabled: {selected.dataset_id}")
effective_tenant = str(tenant_id or selected.tenant_id or se.default_tenant_id)
effective_language = str(language or selected.language or se.default_language)
queries = tuple(read_queries_file(selected.query_file))
return EvalDatasetSnapshot(
dataset_id=selected.dataset_id,
display_name=selected.display_name,
description=selected.description,
query_file=selected.query_file.resolve(),
tenant_id=effective_tenant,
language=effective_language,
enabled=selected.enabled,
queries=queries,
query_count=len(queries),
query_sha1=query_sha1(queries),
source="registry",
)
def infer_dataset_id_from_queries(queries: Sequence[str]) -> Optional[str]:
target_sha = query_sha1(queries)
for item in list_registered_datasets(enabled_only=False):
snapshot = resolve_dataset(dataset_id=item.dataset_id)
if snapshot.query_sha1 == target_sha:
return snapshot.dataset_id
return None
def artifact_dataset_root(artifact_root: Path, dataset_id: str) -> Path:
return ensure_dir(artifact_root / "datasets" / dataset_id)
def query_builds_dir(artifact_root: Path, dataset_id: str) -> Path:
return ensure_dir(artifact_dataset_root(artifact_root, dataset_id) / "query_builds")
def batch_reports_root(artifact_root: Path, dataset_id: str) -> Path:
return ensure_dir(artifact_dataset_root(artifact_root, dataset_id) / "batch_reports")
def batch_report_run_dir(artifact_root: Path, dataset_id: str, batch_id: str) -> Path:
return ensure_dir(batch_reports_root(artifact_root, dataset_id) / batch_id)
def audits_dir(artifact_root: Path, dataset_id: str) -> Path:
return ensure_dir(artifact_dataset_root(artifact_root, dataset_id) / "audits")