Blame view

providers/rerank.py 2.23 KB
ed948666   tangwang   tidy
1
  """Rerank provider - HTTP service."""
42e3aea6   tangwang   tidy
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
  
  from __future__ import annotations
  
  import logging
  from typing import Any, Dict, List, Optional, Tuple
  
  import requests
  
  from config.services_config import get_rerank_config, get_rerank_service_url
  
  logger = logging.getLogger(__name__)
  
  
  class HttpRerankProvider:
      """Rerank via HTTP service."""
  
      def __init__(self, service_url: str):
          self.service_url = (service_url or "").rstrip("/")
  
      def rerank(
          self,
          query: str,
          docs: List[str],
          timeout_sec: float,
d31c7f65   tangwang   补充云服务reranker
26
          top_n: Optional[int] = None,
42e3aea6   tangwang   tidy
27
28
29
30
31
      ) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]:
          if not docs:
              return [], {}
          try:
              payload = {"query": (query or "").strip(), "docs": docs}
d31c7f65   tangwang   补充云服务reranker
32
33
              if top_n is not None and int(top_n) > 0:
                  payload["top_n"] = int(top_n)
42e3aea6   tangwang   tidy
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
              response = requests.post(self.service_url, json=payload, timeout=timeout_sec)
              if response.status_code != 200:
                  logger.warning(
                      "Rerank service HTTP %s: %s",
                      response.status_code,
                      (response.text or "")[:200],
                  )
                  return None, None
              data = response.json()
              scores = data.get("scores")
              if not isinstance(scores, list):
                  return None, None
              return scores, data.get("meta") or {}
          except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as exc:
              logger.warning(
                  "Rerank request timed out after %.1fs (docs=%d); returning ES order. %s",
                  timeout_sec,
                  len(docs),
                  exc,
              )
              return None, None
          except Exception as exc:
              logger.warning("Rerank request failed: %s", exc, exc_info=True)
              return None, None
  
  
  def create_rerank_provider() -> HttpRerankProvider:
      """Create rerank provider from services config."""
      cfg = get_rerank_config()
      provider = (cfg.provider or "http").strip().lower()
  
ed948666   tangwang   tidy
65
66
      if provider != "http":
          raise ValueError(f"Unsupported rerank provider: {provider}")
42e3aea6   tangwang   tidy
67
68
69
  
      url = get_rerank_service_url()
      return HttpRerankProvider(service_url=url)