07cf5a93
tangwang
START_EMBEDDING=...
|
1
2
3
4
|
"""TEI text embedding backend client."""
from __future__ import annotations
|
54ccf28c
tangwang
tei
|
5
|
from typing import Any, List, Union
|
07cf5a93
tangwang
START_EMBEDDING=...
|
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
import numpy as np
import requests
class TEITextModel:
"""
Text embedding backend implemented via Hugging Face TEI HTTP API.
Expected TEI endpoint:
POST {base_url}/embed
body: {"inputs": ["text1", "text2", ...]}
response: [[...], [...], ...]
"""
def __init__(self, base_url: str, timeout_sec: int = 60):
if not base_url or not str(base_url).strip():
raise ValueError("TEI base_url must not be empty")
self.base_url = str(base_url).rstrip("/")
self.endpoint = f"{self.base_url}/embed"
self.timeout_sec = int(timeout_sec)
self._health_check()
def _health_check(self) -> None:
health_url = f"{self.base_url}/health"
response = requests.get(health_url, timeout=5)
response.raise_for_status()
|
54ccf28c
tangwang
tei
|
33
34
35
36
37
38
39
40
41
|
# Probe one tiny embedding at startup so runtime requests do not fail later
# with opaque "Invalid TEI embedding" errors.
probe_resp = requests.post(
self.endpoint,
json={"inputs": ["health check"]},
timeout=min(self.timeout_sec, 15),
)
probe_resp.raise_for_status()
self._parse_payload(probe_resp.json(), expected_len=1)
|
07cf5a93
tangwang
START_EMBEDDING=...
|
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
@staticmethod
def _normalize(embedding: np.ndarray) -> np.ndarray:
norm = np.linalg.norm(embedding)
if norm <= 0:
raise RuntimeError("TEI returned zero-norm embedding")
return embedding / norm
def encode(
self,
sentences: Union[str, List[str]],
normalize_embeddings: bool = True,
device: str = "cuda",
batch_size: int = 32,
) -> np.ndarray:
if isinstance(sentences, str):
sentences = [sentences]
return self.encode_batch(
texts=sentences,
batch_size=batch_size,
device=device,
normalize_embeddings=normalize_embeddings,
)
def encode_batch(
self,
texts: List[str],
batch_size: int = 32,
device: str = "cuda",
normalize_embeddings: bool = True,
) -> np.ndarray:
del batch_size # TEI performs its own batching.
del device # Not used by HTTP backend.
if texts is None or len(texts) == 0:
return np.array([], dtype=object)
for i, t in enumerate(texts):
if not isinstance(t, str) or not t.strip():
raise ValueError(f"Invalid input text at index {i}: {t!r}")
response = requests.post(
self.endpoint,
json={"inputs": texts},
timeout=self.timeout_sec,
)
response.raise_for_status()
payload = response.json()
|
54ccf28c
tangwang
tei
|
89
90
91
92
|
vectors = self._parse_payload(payload, expected_len=len(texts))
if normalize_embeddings:
vectors = [self._normalize(vec) for vec in vectors]
return np.array(vectors, dtype=object)
|
07cf5a93
tangwang
START_EMBEDDING=...
|
93
|
|
54ccf28c
tangwang
tei
|
94
95
96
|
def _parse_payload(self, payload: Any, expected_len: int) -> List[np.ndarray]:
if not isinstance(payload, list) or len(payload) != expected_len:
got = 0 if payload is None else (len(payload) if isinstance(payload, list) else "non-list")
|
07cf5a93
tangwang
START_EMBEDDING=...
|
97
|
raise RuntimeError(
|
54ccf28c
tangwang
tei
|
98
99
|
f"TEI response length mismatch: expected {expected_len}, got {got}. "
f"Response type={type(payload).__name__}"
|
07cf5a93
tangwang
START_EMBEDDING=...
|
100
101
102
|
)
vectors: List[np.ndarray] = []
|
54ccf28c
tangwang
tei
|
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
for i, item in enumerate(payload):
emb = item.get("embedding") if isinstance(item, dict) else item
try:
vec = np.asarray(emb, dtype=np.float32)
except (TypeError, ValueError) as exc:
raise RuntimeError(
f"Invalid TEI embedding at index {i}: cannot convert to float array "
f"(item_type={type(item).__name__})"
) from exc
if vec.ndim != 1 or vec.size == 0:
raise RuntimeError(
f"Invalid TEI embedding at index {i}: shape={vec.shape}, size={vec.size}"
)
if not np.isfinite(vec).all():
preview = vec[:8].tolist()
raise RuntimeError(
f"Invalid TEI embedding at index {i}: contains non-finite values, "
f"preview={preview}. This often indicates TEI backend/model runtime issues "
f"(for example an incompatible dtype or model config)."
)
|
07cf5a93
tangwang
START_EMBEDDING=...
|
124
|
vectors.append(vec)
|
54ccf28c
tangwang
tei
|
125
|
return vectors
|
07cf5a93
tangwang
START_EMBEDDING=...
|
|
|