dashscope_rerank.py 10.6 KB
"""
DashScope cloud reranker backend (OpenAI-compatible reranks API).

Reference:
- https://dashscope.aliyuncs.com/compatible-api/v1/reranks
- Use region-specific domains when needed:
  - China:     https://dashscope.aliyuncs.com
  - Singapore: https://dashscope-intl.aliyuncs.com
  - US:        https://dashscope-us.aliyuncs.com
"""

from __future__ import annotations

import json
import logging
import math
import os
import time
from typing import Any, Dict, List, Tuple
from urllib import error as urllib_error
from urllib import request as urllib_request

from reranker.backends.batching_utils import deduplicate_with_positions

logger = logging.getLogger("reranker.backends.dashscope_rerank")


class DashScopeRerankBackend:
    """
    DashScope cloud reranker backend.

    Config from services.rerank.backends.dashscope_rerank:
      - model_name: str, default "qwen3-rerank"
      - endpoint: str, default "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
      - api_key: optional str (or env DASHSCOPE_API_KEY)
      - timeout_sec: float, default 15.0
      - top_n_cap: int, optional cap; 0 means use all docs in request
      - instruct: optional str
      - max_retries: int, default 1
      - retry_backoff_sec: float, default 0.2

    Env overrides:
      - DASHSCOPE_API_KEY
      - RERANK_DASHSCOPE_ENDPOINT
      - RERANK_DASHSCOPE_MODEL
      - RERANK_DASHSCOPE_TIMEOUT_SEC
      - RERANK_DASHSCOPE_TOP_N_CAP
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self._config = config or {}
        self._model_name = str(
            os.getenv("RERANK_DASHSCOPE_MODEL")
            or self._config.get("model_name")
            or "qwen3-rerank"
        )
        self._endpoint = str(
            os.getenv("RERANK_DASHSCOPE_ENDPOINT")
            or self._config.get("endpoint")
            or "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
        ).strip()
        self._api_key = str(
            os.getenv("DASHSCOPE_API_KEY")
            or self._config.get("api_key")
            or ""
        ).strip().strip('"').strip("'")
        self._timeout_sec = float(
            os.getenv("RERANK_DASHSCOPE_TIMEOUT_SEC")
            or self._config.get("timeout_sec")
            or 15.0
        )
        self._top_n_cap = int(
            os.getenv("RERANK_DASHSCOPE_TOP_N_CAP")
            or self._config.get("top_n_cap")
            or 0
        )
        self._instruct = str(self._config.get("instruct") or "").strip()
        self._max_retries = int(self._config.get("max_retries", 1))
        self._retry_backoff_sec = float(self._config.get("retry_backoff_sec", 0.2))

        if not self._endpoint:
            raise ValueError("dashscope_rerank endpoint is required")
        if not self._api_key:
            raise ValueError(
                "dashscope_rerank api_key is required (set services.rerank.backends.dashscope_rerank.api_key "
                "or env DASHSCOPE_API_KEY)"
            )
        if self._timeout_sec <= 0:
            raise ValueError(f"dashscope_rerank timeout_sec must be > 0, got {self._timeout_sec}")
        if self._top_n_cap < 0:
            raise ValueError(f"dashscope_rerank top_n_cap must be >= 0, got {self._top_n_cap}")
        if self._max_retries <= 0:
            raise ValueError(f"dashscope_rerank max_retries must be > 0, got {self._max_retries}")
        if self._retry_backoff_sec < 0:
            raise ValueError(
                f"dashscope_rerank retry_backoff_sec must be >= 0, got {self._retry_backoff_sec}"
            )

        logger.info(
            "DashScope reranker ready | endpoint=%s model=%s timeout_sec=%s top_n_cap=%s",
            self._endpoint,
            self._model_name,
            self._timeout_sec,
            self._top_n_cap,
        )

    def _http_post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
        req = urllib_request.Request(
            url=self._endpoint,
            method="POST",
            data=body,
            headers={
                "Authorization": f"Bearer {self._api_key}",
                "Content-Type": "application/json",
            },
        )
        with urllib_request.urlopen(req, timeout=self._timeout_sec) as resp:
            raw = resp.read().decode("utf-8", errors="replace")
            try:
                data = json.loads(raw)
            except json.JSONDecodeError as exc:
                raise RuntimeError(f"DashScope response is not valid JSON: {raw[:512]}") from exc
            if not isinstance(data, dict):
                raise RuntimeError(f"DashScope response must be JSON object, got: {type(data).__name__}")
            return data

    def _post_rerank(self, query: str, docs: List[str], top_n: int) -> Dict[str, Any]:
        payload: Dict[str, Any] = {
            "model": self._model_name,
            "query": query,
            "documents": docs,
            "top_n": top_n,
        }
        if self._instruct:
            payload["instruct"] = self._instruct

        last_exc: Exception | None = None
        for attempt in range(1, self._max_retries + 1):
            try:
                return self._http_post_json(payload)
            except urllib_error.HTTPError as exc:
                body = ""
                try:
                    body = exc.read().decode("utf-8", errors="replace")
                except Exception:
                    body = ""
                last_exc = RuntimeError(
                    f"DashScope rerank HTTP {exc.code} (attempt {attempt}/{self._max_retries}): {body[:512]}"
                )
            except urllib_error.URLError as exc:
                last_exc = RuntimeError(
                    f"DashScope rerank network error (attempt {attempt}/{self._max_retries}): {exc}"
                )
            except Exception as exc:  # pragma: no cover - defensive
                last_exc = RuntimeError(
                    f"DashScope rerank unexpected error (attempt {attempt}/{self._max_retries}): {exc}"
                )

            if attempt < self._max_retries and self._retry_backoff_sec > 0:
                time.sleep(self._retry_backoff_sec * attempt)

        raise RuntimeError(str(last_exc) if last_exc else "DashScope rerank failed with unknown error")

    @staticmethod
    def _extract_results(data: Dict[str, Any]) -> List[Dict[str, Any]]:
        # Compatible API style: {"results":[...]}
        results = data.get("results")
        if isinstance(results, list):
            return [x for x in results if isinstance(x, dict)]

        # Native style fallback: {"output":{"results":[...]}}
        output = data.get("output")
        if isinstance(output, dict):
            output_results = output.get("results")
            if isinstance(output_results, list):
                return [x for x in output_results if isinstance(x, dict)]

        return []

    @staticmethod
    def _coerce_score(raw_score: Any, normalize: bool) -> float:
        try:
            score = float(raw_score)
        except (TypeError, ValueError):
            return 0.0

        if not normalize:
            return score
        # DashScope relevance_score is typically already in [0,1]; keep it.
        if 0.0 <= score <= 1.0:
            return score
        # Fallback when provider returns logits/raw scores.
        if score > 60:
            return 1.0
        if score < -60:
            return 0.0
        return 1.0 / (1.0 + math.exp(-score))

    def score_with_meta_topn(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
        top_n: int | None = None,
    ) -> Tuple[List[float], Dict[str, Any]]:
        start_ts = time.time()
        total_docs = len(docs) if docs else 0
        output_scores: List[float] = [0.0] * total_docs

        query = "" if query is None else str(query).strip()
        indexed: List[Tuple[int, str]] = []
        for i, doc in enumerate(docs or []):
            if doc is None:
                continue
            text = str(doc).strip()
            if not text:
                continue
            indexed.append((i, text))

        if not query or not indexed:
            elapsed_ms = (time.time() - start_ts) * 1000.0
            return output_scores, {
                "input_docs": total_docs,
                "usable_docs": len(indexed),
                "unique_docs": 0,
                "dedup_ratio": 0.0,
                "elapsed_ms": round(elapsed_ms, 3),
                "model": self._model_name,
                "backend": "dashscope_rerank",
                "normalize": normalize,
                "top_n": 0,
            }

        indexed_texts = [text for _, text in indexed]
        unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts)

        top_n_effective = len(unique_texts)
        if top_n is not None and int(top_n) > 0:
            top_n_effective = min(top_n_effective, int(top_n))
        if self._top_n_cap > 0:
            top_n_effective = min(top_n_effective, self._top_n_cap)

        response = self._post_rerank(query=query, docs=unique_texts, top_n=top_n_effective)
        results = self._extract_results(response)

        unique_scores: List[float] = [0.0] * len(unique_texts)
        for rank, item in enumerate(results):
            raw_idx = item.get("index", rank)
            try:
                idx = int(raw_idx)
            except (TypeError, ValueError):
                continue
            if idx < 0 or idx >= len(unique_scores):
                continue
            raw_score = item.get("relevance_score", item.get("score"))
            unique_scores[idx] = self._coerce_score(raw_score, normalize=normalize)

        for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
            output_scores[orig_idx] = float(unique_scores[unique_idx])

        elapsed_ms = (time.time() - start_ts) * 1000.0
        dedup_ratio = 0.0
        if indexed:
            dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))

        return output_scores, {
            "input_docs": total_docs,
            "usable_docs": len(indexed),
            "unique_docs": len(unique_texts),
            "dedup_ratio": round(dedup_ratio, 4),
            "elapsed_ms": round(elapsed_ms, 3),
            "model": self._model_name,
            "backend": "dashscope_rerank",
            "normalize": normalize,
            "top_n": top_n_effective,
            "requested_top_n": int(top_n) if top_n is not None else None,
            "response_results": len(results),
            "endpoint": self._endpoint,
        }

    def score_with_meta(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
    ) -> Tuple[List[float], Dict[str, Any]]:
        return self.score_with_meta_topn(query=query, docs=docs, normalize=normalize, top_n=None)