"""HTTP clients for search API, reranker, and DashScope chat (relevance labeling).""" from __future__ import annotations import io import json import time import uuid from typing import Any, Dict, List, Optional, Sequence, Tuple import requests from .constants import VALID_LABELS from .prompts import ( classify_batch_complex_prompt, classify_batch_simple_prompt, extract_query_profile_prompt, ) from .utils import build_label_doc_line, extract_json_blob, safe_json_dumps class SearchServiceClient: def __init__(self, base_url: str, tenant_id: str): self.base_url = base_url.rstrip("/") self.tenant_id = str(tenant_id) self.session = requests.Session() def search(self, query: str, size: int, from_: int = 0, language: str = "en", *, debug: bool = False) -> Dict[str, Any]: payload: Dict[str, Any] = { "query": query, "size": size, "from": from_, "language": language, } if debug: payload["debug"] = True response = self.session.post( f"{self.base_url}/search/", headers={"Content-Type": "application/json", "X-Tenant-ID": self.tenant_id}, json=payload, timeout=120, ) response.raise_for_status() return response.json() class RerankServiceClient: def __init__(self, service_url: str): self.service_url = service_url.rstrip("/") self.session = requests.Session() def rerank(self, query: str, docs: Sequence[str], normalize: bool = False, top_n: Optional[int] = None) -> Tuple[List[float], Dict[str, Any]]: payload: Dict[str, Any] = { "query": query, "docs": list(docs), "normalize": normalize, } if top_n is not None: payload["top_n"] = int(top_n) response = self.session.post(self.service_url, json=payload, timeout=180) response.raise_for_status() data = response.json() return list(data.get("scores") or []), dict(data.get("meta") or {}) class DashScopeLabelClient: """DashScope OpenAI-compatible chat: synchronous or Batch File API (JSONL job). Batch flow: https://help.aliyun.com/zh/model-studio/batch-interfaces-compatible-with-openai/ """ def __init__( self, model: str, base_url: str, api_key: str, batch_size: int = 40, *, batch_completion_window: str = "24h", batch_poll_interval_sec: float = 10.0, enable_thinking: bool = True, use_batch: bool = True, ): self.model = model self.base_url = base_url.rstrip("/") self.api_key = api_key self.batch_size = int(batch_size) self.batch_completion_window = str(batch_completion_window) self.batch_poll_interval_sec = float(batch_poll_interval_sec) self.enable_thinking = bool(enable_thinking) self.use_batch = bool(use_batch) self.session = requests.Session() def _auth_headers(self) -> Dict[str, str]: return {"Authorization": f"Bearer {self.api_key}"} def _completion_body(self, prompt: str) -> Dict[str, Any]: body: Dict[str, Any] = { "model": self.model, "messages": [{"role": "user", "content": prompt}], "temperature": 0, "top_p": 0.1, "enable_thinking": self.enable_thinking, } return body def _chat_sync(self, prompt: str) -> Tuple[str, str]: response = self.session.post( f"{self.base_url}/chat/completions", headers={**self._auth_headers(), "Content-Type": "application/json"}, json=self._completion_body(prompt), timeout=180, ) response.raise_for_status() data = response.json() content = str(((data.get("choices") or [{}])[0].get("message") or {}).get("content") or "").strip() return content, safe_json_dumps(data) def _chat_batch(self, prompt: str) -> Tuple[str, str]: """One chat completion via Batch File API (single-line JSONL job).""" custom_id = uuid.uuid4().hex body = self._completion_body(prompt) line_obj = { "custom_id": custom_id, "method": "POST", "url": "/v1/chat/completions", "body": body, } jsonl = json.dumps(line_obj, ensure_ascii=False, separators=(",", ":")) + "\n" auth = self._auth_headers() up = self.session.post( f"{self.base_url}/files", headers=auth, files={ "file": ( "eval_batch_input.jsonl", io.BytesIO(jsonl.encode("utf-8")), "application/octet-stream", ) }, data={"purpose": "batch"}, timeout=300, ) up.raise_for_status() file_id = (up.json() or {}).get("id") if not file_id: raise RuntimeError(f"DashScope file upload returned no id: {up.text!r}") cr = self.session.post( f"{self.base_url}/batches", headers={**auth, "Content-Type": "application/json"}, json={ "input_file_id": file_id, "endpoint": "/v1/chat/completions", "completion_window": self.batch_completion_window, }, timeout=120, ) cr.raise_for_status() batch_payload = cr.json() or {} batch_id = batch_payload.get("id") if not batch_id: raise RuntimeError(f"DashScope batches.create returned no id: {cr.text!r}") terminal = frozenset({"completed", "failed", "expired", "cancelled"}) batch: Dict[str, Any] = dict(batch_payload) status = str(batch.get("status") or "") while status not in terminal: time.sleep(self.batch_poll_interval_sec) br = self.session.get(f"{self.base_url}/batches/{batch_id}", headers=auth, timeout=120) br.raise_for_status() batch = br.json() or {} status = str(batch.get("status") or "") if status != "completed": raise RuntimeError( f"DashScope batch {batch_id} ended with status={status!r} errors={batch.get('errors')!r}" ) out_id = batch.get("output_file_id") err_id = batch.get("error_file_id") row = self._find_batch_line_for_custom_id(out_id, custom_id, auth) if row is None: err_row = self._find_batch_line_for_custom_id(err_id, custom_id, auth) if err_row is not None: raise RuntimeError(f"DashScope batch request failed: {err_row!r}") raise RuntimeError(f"DashScope batch output missing custom_id={custom_id!r}") resp = row.get("response") or {} sc = resp.get("status_code") if sc is not None and int(sc) != 200: raise RuntimeError(f"DashScope batch line error: {row!r}") data = resp.get("body") or {} content = str(((data.get("choices") or [{}])[0].get("message") or {}).get("content") or "").strip() return content, safe_json_dumps(row) def _chat(self, prompt: str) -> Tuple[str, str]: if self.use_batch: return self._chat_batch(prompt) return self._chat_sync(prompt) def _find_batch_line_for_custom_id( self, file_id: Optional[str], custom_id: str, auth: Dict[str, str], ) -> Optional[Dict[str, Any]]: if not file_id or str(file_id) in ("null", ""): return None r = self.session.get(f"{self.base_url}/files/{file_id}/content", headers=auth, timeout=300) r.raise_for_status() for raw in r.text.splitlines(): raw = raw.strip() if not raw: continue try: obj = json.loads(raw) except json.JSONDecodeError: continue if str(obj.get("custom_id")) == custom_id: return obj return None def classify_batch_simple( self, query: str, docs: Sequence[Dict[str, Any]], ) -> Tuple[List[str], str]: numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)] prompt = classify_batch_simple_prompt(query, numbered_docs) content, raw_response = self._chat(prompt) labels = [] for line in str(content or "").splitlines(): label = line.strip() if label in VALID_LABELS: labels.append(label) if len(labels) != len(docs): payload = extract_json_blob(content) if isinstance(payload, dict) and isinstance(payload.get("labels"), list): labels = [] for item in payload["labels"][: len(docs)]: if isinstance(item, dict): label = str(item.get("label") or "").strip() else: label = str(item).strip() if label in VALID_LABELS: labels.append(label) if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels): raise ValueError(f"unexpected simple label output: {content!r}") return labels, raw_response def extract_query_profile( self, query: str, parser_hints: Dict[str, Any], ) -> Tuple[Dict[str, Any], str]: prompt = extract_query_profile_prompt(query, parser_hints) content, raw_response = self._chat(prompt) payload = extract_json_blob(content) if not isinstance(payload, dict): raise ValueError(f"unexpected query profile payload: {content!r}") payload.setdefault("normalized_query_en", query) payload.setdefault("primary_category", "") payload.setdefault("allowed_categories", []) payload.setdefault("required_attributes", []) payload.setdefault("notes", []) return payload, raw_response def classify_batch_complex( self, query: str, query_profile: Dict[str, Any], docs: Sequence[Dict[str, Any]], ) -> Tuple[List[str], str]: numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)] prompt = classify_batch_complex_prompt(query, query_profile, numbered_docs) content, raw_response = self._chat(prompt) payload = extract_json_blob(content) if not isinstance(payload, dict) or not isinstance(payload.get("labels"), list): raise ValueError(f"unexpected label payload: {content!r}") labels_payload = payload["labels"] labels: List[str] = [] for item in labels_payload[: len(docs)]: if not isinstance(item, dict): continue label = str(item.get("label") or "").strip() if label in VALID_LABELS: labels.append(label) if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels): raise ValueError(f"unexpected label output: {content!r}") return labels, raw_response