Blame view

embeddings/tei_model.py 3.02 KB
07cf5a93   tangwang   START_EMBEDDING=...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  """TEI text embedding backend client."""
  
  from __future__ import annotations
  
  from typing import List, Union
  
  import numpy as np
  import requests
  
  
  class TEITextModel:
      """
      Text embedding backend implemented via Hugging Face TEI HTTP API.
  
      Expected TEI endpoint:
        POST {base_url}/embed
        body: {"inputs": ["text1", "text2", ...]}
        response: [[...], [...], ...]
      """
  
      def __init__(self, base_url: str, timeout_sec: int = 60):
          if not base_url or not str(base_url).strip():
              raise ValueError("TEI base_url must not be empty")
          self.base_url = str(base_url).rstrip("/")
          self.endpoint = f"{self.base_url}/embed"
          self.timeout_sec = int(timeout_sec)
          self._health_check()
  
      def _health_check(self) -> None:
          health_url = f"{self.base_url}/health"
          response = requests.get(health_url, timeout=5)
          response.raise_for_status()
  
      @staticmethod
      def _normalize(embedding: np.ndarray) -> np.ndarray:
          norm = np.linalg.norm(embedding)
          if norm <= 0:
              raise RuntimeError("TEI returned zero-norm embedding")
          return embedding / norm
  
      def encode(
          self,
          sentences: Union[str, List[str]],
          normalize_embeddings: bool = True,
          device: str = "cuda",
          batch_size: int = 32,
      ) -> np.ndarray:
          if isinstance(sentences, str):
              sentences = [sentences]
          return self.encode_batch(
              texts=sentences,
              batch_size=batch_size,
              device=device,
              normalize_embeddings=normalize_embeddings,
          )
  
      def encode_batch(
          self,
          texts: List[str],
          batch_size: int = 32,
          device: str = "cuda",
          normalize_embeddings: bool = True,
      ) -> np.ndarray:
          del batch_size  # TEI performs its own batching.
          del device      # Not used by HTTP backend.
  
          if texts is None or len(texts) == 0:
              return np.array([], dtype=object)
          for i, t in enumerate(texts):
              if not isinstance(t, str) or not t.strip():
                  raise ValueError(f"Invalid input text at index {i}: {t!r}")
  
          response = requests.post(
              self.endpoint,
              json={"inputs": texts},
              timeout=self.timeout_sec,
          )
          response.raise_for_status()
          payload = response.json()
  
          if not isinstance(payload, list) or len(payload) != len(texts):
              raise RuntimeError(
                  f"TEI response length mismatch: expected {len(texts)}, "
                  f"got {0 if payload is None else len(payload)}"
              )
  
          vectors: List[np.ndarray] = []
          for i, emb in enumerate(payload):
              vec = np.asarray(emb, dtype=np.float32)
              if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
                  raise RuntimeError(f"Invalid TEI embedding at index {i}")
              if normalize_embeddings:
                  vec = self._normalize(vec)
              vectors.append(vec)
          return np.array(vectors, dtype=object)