Blame view

embeddings/text_embedding_tei.py 5.04 KB
07cf5a93   tangwang   START_EMBEDDING=...
1
2
3
4
  """TEI text embedding backend client."""
  
  from __future__ import annotations
  
4650fcec   tangwang   日志优化、日志串联(uid rqid)
5
  import logging
54ccf28c   tangwang   tei
6
  from typing import Any, List, Union
07cf5a93   tangwang   START_EMBEDDING=...
7
8
9
10
  
  import numpy as np
  import requests
  
4650fcec   tangwang   日志优化、日志串联(uid rqid)
11
12
  logger = logging.getLogger(__name__)
  
07cf5a93   tangwang   START_EMBEDDING=...
13
14
15
16
17
18
19
20
21
22
23
  
  class TEITextModel:
      """
      Text embedding backend implemented via Hugging Face TEI HTTP API.
  
      Expected TEI endpoint:
        POST {base_url}/embed
        body: {"inputs": ["text1", "text2", ...]}
        response: [[...], [...], ...]
      """
  
4650fcec   tangwang   日志优化、日志串联(uid rqid)
24
      def __init__(self, base_url: str, timeout_sec: int = 60, max_client_batch_size: int = 24):
07cf5a93   tangwang   START_EMBEDDING=...
25
26
27
28
29
          if not base_url or not str(base_url).strip():
              raise ValueError("TEI base_url must not be empty")
          self.base_url = str(base_url).rstrip("/")
          self.endpoint = f"{self.base_url}/embed"
          self.timeout_sec = int(timeout_sec)
4650fcec   tangwang   日志优化、日志串联(uid rqid)
30
          self.max_client_batch_size = max(1, int(max_client_batch_size))
07cf5a93   tangwang   START_EMBEDDING=...
31
32
33
34
35
36
          self._health_check()
  
      def _health_check(self) -> None:
          health_url = f"{self.base_url}/health"
          response = requests.get(health_url, timeout=5)
          response.raise_for_status()
54ccf28c   tangwang   tei
37
38
39
40
41
42
43
44
45
          # Probe one tiny embedding at startup so runtime requests do not fail later
          # with opaque "Invalid TEI embedding" errors.
          probe_resp = requests.post(
              self.endpoint,
              json={"inputs": ["health check"]},
              timeout=min(self.timeout_sec, 15),
          )
          probe_resp.raise_for_status()
          self._parse_payload(probe_resp.json(), expected_len=1)
07cf5a93   tangwang   START_EMBEDDING=...
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  
      @staticmethod
      def _normalize(embedding: np.ndarray) -> np.ndarray:
          norm = np.linalg.norm(embedding)
          if norm <= 0:
              raise RuntimeError("TEI returned zero-norm embedding")
          return embedding / norm
  
      def encode(
          self,
          sentences: Union[str, List[str]],
          normalize_embeddings: bool = True,
          device: str = "cuda",
          batch_size: int = 32,
      ) -> np.ndarray:
77516841   tangwang   tidy embeddings
61
62
          """
          Encode a single sentence or a list of sentences.
07cf5a93   tangwang   START_EMBEDDING=...
63
  
77516841   tangwang   tidy embeddings
64
65
66
67
68
69
70
71
          TEI HTTP 后端天然是批量接口,这里统一通过 encode 处理单条和批量输入,
          不再额外暴露 encode_batch
          """
  
          if isinstance(sentences, str):
              texts: List[str] = [sentences]
          else:
              texts = sentences
07cf5a93   tangwang   START_EMBEDDING=...
72
73
74
75
76
77
78
  
          if texts is None or len(texts) == 0:
              return np.array([], dtype=object)
          for i, t in enumerate(texts):
              if not isinstance(t, str) or not t.strip():
                  raise ValueError(f"Invalid input text at index {i}: {t!r}")
  
4650fcec   tangwang   日志优化、日志串联(uid rqid)
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
          if len(texts) > self.max_client_batch_size:
              logger.info(
                  "TEI batch split | total_inputs=%d chunk_size=%d chunks=%d",
                  len(texts),
                  self.max_client_batch_size,
                  (len(texts) + self.max_client_batch_size - 1) // self.max_client_batch_size,
              )
  
          vectors: List[np.ndarray] = []
          for start in range(0, len(texts), self.max_client_batch_size):
              batch = texts[start : start + self.max_client_batch_size]
              response = requests.post(
                  self.endpoint,
                  json={"inputs": batch},
                  timeout=self.timeout_sec,
              )
              response.raise_for_status()
              payload = response.json()
              parsed = self._parse_payload(payload, expected_len=len(batch))
              if normalize_embeddings:
                  parsed = [self._normalize(vec) for vec in parsed]
              vectors.extend(parsed)
54ccf28c   tangwang   tei
101
          return np.array(vectors, dtype=object)
07cf5a93   tangwang   START_EMBEDDING=...
102
  
54ccf28c   tangwang   tei
103
104
105
      def _parse_payload(self, payload: Any, expected_len: int) -> List[np.ndarray]:
          if not isinstance(payload, list) or len(payload) != expected_len:
              got = 0 if payload is None else (len(payload) if isinstance(payload, list) else "non-list")
07cf5a93   tangwang   START_EMBEDDING=...
106
              raise RuntimeError(
54ccf28c   tangwang   tei
107
108
                  f"TEI response length mismatch: expected {expected_len}, got {got}. "
                  f"Response type={type(payload).__name__}"
07cf5a93   tangwang   START_EMBEDDING=...
109
110
111
              )
  
          vectors: List[np.ndarray] = []
54ccf28c   tangwang   tei
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
          for i, item in enumerate(payload):
              emb = item.get("embedding") if isinstance(item, dict) else item
              try:
                  vec = np.asarray(emb, dtype=np.float32)
              except (TypeError, ValueError) as exc:
                  raise RuntimeError(
                      f"Invalid TEI embedding at index {i}: cannot convert to float array "
                      f"(item_type={type(item).__name__})"
                  ) from exc
  
              if vec.ndim != 1 or vec.size == 0:
                  raise RuntimeError(
                      f"Invalid TEI embedding at index {i}: shape={vec.shape}, size={vec.size}"
                  )
              if not np.isfinite(vec).all():
                  preview = vec[:8].tolist()
                  raise RuntimeError(
                      f"Invalid TEI embedding at index {i}: contains non-finite values, "
                      f"preview={preview}. This often indicates TEI backend/model runtime issues "
                      f"(for example an incompatible dtype or model config)."
                  )
07cf5a93   tangwang   START_EMBEDDING=...
133
              vectors.append(vec)
54ccf28c   tangwang   tei
134
          return vectors
07cf5a93   tangwang   START_EMBEDDING=...