"""HTTP clients for search API, reranker, and DashScope chat (relevance labeling).""" from __future__ import annotations from typing import Any, Dict, List, Optional, Sequence, Tuple import requests from .constants import VALID_LABELS from .prompts import ( classify_batch_complex_prompt, classify_batch_simple_prompt, extract_query_profile_prompt, ) from .utils import build_label_doc_line, extract_json_blob, safe_json_dumps class SearchServiceClient: def __init__(self, base_url: str, tenant_id: str): self.base_url = base_url.rstrip("/") self.tenant_id = str(tenant_id) self.session = requests.Session() def search(self, query: str, size: int, from_: int = 0, language: str = "en") -> Dict[str, Any]: response = self.session.post( f"{self.base_url}/search/", headers={"Content-Type": "application/json", "X-Tenant-ID": self.tenant_id}, json={"query": query, "size": size, "from": from_, "language": language}, timeout=120, ) response.raise_for_status() return response.json() class RerankServiceClient: def __init__(self, service_url: str): self.service_url = service_url.rstrip("/") self.session = requests.Session() def rerank(self, query: str, docs: Sequence[str], normalize: bool = False, top_n: Optional[int] = None) -> Tuple[List[float], Dict[str, Any]]: payload: Dict[str, Any] = { "query": query, "docs": list(docs), "normalize": normalize, } if top_n is not None: payload["top_n"] = int(top_n) response = self.session.post(self.service_url, json=payload, timeout=180) response.raise_for_status() data = response.json() return list(data.get("scores") or []), dict(data.get("meta") or {}) class DashScopeLabelClient: def __init__(self, model: str, base_url: str, api_key: str, batch_size: int = 40): self.model = model self.base_url = base_url.rstrip("/") self.api_key = api_key self.batch_size = int(batch_size) self.session = requests.Session() def _chat(self, prompt: str) -> Tuple[str, str]: response = self.session.post( f"{self.base_url}/chat/completions", headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, json={ "model": self.model, "messages": [{"role": "user", "content": prompt}], "temperature": 0, "top_p": 0.1, }, timeout=180, ) response.raise_for_status() data = response.json() content = str(((data.get("choices") or [{}])[0].get("message") or {}).get("content") or "").strip() return content, safe_json_dumps(data) def classify_batch_simple( self, query: str, docs: Sequence[Dict[str, Any]], ) -> Tuple[List[str], str]: numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)] prompt = classify_batch_simple_prompt(query, numbered_docs) content, raw_response = self._chat(prompt) labels = [] for line in str(content or "").splitlines(): label = line.strip() if label in VALID_LABELS: labels.append(label) if len(labels) != len(docs): payload = extract_json_blob(content) if isinstance(payload, dict) and isinstance(payload.get("labels"), list): labels = [] for item in payload["labels"][: len(docs)]: if isinstance(item, dict): label = str(item.get("label") or "").strip() else: label = str(item).strip() if label in VALID_LABELS: labels.append(label) if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels): raise ValueError(f"unexpected simple label output: {content!r}") return labels, raw_response def extract_query_profile( self, query: str, parser_hints: Dict[str, Any], ) -> Tuple[Dict[str, Any], str]: prompt = extract_query_profile_prompt(query, parser_hints) content, raw_response = self._chat(prompt) payload = extract_json_blob(content) if not isinstance(payload, dict): raise ValueError(f"unexpected query profile payload: {content!r}") payload.setdefault("normalized_query_en", query) payload.setdefault("primary_category", "") payload.setdefault("allowed_categories", []) payload.setdefault("required_attributes", []) payload.setdefault("notes", []) return payload, raw_response def classify_batch_complex( self, query: str, query_profile: Dict[str, Any], docs: Sequence[Dict[str, Any]], ) -> Tuple[List[str], str]: numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)] prompt = classify_batch_complex_prompt(query, query_profile, numbered_docs) content, raw_response = self._chat(prompt) payload = extract_json_blob(content) if not isinstance(payload, dict) or not isinstance(payload.get("labels"), list): raise ValueError(f"unexpected label payload: {content!r}") labels_payload = payload["labels"] labels: List[str] = [] for item in labels_payload[: len(docs)]: if not isinstance(item, dict): continue label = str(item.get("label") or "").strip() if label in VALID_LABELS: labels.append(label) if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels): raise ValueError(f"unexpected label output: {content!r}") return labels, raw_response