tei_model.py
3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""TEI text embedding backend client."""
from __future__ import annotations
from typing import List, Union
import numpy as np
import requests
class TEITextModel:
"""
Text embedding backend implemented via Hugging Face TEI HTTP API.
Expected TEI endpoint:
POST {base_url}/embed
body: {"inputs": ["text1", "text2", ...]}
response: [[...], [...], ...]
"""
def __init__(self, base_url: str, timeout_sec: int = 60):
if not base_url or not str(base_url).strip():
raise ValueError("TEI base_url must not be empty")
self.base_url = str(base_url).rstrip("/")
self.endpoint = f"{self.base_url}/embed"
self.timeout_sec = int(timeout_sec)
self._health_check()
def _health_check(self) -> None:
health_url = f"{self.base_url}/health"
response = requests.get(health_url, timeout=5)
response.raise_for_status()
@staticmethod
def _normalize(embedding: np.ndarray) -> np.ndarray:
norm = np.linalg.norm(embedding)
if norm <= 0:
raise RuntimeError("TEI returned zero-norm embedding")
return embedding / norm
def encode(
self,
sentences: Union[str, List[str]],
normalize_embeddings: bool = True,
device: str = "cuda",
batch_size: int = 32,
) -> np.ndarray:
if isinstance(sentences, str):
sentences = [sentences]
return self.encode_batch(
texts=sentences,
batch_size=batch_size,
device=device,
normalize_embeddings=normalize_embeddings,
)
def encode_batch(
self,
texts: List[str],
batch_size: int = 32,
device: str = "cuda",
normalize_embeddings: bool = True,
) -> np.ndarray:
del batch_size # TEI performs its own batching.
del device # Not used by HTTP backend.
if texts is None or len(texts) == 0:
return np.array([], dtype=object)
for i, t in enumerate(texts):
if not isinstance(t, str) or not t.strip():
raise ValueError(f"Invalid input text at index {i}: {t!r}")
response = requests.post(
self.endpoint,
json={"inputs": texts},
timeout=self.timeout_sec,
)
response.raise_for_status()
payload = response.json()
if not isinstance(payload, list) or len(payload) != len(texts):
raise RuntimeError(
f"TEI response length mismatch: expected {len(texts)}, "
f"got {0 if payload is None else len(payload)}"
)
vectors: List[np.ndarray] = []
for i, emb in enumerate(payload):
vec = np.asarray(emb, dtype=np.float32)
if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
raise RuntimeError(f"Invalid TEI embedding at index {i}")
if normalize_embeddings:
vec = self._normalize(vec)
vectors.append(vec)
return np.array(vectors, dtype=object)