clients.py 9.94 KB
"""HTTP clients for search API, reranker, and DashScope chat (relevance labeling)."""

from __future__ import annotations

import io
import json
import time
import uuid
from typing import Any, Dict, List, Optional, Sequence, Tuple

import requests

from .constants import VALID_LABELS
from .prompts import classify_prompt
from .utils import build_label_doc_line, extract_json_blob, safe_json_dumps


def _canonicalize_judge_label(raw: str) -> str | None:
    s = str(raw or "").strip().strip('"').strip("'")
    if s in VALID_LABELS:
        return s
    low = s.lower()
    for v in VALID_LABELS:
        if v.lower() == low:
            return v
    return None


class SearchServiceClient:
    def __init__(self, base_url: str, tenant_id: str):
        self.base_url = base_url.rstrip("/")
        self.tenant_id = str(tenant_id)
        self.session = requests.Session()

    def search(self, query: str, size: int, from_: int = 0, language: str = "en", *, debug: bool = False) -> Dict[str, Any]:
        payload: Dict[str, Any] = {
            "query": query,
            "size": size,
            "from": from_,
            "language": language,
        }
        if debug:
            payload["debug"] = True
        response = self.session.post(
            f"{self.base_url}/search/",
            headers={"Content-Type": "application/json", "X-Tenant-ID": self.tenant_id},
            json=payload,
            timeout=120,
        )
        response.raise_for_status()
        return response.json()


class RerankServiceClient:
    def __init__(self, service_url: str):
        self.service_url = service_url.rstrip("/")
        self.session = requests.Session()

    def rerank(self, query: str, docs: Sequence[str], normalize: bool = False, top_n: Optional[int] = None) -> Tuple[List[float], Dict[str, Any]]:
        payload: Dict[str, Any] = {
            "query": query,
            "docs": list(docs),
            "normalize": normalize,
        }
        if top_n is not None:
            payload["top_n"] = int(top_n)
        response = self.session.post(self.service_url, json=payload, timeout=180)
        response.raise_for_status()
        data = response.json()
        return list(data.get("scores") or []), dict(data.get("meta") or {})


class DashScopeLabelClient:
    """DashScope OpenAI-compatible chat: synchronous or Batch File API (JSONL job).

    Batch flow: https://help.aliyun.com/zh/model-studio/batch-interfaces-compatible-with-openai/

    Some regional endpoints (e.g. ``dashscope-us`` compatible-mode) do not implement ``/batches``;
    on HTTP 404 from batch calls we fall back to synchronous ``/chat/completions`` and stop using batch
    for subsequent requests on this client.
    """

    def __init__(
        self,
        model: str,
        base_url: str,
        api_key: str,
        batch_size: int = 40,
        *,
        batch_completion_window: str = "24h",
        batch_poll_interval_sec: float = 10.0,
        enable_thinking: bool = True,
        use_batch: bool = False,
    ):
        self.model = model
        self.base_url = base_url.rstrip("/")
        self.api_key = api_key
        self.batch_size = int(batch_size)
        self.batch_completion_window = str(batch_completion_window)
        self.batch_poll_interval_sec = float(batch_poll_interval_sec)
        self.enable_thinking = bool(enable_thinking)
        self.use_batch = bool(use_batch)
        self.session = requests.Session()

    def _auth_headers(self) -> Dict[str, str]:
        return {"Authorization": f"Bearer {self.api_key}"}

    def _completion_body(self, prompt: str) -> Dict[str, Any]:
        body: Dict[str, Any] = {
            "model": self.model,
            "messages": [{"role": "user", "content": prompt}],
            "temperature": 0,
            "top_p": 0.1,
            "enable_thinking": self.enable_thinking,
        }
        return body

    def _chat_sync(self, prompt: str) -> Tuple[str, str]:
        response = self.session.post(
            f"{self.base_url}/chat/completions",
            headers={**self._auth_headers(), "Content-Type": "application/json"},
            json=self._completion_body(prompt),
            timeout=180,
        )
        response.raise_for_status()
        data = response.json()
        content = str(((data.get("choices") or [{}])[0].get("message") or {}).get("content") or "").strip()
        return content, safe_json_dumps(data)

    def _chat_batch(self, prompt: str) -> Tuple[str, str]:
        """One chat completion via Batch File API (single-line JSONL job)."""
        custom_id = uuid.uuid4().hex
        body = self._completion_body(prompt)
        line_obj = {
            "custom_id": custom_id,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": body,
        }
        jsonl = json.dumps(line_obj, ensure_ascii=False, separators=(",", ":")) + "\n"
        auth = self._auth_headers()

        up = self.session.post(
            f"{self.base_url}/files",
            headers=auth,
            files={
                "file": (
                    "eval_batch_input.jsonl",
                    io.BytesIO(jsonl.encode("utf-8")),
                    "application/octet-stream",
                )
            },
            data={"purpose": "batch"},
            timeout=300,
        )
        up.raise_for_status()
        file_id = (up.json() or {}).get("id")
        if not file_id:
            raise RuntimeError(f"DashScope file upload returned no id: {up.text!r}")

        cr = self.session.post(
            f"{self.base_url}/batches",
            headers={**auth, "Content-Type": "application/json"},
            json={
                "input_file_id": file_id,
                "endpoint": "/v1/chat/completions",
                "completion_window": self.batch_completion_window,
            },
            timeout=120,
        )
        cr.raise_for_status()
        batch_payload = cr.json() or {}
        batch_id = batch_payload.get("id")
        if not batch_id:
            raise RuntimeError(f"DashScope batches.create returned no id: {cr.text!r}")

        terminal = frozenset({"completed", "failed", "expired", "cancelled"})
        batch: Dict[str, Any] = dict(batch_payload)
        status = str(batch.get("status") or "")
        while status not in terminal:
            time.sleep(self.batch_poll_interval_sec)
            br = self.session.get(f"{self.base_url}/batches/{batch_id}", headers=auth, timeout=120)
            br.raise_for_status()
            batch = br.json() or {}
            status = str(batch.get("status") or "")

        if status != "completed":
            raise RuntimeError(
                f"DashScope batch {batch_id} ended with status={status!r} errors={batch.get('errors')!r}"
            )

        out_id = batch.get("output_file_id")
        err_id = batch.get("error_file_id")

        row = self._find_batch_line_for_custom_id(out_id, custom_id, auth)
        if row is None:
            err_row = self._find_batch_line_for_custom_id(err_id, custom_id, auth)
            if err_row is not None:
                raise RuntimeError(f"DashScope batch request failed: {err_row!r}")
            raise RuntimeError(f"DashScope batch output missing custom_id={custom_id!r}")

        resp = row.get("response") or {}
        sc = resp.get("status_code")
        if sc is not None and int(sc) != 200:
            raise RuntimeError(f"DashScope batch line error: {row!r}")

        data = resp.get("body") or {}
        content = str(((data.get("choices") or [{}])[0].get("message") or {}).get("content") or "").strip()
        return content, safe_json_dumps(row)

    def _chat(self, prompt: str) -> Tuple[str, str]:
        if not self.use_batch:
            return self._chat_sync(prompt)
        try:
            return self._chat_batch(prompt)
        except requests.exceptions.HTTPError as e:
            resp = getattr(e, "response", None)
            if resp is not None and resp.status_code == 404:
                self.use_batch = False
                return self._chat_sync(prompt)
            raise

    def _find_batch_line_for_custom_id(
        self,
        file_id: Optional[str],
        custom_id: str,
        auth: Dict[str, str],
    ) -> Optional[Dict[str, Any]]:
        if not file_id or str(file_id) in ("null", ""):
            return None
        r = self.session.get(f"{self.base_url}/files/{file_id}/content", headers=auth, timeout=300)
        r.raise_for_status()
        for raw in r.text.splitlines():
            raw = raw.strip()
            if not raw:
                continue
            try:
                obj = json.loads(raw)
            except json.JSONDecodeError:
                continue
            if str(obj.get("custom_id")) == custom_id:
                return obj
        return None

    def classify_batch(
        self,
        query: str,
        docs: Sequence[Dict[str, Any]],
    ) -> Tuple[List[str], str]:
        numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)]
        prompt = classify_prompt(query, numbered_docs)
        content, raw_response = self._chat(prompt)
        labels: List[str] = []
        for line in str(content or "").splitlines():
            canon = _canonicalize_judge_label(line)
            if canon is not None:
                labels.append(canon)
        if len(labels) != len(docs):
            payload = extract_json_blob(content)
            if isinstance(payload, dict) and isinstance(payload.get("labels"), list):
                labels = []
                for item in payload["labels"][: len(docs)]:
                    if isinstance(item, dict):
                        raw_l = str(item.get("label") or "").strip()
                    else:
                        raw_l = str(item).strip()
                    canon = _canonicalize_judge_label(raw_l)
                    if canon is not None:
                        labels.append(canon)
        if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels):
            raise ValueError(f"unexpected classify output: {content!r}")
        return labels, raw_response