clients.py 12.5 KB
"""HTTP clients for search API, reranker, and DashScope chat (relevance labeling)."""

from __future__ import annotations

import io
import json
import logging
import threading
import time
import uuid
from typing import Any, Dict, List, Optional, Sequence, Tuple

import requests

from .constants import EVAL_VERBOSE_LOG_FILE, VALID_LABELS
from .logging_setup import setup_eval_logging
from .prompts import classify_prompt, intent_analysis_prompt
from .utils import build_label_doc_line, extract_json_blob, safe_json_dumps

_VERBOSE_LOGGER_LOCK = threading.Lock()
_eval_llm_verbose_logger_singleton: logging.Logger | None = None
_eval_llm_verbose_path_logged = False


def _get_eval_llm_verbose_logger() -> logging.Logger:
    """File logger for full LLM prompts/responses → ``logs/verbose/eval_verbose.log``."""
    setup_eval_logging()
    global _eval_llm_verbose_logger_singleton, _eval_llm_verbose_path_logged
    with _VERBOSE_LOGGER_LOCK:
        if _eval_llm_verbose_logger_singleton is not None:
            return _eval_llm_verbose_logger_singleton
        log_path = EVAL_VERBOSE_LOG_FILE
        log_path.parent.mkdir(parents=True, exist_ok=True)
        lg = logging.getLogger("search_eval.verbose_llm")
        lg.setLevel(logging.INFO)
        if not lg.handlers:
            handler = logging.FileHandler(log_path, encoding="utf-8")
            handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
            lg.addHandler(handler)
            lg.propagate = False
        _eval_llm_verbose_logger_singleton = lg
        if not _eval_llm_verbose_path_logged:
            _eval_llm_verbose_path_logged = True
            logging.getLogger("search_eval").info(
                "LLM verbose I/O log (full prompt + response): %s",
                log_path.resolve(),
            )
        return lg


def _log_eval_llm_verbose(
    *,
    phase: str,
    model: str,
    prompt: str,
    assistant_text: str,
    raw_response: str,
) -> None:
    log = _get_eval_llm_verbose_logger()
    sep = "=" * 80
    log.info("\n%s", sep)
    log.info("phase=%s model=%s", phase, model)
    log.info("%s\nFULL PROMPT (user message)\n%s", sep, prompt)
    log.info("%s\nASSISTANT CONTENT (parsed)\n%s", sep, assistant_text)
    log.info("%s\nRAW RESPONSE (JSON string)\n%s", sep, raw_response)
    log.info("%s\n", sep)


def _canonicalize_judge_label(raw: str) -> str | None:
    s = str(raw or "").strip().strip('"').strip("'")
    if s in VALID_LABELS:
        return s
    low = s.lower()
    for v in VALID_LABELS:
        if v.lower() == low:
            return v
    return None


class SearchServiceClient:
    def __init__(self, base_url: str, tenant_id: str):
        self.base_url = base_url.rstrip("/")
        self.tenant_id = str(tenant_id)
        self.session = requests.Session()

    def search(self, query: str, size: int, from_: int = 0, language: str = "en", *, debug: bool = False) -> Dict[str, Any]:
        payload: Dict[str, Any] = {
            "query": query,
            "size": size,
            "from": from_,
            "language": language,
        }
        if debug:
            payload["debug"] = True
        response = self.session.post(
            f"{self.base_url}/search/",
            headers={"Content-Type": "application/json", "X-Tenant-ID": self.tenant_id},
            json=payload,
            timeout=120,
        )
        response.raise_for_status()
        return response.json()


class RerankServiceClient:
    def __init__(self, service_url: str):
        self.service_url = service_url.rstrip("/")
        self.session = requests.Session()

    def rerank(self, query: str, docs: Sequence[str], normalize: bool = False, top_n: Optional[int] = None) -> Tuple[List[float], Dict[str, Any]]:
        payload: Dict[str, Any] = {
            "query": query,
            "docs": list(docs),
            "normalize": normalize,
        }
        if top_n is not None:
            payload["top_n"] = int(top_n)
        response = self.session.post(self.service_url, json=payload, timeout=180)
        response.raise_for_status()
        data = response.json()
        return list(data.get("scores") or []), dict(data.get("meta") or {})


class DashScopeLabelClient:
    """DashScope OpenAI-compatible chat: synchronous or Batch File API (JSONL job).

    Batch flow: https://help.aliyun.com/zh/model-studio/batch-interfaces-compatible-with-openai/

    Some regional endpoints (e.g. ``dashscope-us`` compatible-mode) do not implement ``/batches``;
    on HTTP 404 from batch calls we fall back to synchronous ``/chat/completions`` and stop using batch
    for subsequent requests on this client.
    """

    def __init__(
        self,
        model: str,
        base_url: str,
        api_key: str,
        batch_size: int = 40,
        *,
        batch_completion_window: str = "24h",
        batch_poll_interval_sec: float = 10.0,
        enable_thinking: bool = True,
        use_batch: bool = False,
    ):
        self.model = model
        self.base_url = base_url.rstrip("/")
        self.api_key = api_key
        self.batch_size = int(batch_size)
        self.batch_completion_window = str(batch_completion_window)
        self.batch_poll_interval_sec = float(batch_poll_interval_sec)
        self.enable_thinking = bool(enable_thinking)
        self.use_batch = bool(use_batch)
        self.session = requests.Session()

    def _auth_headers(self) -> Dict[str, str]:
        return {"Authorization": f"Bearer {self.api_key}"}

    def _completion_body(self, prompt: str) -> Dict[str, Any]:
        body: Dict[str, Any] = {
            "model": self.model,
            "messages": [{"role": "user", "content": prompt}],
            "temperature": 0,
            "top_p": 0.1,
            "enable_thinking": self.enable_thinking,
        }
        return body

    def _chat_sync(self, prompt: str) -> Tuple[str, str]:
        response = self.session.post(
            f"{self.base_url}/chat/completions",
            headers={**self._auth_headers(), "Content-Type": "application/json"},
            json=self._completion_body(prompt),
            timeout=180,
        )
        response.raise_for_status()
        data = response.json()
        content = str(((data.get("choices") or [{}])[0].get("message") or {}).get("content") or "").strip()
        return content, safe_json_dumps(data)

    def _chat_batch(self, prompt: str) -> Tuple[str, str]:
        """One chat completion via Batch File API (single-line JSONL job)."""
        custom_id = uuid.uuid4().hex
        body = self._completion_body(prompt)
        line_obj = {
            "custom_id": custom_id,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": body,
        }
        jsonl = json.dumps(line_obj, ensure_ascii=False, separators=(",", ":")) + "\n"
        auth = self._auth_headers()

        up = self.session.post(
            f"{self.base_url}/files",
            headers=auth,
            files={
                "file": (
                    "eval_batch_input.jsonl",
                    io.BytesIO(jsonl.encode("utf-8")),
                    "application/octet-stream",
                )
            },
            data={"purpose": "batch"},
            timeout=300,
        )
        up.raise_for_status()
        file_id = (up.json() or {}).get("id")
        if not file_id:
            raise RuntimeError(f"DashScope file upload returned no id: {up.text!r}")

        cr = self.session.post(
            f"{self.base_url}/batches",
            headers={**auth, "Content-Type": "application/json"},
            json={
                "input_file_id": file_id,
                "endpoint": "/v1/chat/completions",
                "completion_window": self.batch_completion_window,
            },
            timeout=120,
        )
        cr.raise_for_status()
        batch_payload = cr.json() or {}
        batch_id = batch_payload.get("id")
        if not batch_id:
            raise RuntimeError(f"DashScope batches.create returned no id: {cr.text!r}")

        terminal = frozenset({"completed", "failed", "expired", "cancelled"})
        batch: Dict[str, Any] = dict(batch_payload)
        status = str(batch.get("status") or "")
        while status not in terminal:
            time.sleep(self.batch_poll_interval_sec)
            br = self.session.get(f"{self.base_url}/batches/{batch_id}", headers=auth, timeout=120)
            br.raise_for_status()
            batch = br.json() or {}
            status = str(batch.get("status") or "")

        if status != "completed":
            raise RuntimeError(
                f"DashScope batch {batch_id} ended with status={status!r} errors={batch.get('errors')!r}"
            )

        out_id = batch.get("output_file_id")
        err_id = batch.get("error_file_id")

        row = self._find_batch_line_for_custom_id(out_id, custom_id, auth)
        if row is None:
            err_row = self._find_batch_line_for_custom_id(err_id, custom_id, auth)
            if err_row is not None:
                raise RuntimeError(f"DashScope batch request failed: {err_row!r}")
            raise RuntimeError(f"DashScope batch output missing custom_id={custom_id!r}")

        resp = row.get("response") or {}
        sc = resp.get("status_code")
        if sc is not None and int(sc) != 200:
            raise RuntimeError(f"DashScope batch line error: {row!r}")

        data = resp.get("body") or {}
        content = str(((data.get("choices") or [{}])[0].get("message") or {}).get("content") or "").strip()
        return content, safe_json_dumps(row)

    def _chat(self, prompt: str, *, phase: str = "chat") -> Tuple[str, str]:
        if not self.use_batch:
            content, raw = self._chat_sync(prompt)
        else:
            try:
                content, raw = self._chat_batch(prompt)
            except requests.exceptions.HTTPError as e:
                resp = getattr(e, "response", None)
                if resp is not None and resp.status_code == 404:
                    self.use_batch = False
                    content, raw = self._chat_sync(prompt)
                else:
                    raise
        _log_eval_llm_verbose(
            phase=phase,
            model=self.model,
            prompt=prompt,
            assistant_text=content,
            raw_response=raw,
        )
        return content, raw

    def _find_batch_line_for_custom_id(
        self,
        file_id: Optional[str],
        custom_id: str,
        auth: Dict[str, str],
    ) -> Optional[Dict[str, Any]]:
        if not file_id or str(file_id) in ("null", ""):
            return None
        r = self.session.get(f"{self.base_url}/files/{file_id}/content", headers=auth, timeout=300)
        r.raise_for_status()
        for raw in r.text.splitlines():
            raw = raw.strip()
            if not raw:
                continue
            try:
                obj = json.loads(raw)
            except json.JSONDecodeError:
                continue
            if str(obj.get("custom_id")) == custom_id:
                return obj
        return None

    def query_intent(self, query: str) -> Tuple[str, str]:
        prompt = intent_analysis_prompt(query)
        return self._chat(prompt, phase="query_intent")

    def classify_batch(
        self,
        query: str,
        docs: Sequence[Dict[str, Any]],
        *,
        query_intent_block: str = "",
    ) -> Tuple[List[str], str]:
        numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)]
        prompt = classify_prompt(query, numbered_docs, query_intent_block=query_intent_block)
        content, raw_response = self._chat(prompt, phase="relevance_classify")
        labels: List[str] = []
        for line in str(content or "").splitlines():
            canon = _canonicalize_judge_label(line)
            if canon is not None:
                labels.append(canon)
        if len(labels) != len(docs):
            payload = extract_json_blob(content)
            if isinstance(payload, dict) and isinstance(payload.get("labels"), list):
                labels = []
                for item in payload["labels"][: len(docs)]:
                    if isinstance(item, dict):
                        raw_l = str(item.get("label") or "").strip()
                    else:
                        raw_l = str(item).strip()
                    canon = _canonicalize_judge_label(raw_l)
                    if canon is not None:
                        labels.append(canon)
        if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels):
            raise ValueError(f"unexpected classify output: {content!r}")
        return labels, raw_response