tei_model.py
4.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""TEI text embedding backend client."""
from __future__ import annotations
from typing import Any, List, Union
import numpy as np
import requests
class TEITextModel:
"""
Text embedding backend implemented via Hugging Face TEI HTTP API.
Expected TEI endpoint:
POST {base_url}/embed
body: {"inputs": ["text1", "text2", ...]}
response: [[...], [...], ...]
"""
def __init__(self, base_url: str, timeout_sec: int = 60):
if not base_url or not str(base_url).strip():
raise ValueError("TEI base_url must not be empty")
self.base_url = str(base_url).rstrip("/")
self.endpoint = f"{self.base_url}/embed"
self.timeout_sec = int(timeout_sec)
self._health_check()
def _health_check(self) -> None:
health_url = f"{self.base_url}/health"
response = requests.get(health_url, timeout=5)
response.raise_for_status()
# Probe one tiny embedding at startup so runtime requests do not fail later
# with opaque "Invalid TEI embedding" errors.
probe_resp = requests.post(
self.endpoint,
json={"inputs": ["health check"]},
timeout=min(self.timeout_sec, 15),
)
probe_resp.raise_for_status()
self._parse_payload(probe_resp.json(), expected_len=1)
@staticmethod
def _normalize(embedding: np.ndarray) -> np.ndarray:
norm = np.linalg.norm(embedding)
if norm <= 0:
raise RuntimeError("TEI returned zero-norm embedding")
return embedding / norm
def encode(
self,
sentences: Union[str, List[str]],
normalize_embeddings: bool = True,
device: str = "cuda",
batch_size: int = 32,
) -> np.ndarray:
if isinstance(sentences, str):
sentences = [sentences]
return self.encode_batch(
texts=sentences,
batch_size=batch_size,
device=device,
normalize_embeddings=normalize_embeddings,
)
def encode_batch(
self,
texts: List[str],
batch_size: int = 32,
device: str = "cuda",
normalize_embeddings: bool = True,
) -> np.ndarray:
del batch_size # TEI performs its own batching.
del device # Not used by HTTP backend.
if texts is None or len(texts) == 0:
return np.array([], dtype=object)
for i, t in enumerate(texts):
if not isinstance(t, str) or not t.strip():
raise ValueError(f"Invalid input text at index {i}: {t!r}")
response = requests.post(
self.endpoint,
json={"inputs": texts},
timeout=self.timeout_sec,
)
response.raise_for_status()
payload = response.json()
vectors = self._parse_payload(payload, expected_len=len(texts))
if normalize_embeddings:
vectors = [self._normalize(vec) for vec in vectors]
return np.array(vectors, dtype=object)
def _parse_payload(self, payload: Any, expected_len: int) -> List[np.ndarray]:
if not isinstance(payload, list) or len(payload) != expected_len:
got = 0 if payload is None else (len(payload) if isinstance(payload, list) else "non-list")
raise RuntimeError(
f"TEI response length mismatch: expected {expected_len}, got {got}. "
f"Response type={type(payload).__name__}"
)
vectors: List[np.ndarray] = []
for i, item in enumerate(payload):
emb = item.get("embedding") if isinstance(item, dict) else item
try:
vec = np.asarray(emb, dtype=np.float32)
except (TypeError, ValueError) as exc:
raise RuntimeError(
f"Invalid TEI embedding at index {i}: cannot convert to float array "
f"(item_type={type(item).__name__})"
) from exc
if vec.ndim != 1 or vec.size == 0:
raise RuntimeError(
f"Invalid TEI embedding at index {i}: shape={vec.shape}, size={vec.size}"
)
if not np.isfinite(vec).all():
preview = vec[:8].tolist()
raise RuntimeError(
f"Invalid TEI embedding at index {i}: contains non-finite values, "
f"preview={preview}. This often indicates TEI backend/model runtime issues "
f"(for example an incompatible dtype or model config)."
)
vectors.append(vec)
return vectors