clients.py 5.9 KB
"""HTTP clients for search API, reranker, and DashScope chat (relevance labeling)."""

from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence, Tuple

import requests

from .constants import VALID_LABELS
from .prompts import (
    classify_batch_complex_prompt,
    classify_batch_simple_prompt,
    extract_query_profile_prompt,
)
from .utils import build_label_doc_line, extract_json_blob, safe_json_dumps


class SearchServiceClient:
    def __init__(self, base_url: str, tenant_id: str):
        self.base_url = base_url.rstrip("/")
        self.tenant_id = str(tenant_id)
        self.session = requests.Session()

    def search(self, query: str, size: int, from_: int = 0, language: str = "en") -> Dict[str, Any]:
        response = self.session.post(
            f"{self.base_url}/search/",
            headers={"Content-Type": "application/json", "X-Tenant-ID": self.tenant_id},
            json={"query": query, "size": size, "from": from_, "language": language},
            timeout=120,
        )
        response.raise_for_status()
        return response.json()


class RerankServiceClient:
    def __init__(self, service_url: str):
        self.service_url = service_url.rstrip("/")
        self.session = requests.Session()

    def rerank(self, query: str, docs: Sequence[str], normalize: bool = False, top_n: Optional[int] = None) -> Tuple[List[float], Dict[str, Any]]:
        payload: Dict[str, Any] = {
            "query": query,
            "docs": list(docs),
            "normalize": normalize,
        }
        if top_n is not None:
            payload["top_n"] = int(top_n)
        response = self.session.post(self.service_url, json=payload, timeout=180)
        response.raise_for_status()
        data = response.json()
        return list(data.get("scores") or []), dict(data.get("meta") or {})


class DashScopeLabelClient:
    def __init__(self, model: str, base_url: str, api_key: str, batch_size: int = 40):
        self.model = model
        self.base_url = base_url.rstrip("/")
        self.api_key = api_key
        self.batch_size = int(batch_size)
        self.session = requests.Session()

    def _chat(self, prompt: str) -> Tuple[str, str]:
        response = self.session.post(
            f"{self.base_url}/chat/completions",
            headers={
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json",
            },
            json={
                "model": self.model,
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0,
                "top_p": 0.1,
            },
            timeout=180,
        )
        response.raise_for_status()
        data = response.json()
        content = str(((data.get("choices") or [{}])[0].get("message") or {}).get("content") or "").strip()
        return content, safe_json_dumps(data)

    def classify_batch_simple(
        self,
        query: str,
        docs: Sequence[Dict[str, Any]],
    ) -> Tuple[List[str], str]:
        numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)]
        prompt = classify_batch_simple_prompt(query, numbered_docs)
        content, raw_response = self._chat(prompt)
        labels = []
        for line in str(content or "").splitlines():
            label = line.strip()
            if label in VALID_LABELS:
                labels.append(label)
        if len(labels) != len(docs):
            payload = extract_json_blob(content)
            if isinstance(payload, dict) and isinstance(payload.get("labels"), list):
                labels = []
                for item in payload["labels"][: len(docs)]:
                    if isinstance(item, dict):
                        label = str(item.get("label") or "").strip()
                    else:
                        label = str(item).strip()
                    if label in VALID_LABELS:
                        labels.append(label)
        if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels):
            raise ValueError(f"unexpected simple label output: {content!r}")
        return labels, raw_response

    def extract_query_profile(
        self,
        query: str,
        parser_hints: Dict[str, Any],
    ) -> Tuple[Dict[str, Any], str]:
        prompt = extract_query_profile_prompt(query, parser_hints)
        content, raw_response = self._chat(prompt)
        payload = extract_json_blob(content)
        if not isinstance(payload, dict):
            raise ValueError(f"unexpected query profile payload: {content!r}")
        payload.setdefault("normalized_query_en", query)
        payload.setdefault("primary_category", "")
        payload.setdefault("allowed_categories", [])
        payload.setdefault("required_attributes", [])
        payload.setdefault("notes", [])
        return payload, raw_response

    def classify_batch_complex(
        self,
        query: str,
        query_profile: Dict[str, Any],
        docs: Sequence[Dict[str, Any]],
    ) -> Tuple[List[str], str]:
        numbered_docs = [build_label_doc_line(idx + 1, doc) for idx, doc in enumerate(docs)]
        prompt = classify_batch_complex_prompt(query, query_profile, numbered_docs)
        content, raw_response = self._chat(prompt)
        payload = extract_json_blob(content)
        if not isinstance(payload, dict) or not isinstance(payload.get("labels"), list):
            raise ValueError(f"unexpected label payload: {content!r}")
        labels_payload = payload["labels"]
        labels: List[str] = []
        for item in labels_payload[: len(docs)]:
            if not isinstance(item, dict):
                continue
            label = str(item.get("label") or "").strip()
            if label in VALID_LABELS:
                labels.append(label)
        if len(labels) != len(docs) or any(label not in VALID_LABELS for label in labels):
            raise ValueError(f"unexpected label output: {content!r}")
        return labels, raw_response