c81b0fc1
tangwang
scripts/evaluatio...
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
"""HTTP clients for search API, reranker, and DashScope chat (relevance labeling)."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple
import requests
from .constants import VALID_LABELS
from .prompts import (
classify_batch_complex_prompt,
classify_batch_simple_prompt,
extract_query_profile_prompt,
)
from .utils import build_label_doc_line, extract_json_blob, safe_json_dumps
class SearchServiceClient:
def __init__(self, base_url: str, tenant_id: str):
self.base_url = base_url.rstrip("/")
self.tenant_id = str(tenant_id)
self.session = requests.Session()
|
167f33b4
tangwang
eval框架前端
|
24
25
26
27
28
29
30
31
32
|
def search(self, query: str, size: int, from_: int = 0, language: str = "en", *, debug: bool = False) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"query": query,
"size": size,
"from": from_,
"language": language,
}
if debug:
payload["debug"] = True
|
c81b0fc1
tangwang
scripts/evaluatio...
|
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
timeout=120,
)
response.raise_for_status()
return response.json()
class RerankServiceClient:
def __init__(self, service_url: str):
self.service_url = service_url.rstrip("/")
self.session = requests.Session()
def rerank(self, query: str, docs: Sequence[str], normalize: bool = False, top_n: Optional[int] = None) -> Tuple[List[float], Dict[str, Any]]:
payload: Dict[str, Any] = {
"query": query,
"docs": list(docs),
"normalize": normalize,
}
if top_n is not None:
payload["top_n"] = int(top_n)
response = self.session.post(self.service_url, json=payload, timeout=180)
response.raise_for_status()
data = response.json()
return list(data.get("scores") or []), dict(data.get("meta") or {})
class DashScopeLabelClient:
def __init__(self, model: str, base_url: str, api_key: str, batch_size: int = 40):
self.model = model
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.batch_size = int(batch_size)
self.session = requests.Session()
def _chat(self, prompt: str) -> Tuple[str, str]:
response = self.session.post(
f"{self.base_url}/chat/completions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0,
"top_p": 0.1,
},
timeout=180,
)
response.raise_for_status()
data = response.json()
content = str(((data.get("choices") or [{}])[0].get("message") or {}).get("content") or "").strip()
return content, safe_json_dumps(data)
def classify_batch_simple(
self,
query: str,
docs: Sequence[Dict[str, Any]],
) -> Tuple[List[str], str]:
numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)]
prompt = classify_batch_simple_prompt(query, numbered_docs)
content, raw_response = self._chat(prompt)
labels = []
for line in str(content or "").splitlines():
label = line.strip()
if label in VALID_LABELS:
labels.append(label)
if len(labels) != len(docs):
payload = extract_json_blob(content)
if isinstance(payload, dict) and isinstance(payload.get("labels"), list):
labels = []
for item in payload["labels"][: len(docs)]:
if isinstance(item, dict):
label = str(item.get("label") or "").strip()
else:
label = str(item).strip()
if label in VALID_LABELS:
labels.append(label)
if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels):
raise ValueError(f"unexpected simple label output: {content!r}")
return labels, raw_response
def extract_query_profile(
self,
query: str,
parser_hints: Dict[str, Any],
) -> Tuple[Dict[str, Any], str]:
prompt = extract_query_profile_prompt(query, parser_hints)
content, raw_response = self._chat(prompt)
payload = extract_json_blob(content)
if not isinstance(payload, dict):
raise ValueError(f"unexpected query profile payload: {content!r}")
payload.setdefault("normalized_query_en", query)
payload.setdefault("primary_category", "")
payload.setdefault("allowed_categories", [])
payload.setdefault("required_attributes", [])
payload.setdefault("notes", [])
return payload, raw_response
def classify_batch_complex(
self,
query: str,
query_profile: Dict[str, Any],
docs: Sequence[Dict[str, Any]],
) -> Tuple[List[str], str]:
numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)]
prompt = classify_batch_complex_prompt(query, query_profile, numbered_docs)
content, raw_response = self._chat(prompt)
payload = extract_json_blob(content)
if not isinstance(payload, dict) or not isinstance(payload.get("labels"), list):
raise ValueError(f"unexpected label payload: {content!r}")
labels_payload = payload["labels"]
labels: List[str] = []
for item in labels_payload[: len(docs)]:
if not isinstance(item, dict):
continue
label = str(item.get("label") or "").strip()
if label in VALID_LABELS:
labels.append(label)
if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels):
raise ValueError(f"unexpected label output: {content!r}")
return labels, raw_response
|