Blame view

scripts/evaluation/eval_framework/clients.py 5.9 KB
c81b0fc1   tangwang   scripts/evaluatio...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
  """HTTP clients for search API, reranker, and DashScope chat (relevance labeling)."""
  
  from __future__ import annotations
  
  from typing import Any, Dict, List, Optional, Sequence, Tuple
  
  import requests
  
  from .constants import VALID_LABELS
  from .prompts import (
      classify_batch_complex_prompt,
      classify_batch_simple_prompt,
      extract_query_profile_prompt,
  )
  from .utils import build_label_doc_line, extract_json_blob, safe_json_dumps
  
  
  class SearchServiceClient:
      def __init__(self, base_url: str, tenant_id: str):
          self.base_url = base_url.rstrip("/")
          self.tenant_id = str(tenant_id)
          self.session = requests.Session()
  
      def search(self, query: str, size: int, from_: int = 0, language: str = "en") -> Dict[str, Any]:
          response = self.session.post(
              f"{self.base_url}/search/",
              headers={"Content-Type": "application/json", "X-Tenant-ID": self.tenant_id},
              json={"query": query, "size": size, "from": from_, "language": language},
              timeout=120,
          )
          response.raise_for_status()
          return response.json()
  
  
  class RerankServiceClient:
      def __init__(self, service_url: str):
          self.service_url = service_url.rstrip("/")
          self.session = requests.Session()
  
      def rerank(self, query: str, docs: Sequence[str], normalize: bool = False, top_n: Optional[int] = None) -> Tuple[List[float], Dict[str, Any]]:
          payload: Dict[str, Any] = {
              "query": query,
              "docs": list(docs),
              "normalize": normalize,
          }
          if top_n is not None:
              payload["top_n"] = int(top_n)
          response = self.session.post(self.service_url, json=payload, timeout=180)
          response.raise_for_status()
          data = response.json()
          return list(data.get("scores") or []), dict(data.get("meta") or {})
  
  
  class DashScopeLabelClient:
      def __init__(self, model: str, base_url: str, api_key: str, batch_size: int = 40):
          self.model = model
          self.base_url = base_url.rstrip("/")
          self.api_key = api_key
          self.batch_size = int(batch_size)
          self.session = requests.Session()
  
      def _chat(self, prompt: str) -> Tuple[str, str]:
          response = self.session.post(
              f"{self.base_url}/chat/completions",
              headers={
                  "Authorization": f"Bearer {self.api_key}",
                  "Content-Type": "application/json",
              },
              json={
                  "model": self.model,
                  "messages": [{"role": "user", "content": prompt}],
                  "temperature": 0,
                  "top_p": 0.1,
              },
              timeout=180,
          )
          response.raise_for_status()
          data = response.json()
          content = str(((data.get("choices") or [{}])[0].get("message") or {}).get("content") or "").strip()
          return content, safe_json_dumps(data)
  
      def classify_batch_simple(
          self,
          query: str,
          docs: Sequence[Dict[str, Any]],
      ) -> Tuple[List[str], str]:
          numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)]
          prompt = classify_batch_simple_prompt(query, numbered_docs)
          content, raw_response = self._chat(prompt)
          labels = []
          for line in str(content or "").splitlines():
              label = line.strip()
              if label in VALID_LABELS:
                  labels.append(label)
          if len(labels) != len(docs):
              payload = extract_json_blob(content)
              if isinstance(payload, dict) and isinstance(payload.get("labels"), list):
                  labels = []
                  for item in payload["labels"][: len(docs)]:
                      if isinstance(item, dict):
                          label = str(item.get("label") or "").strip()
                      else:
                          label = str(item).strip()
                      if label in VALID_LABELS:
                          labels.append(label)
          if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels):
              raise ValueError(f"unexpected simple label output: {content!r}")
          return labels, raw_response
  
      def extract_query_profile(
          self,
          query: str,
          parser_hints: Dict[str, Any],
      ) -> Tuple[Dict[str, Any], str]:
          prompt = extract_query_profile_prompt(query, parser_hints)
          content, raw_response = self._chat(prompt)
          payload = extract_json_blob(content)
          if not isinstance(payload, dict):
              raise ValueError(f"unexpected query profile payload: {content!r}")
          payload.setdefault("normalized_query_en", query)
          payload.setdefault("primary_category", "")
          payload.setdefault("allowed_categories", [])
          payload.setdefault("required_attributes", [])
          payload.setdefault("notes", [])
          return payload, raw_response
  
      def classify_batch_complex(
          self,
          query: str,
          query_profile: Dict[str, Any],
          docs: Sequence[Dict[str, Any]],
      ) -> Tuple[List[str], str]:
          numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)]
          prompt = classify_batch_complex_prompt(query, query_profile, numbered_docs)
          content, raw_response = self._chat(prompt)
          payload = extract_json_blob(content)
          if not isinstance(payload, dict) or not isinstance(payload.get("labels"), list):
              raise ValueError(f"unexpected label payload: {content!r}")
          labels_payload = payload["labels"]
          labels: List[str] = []
          for item in labels_payload[: len(docs)]:
              if not isinstance(item, dict):
                  continue
              label = str(item.get("label") or "").strip()
              if label in VALID_LABELS:
                  labels.append(label)
          if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels):
              raise ValueError(f"unexpected label output: {content!r}")
          return labels, raw_response