dashscope_rerank.py 15.6 KB
"""
DashScope cloud reranker backend (OpenAI-compatible reranks API).

Reference:
- https://dashscope.aliyuncs.com/compatible-api/v1/reranks
- Use region-specific domains when needed:
  - China:     https://dashscope.aliyuncs.com
  - Singapore: https://dashscope-intl.aliyuncs.com
  - US:        https://dashscope-us.aliyuncs.com
"""

from __future__ import annotations

import json
import logging
import math
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Tuple
from urllib import error as urllib_error
from urllib import request as urllib_request

from reranker.backends.batching_utils import deduplicate_with_positions, iter_batches

logger = logging.getLogger("reranker.backends.dashscope_rerank")


class DashScopeRerankBackend:
    """
    DashScope cloud reranker backend.

    Config from services.rerank.backends.dashscope_rerank:
      - model_name: str, default "qwen3-rerank"
      - endpoint: str, default "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
      - api_key_env: str, required env var name for this backend key
      - timeout_sec: float, default 15.0
      - top_n_cap: int, optional cap; 0 means use all docs in request
      - batchsize: int, optional; 0 disables batching; >0 enables concurrent small-batch scheduling
      - instruct: optional str
      - max_retries: int, default 1
      - retry_backoff_sec: float, default 0.2

    Env overrides:
      - RERANK_DASHSCOPE_ENDPOINT
      - RERANK_DASHSCOPE_MODEL
      - RERANK_DASHSCOPE_TIMEOUT_SEC
      - RERANK_DASHSCOPE_TOP_N_CAP
      - RERANK_DASHSCOPE_BATCHSIZE
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self._config = config or {}
        self._model_name = str(
            os.getenv("RERANK_DASHSCOPE_MODEL")
            or self._config.get("model_name")
            or "qwen3-rerank"
        )
        self._endpoint = str(
            os.getenv("RERANK_DASHSCOPE_ENDPOINT")
            or self._config.get("endpoint")
            or "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
        ).strip()
        self._api_key_env = str(self._config.get("api_key_env") or "").strip()
        self._api_key = str(os.getenv(self._api_key_env) or "").strip().strip('"').strip("'")
        self._timeout_sec = float(
            os.getenv("RERANK_DASHSCOPE_TIMEOUT_SEC")
            or self._config.get("timeout_sec")
            or 15.0
        )
        self._top_n_cap = int(
            os.getenv("RERANK_DASHSCOPE_TOP_N_CAP")
            or self._config.get("top_n_cap")
            or 0
        )
        self._batchsize = int(
            os.getenv("RERANK_DASHSCOPE_BATCHSIZE")
            or self._config.get("batchsize")
            or 0
        )
        self._instruct = str(self._config.get("instruct") or "").strip()
        self._max_retries = int(self._config.get("max_retries", 1))
        self._retry_backoff_sec = float(self._config.get("retry_backoff_sec", 0.2))

        if not self._endpoint:
            raise ValueError("dashscope_rerank endpoint is required")
        if not self._api_key_env:
            raise ValueError("dashscope_rerank api_key_env is required")
        if not self._api_key:
            raise ValueError(
                f"dashscope_rerank api key is required (set env {self._api_key_env})"
            )
        if self._timeout_sec <= 0:
            raise ValueError(f"dashscope_rerank timeout_sec must be > 0, got {self._timeout_sec}")
        if self._top_n_cap < 0:
            raise ValueError(f"dashscope_rerank top_n_cap must be >= 0, got {self._top_n_cap}")
        if self._batchsize < 0:
            raise ValueError(f"dashscope_rerank batchsize must be >= 0, got {self._batchsize}")
        if self._max_retries <= 0:
            raise ValueError(f"dashscope_rerank max_retries must be > 0, got {self._max_retries}")
        if self._retry_backoff_sec < 0:
            raise ValueError(
                f"dashscope_rerank retry_backoff_sec must be >= 0, got {self._retry_backoff_sec}"
            )

        logger.info(
            "DashScope reranker ready | endpoint=%s model=%s timeout_sec=%s top_n_cap=%s batchsize=%s",
            self._endpoint,
            self._model_name,
            self._timeout_sec,
            self._top_n_cap,
            self._batchsize,
        )

    def _http_post_json(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
        req = urllib_request.Request(
            url=self._endpoint,
            method="POST",
            data=body,
            headers={
                "Authorization": f"Bearer {self._api_key}",
                "Content-Type": "application/json",
            },
        )
        with urllib_request.urlopen(req, timeout=self._timeout_sec) as resp:
            raw = resp.read().decode("utf-8", errors="replace")
            try:
                data = json.loads(raw)
            except json.JSONDecodeError as exc:
                raise RuntimeError(f"DashScope response is not valid JSON: {raw[:512]}") from exc
            if not isinstance(data, dict):
                raise RuntimeError(f"DashScope response must be JSON object, got: {type(data).__name__}")
            return data

    def _post_rerank(self, query: str, docs: List[str], top_n: int) -> Dict[str, Any]:
        payload: Dict[str, Any] = {
            "model": self._model_name,
            "query": query,
            "documents": docs,
            "top_n": top_n,
        }
        if self._instruct:
            payload["instruct"] = self._instruct

        last_exc: Exception | None = None
        for attempt in range(1, self._max_retries + 1):
            try:
                return self._http_post_json(payload)
            except urllib_error.HTTPError as exc:
                body = ""
                try:
                    body = exc.read().decode("utf-8", errors="replace")
                except Exception:
                    body = ""
                last_exc = RuntimeError(
                    f"DashScope rerank HTTP {exc.code} (attempt {attempt}/{self._max_retries}): {body[:512]}"
                )
            except urllib_error.URLError as exc:
                last_exc = RuntimeError(
                    f"DashScope rerank network error (attempt {attempt}/{self._max_retries}): {exc}"
                )
            except Exception as exc:  # pragma: no cover - defensive
                last_exc = RuntimeError(
                    f"DashScope rerank unexpected error (attempt {attempt}/{self._max_retries}): {exc}"
                )

            if attempt < self._max_retries and self._retry_backoff_sec > 0:
                time.sleep(self._retry_backoff_sec * attempt)

        raise RuntimeError(str(last_exc) if last_exc else "DashScope rerank failed with unknown error")

    def _score_single_request(
        self,
        query: str,
        unique_texts: List[str],
        normalize: bool,
        top_n: int,
    ) -> Tuple[List[float], int]:
        response = self._post_rerank(query=query, docs=unique_texts, top_n=top_n)
        results = self._extract_results(response)

        unique_scores: List[float] = [0.0] * len(unique_texts)
        for rank, item in enumerate(results):
            raw_idx = item.get("index", rank)
            try:
                idx = int(raw_idx)
            except (TypeError, ValueError):
                continue
            if idx < 0 or idx >= len(unique_scores):
                continue
            raw_score = item.get("relevance_score", item.get("score"))
            unique_scores[idx] = self._coerce_score(raw_score, normalize=normalize)
        return unique_scores, len(results)

    def _score_batched_concurrent(
        self,
        query: str,
        unique_texts: List[str],
        normalize: bool,
    ) -> Tuple[List[float], Dict[str, int]]:
        """
        Concurrent batch scoring.

        We intentionally request full local scores in each batch (top_n=len(batch)),
        then apply global top_n/top_n_cap truncation after merge if needed.
        """
        indices = list(range(len(unique_texts)))
        batches = list(iter_batches(indices, batch_size=self._batchsize))
        num_batches = len(batches)
        max_workers = min(8, num_batches) if num_batches > 0 else 1
        unique_scores: List[float] = [0.0] * len(unique_texts)
        response_results = 0

        def _run_one(batch_no: int, batch_indices: List[int]) -> Tuple[int, List[int], Dict[str, Any], float]:
            docs = [unique_texts[i] for i in batch_indices]
            # Ask each batch for all docs to avoid local truncation.
            start_ts = time.perf_counter()
            data = self._post_rerank(query=query, docs=docs, top_n=len(docs))
            elapsed_ms = round((time.perf_counter() - start_ts) * 1000.0, 3)
            return batch_no, batch_indices, data, elapsed_ms

        with ThreadPoolExecutor(max_workers=max_workers) as ex:
            future_to_batch = {ex.submit(_run_one, i + 1, b): b for i, b in enumerate(batches)}
            for fut in as_completed(future_to_batch):
                batch_indices = future_to_batch[fut]
                try:
                    batch_no, _, data, batch_elapsed_ms = fut.result()
                except Exception as exc:
                    raise RuntimeError(
                        f"DashScope rerank batch failed | batch_size={len(batch_indices)} error={exc}"
                    ) from exc
                results = self._extract_results(data)
                logger.info(
                    "DashScope batch response | batch=%d/%d docs=%d elapsed_ms=%s results=%d query=%r",
                    batch_no,
                    num_batches,
                    len(batch_indices),
                    batch_elapsed_ms,
                    len(results),
                    query[:80],
                )
                response_results += len(results)
                for rank, item in enumerate(results):
                    raw_idx = item.get("index", rank)
                    try:
                        local_idx = int(raw_idx)
                    except (TypeError, ValueError):
                        continue
                    if local_idx < 0 or local_idx >= len(batch_indices):
                        continue
                    global_idx = batch_indices[local_idx]
                    raw_score = item.get("relevance_score", item.get("score"))
                    unique_scores[global_idx] = self._coerce_score(raw_score, normalize=normalize)

        return unique_scores, {
            "batches": num_batches,
            "batch_concurrency": max_workers,
            "response_results": response_results,
        }

    @staticmethod
    def _extract_results(data: Dict[str, Any]) -> List[Dict[str, Any]]:
        # Compatible API style: {"results":[...]}
        results = data.get("results")
        if isinstance(results, list):
            return [x for x in results if isinstance(x, dict)]

        # Native style fallback: {"output":{"results":[...]}}
        output = data.get("output")
        if isinstance(output, dict):
            output_results = output.get("results")
            if isinstance(output_results, list):
                return [x for x in output_results if isinstance(x, dict)]

        return []

    @staticmethod
    def _coerce_score(raw_score: Any, normalize: bool) -> float:
        try:
            score = float(raw_score)
        except (TypeError, ValueError):
            return 0.0

        if not normalize:
            return score
        # DashScope relevance_score is typically already in [0,1]; keep it.
        if 0.0 <= score <= 1.0:
            return score
        # Fallback when provider returns logits/raw scores.
        if score > 60:
            return 1.0
        if score < -60:
            return 0.0
        return 1.0 / (1.0 + math.exp(-score))

    def score_with_meta_topn(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
        top_n: int | None = None,
    ) -> Tuple[List[float], Dict[str, Any]]:
        start_ts = time.time()
        total_docs = len(docs) if docs else 0
        output_scores: List[float] = [0.0] * total_docs

        query = "" if query is None else str(query).strip()
        indexed: List[Tuple[int, str]] = []
        for i, doc in enumerate(docs or []):
            if doc is None:
                continue
            text = str(doc).strip()
            if not text:
                continue
            indexed.append((i, text))

        if not query or not indexed:
            elapsed_ms = (time.time() - start_ts) * 1000.0
            return output_scores, {
                "input_docs": total_docs,
                "usable_docs": len(indexed),
                "unique_docs": 0,
                "dedup_ratio": 0.0,
                "elapsed_ms": round(elapsed_ms, 3),
                "model": self._model_name,
                "backend": "dashscope_rerank",
                "normalize": normalize,
                "top_n": 0,
            }

        indexed_texts = [text for _, text in indexed]
        unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts)

        top_n_effective = len(unique_texts)
        if top_n is not None and int(top_n) > 0:
            top_n_effective = min(top_n_effective, int(top_n))
        if self._top_n_cap > 0:
            top_n_effective = min(top_n_effective, self._top_n_cap)
        can_batch = (
            self._batchsize > 0
            and len(unique_texts) > self._batchsize
        )
        if can_batch:
            unique_scores, batch_meta = self._score_batched_concurrent(
                query=query,
                unique_texts=unique_texts,
                normalize=normalize,
            )
            if top_n_effective < len(unique_scores):
                order = sorted(range(len(unique_scores)), key=lambda i: (-unique_scores[i], i))
                keep = set(order[:top_n_effective])
                for i in range(len(unique_scores)):
                    if i not in keep:
                        unique_scores[i] = 0.0
            response_results = int(batch_meta["response_results"])
            batches = int(batch_meta["batches"])
            batch_concurrency = int(batch_meta["batch_concurrency"])
        else:
            unique_scores, response_results = self._score_single_request(
                query=query,
                unique_texts=unique_texts,
                normalize=normalize,
                top_n=top_n_effective,
            )
            batches = 1
            batch_concurrency = 1

        for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
            output_scores[orig_idx] = float(unique_scores[unique_idx])

        elapsed_ms = (time.time() - start_ts) * 1000.0
        dedup_ratio = 0.0
        if indexed:
            dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))

        return output_scores, {
            "input_docs": total_docs,
            "usable_docs": len(indexed),
            "unique_docs": len(unique_texts),
            "dedup_ratio": round(dedup_ratio, 4),
            "elapsed_ms": round(elapsed_ms, 3),
            "model": self._model_name,
            "backend": "dashscope_rerank",
            "normalize": normalize,
            "top_n": top_n_effective,
            "requested_top_n": int(top_n) if top_n is not None else None,
            "response_results": response_results,
            "batchsize": self._batchsize,
            "batches": batches,
            "batch_concurrency": batch_concurrency,
            "endpoint": self._endpoint,
        }

    def score_with_meta(
        self,
        query: str,
        docs: List[str],
        normalize: bool = True,
    ) -> Tuple[List[float], Dict[str, Any]]:
        return self.score_with_meta_topn(query=query, docs=docs, normalize=normalize, top_n=None)