text_embedding_tei.py
5.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""TEI text embedding backend client."""
from __future__ import annotations
import logging
from typing import Any, List, Union
import numpy as np
import requests
logger = logging.getLogger(__name__)
class TEITextModel:
"""
Text embedding backend implemented via Hugging Face TEI HTTP API.
Expected TEI endpoint:
POST {base_url}/embed
body: {"inputs": ["text1", "text2", ...]}
response: [[...], [...], ...]
"""
def __init__(self, base_url: str, timeout_sec: int = 60, max_client_batch_size: int = 24):
if not base_url or not str(base_url).strip():
raise ValueError("TEI base_url must not be empty")
self.base_url = str(base_url).rstrip("/")
self.endpoint = f"{self.base_url}/embed"
self.timeout_sec = int(timeout_sec)
self.max_client_batch_size = max(1, int(max_client_batch_size))
self._health_check()
def _health_check(self) -> None:
health_url = f"{self.base_url}/health"
response = requests.get(health_url, timeout=5)
response.raise_for_status()
# Probe one tiny embedding at startup so runtime requests do not fail later
# with opaque "Invalid TEI embedding" errors.
probe_resp = requests.post(
self.endpoint,
json={"inputs": ["health check"]},
timeout=min(self.timeout_sec, 15),
)
probe_resp.raise_for_status()
self._parse_payload(probe_resp.json(), expected_len=1)
@staticmethod
def _normalize(embedding: np.ndarray) -> np.ndarray:
norm = np.linalg.norm(embedding)
if norm <= 0:
raise RuntimeError("TEI returned zero-norm embedding")
return embedding / norm
def encode(
self,
sentences: Union[str, List[str]],
normalize_embeddings: bool = True,
device: str = "cuda",
batch_size: int = 32,
) -> np.ndarray:
"""
Encode a single sentence or a list of sentences.
TEI HTTP 后端天然是批量接口,这里统一通过 encode 处理单条和批量输入,
不再额外暴露 encode_batch。
"""
if isinstance(sentences, str):
texts: List[str] = [sentences]
else:
texts = sentences
if texts is None or len(texts) == 0:
return np.array([], dtype=object)
for i, t in enumerate(texts):
if not isinstance(t, str) or not t.strip():
raise ValueError(f"Invalid input text at index {i}: {t!r}")
if len(texts) > self.max_client_batch_size:
logger.info(
"TEI batch split | total_inputs=%d chunk_size=%d chunks=%d",
len(texts),
self.max_client_batch_size,
(len(texts) + self.max_client_batch_size - 1) // self.max_client_batch_size,
)
vectors: List[np.ndarray] = []
for start in range(0, len(texts), self.max_client_batch_size):
batch = texts[start : start + self.max_client_batch_size]
response = requests.post(
self.endpoint,
json={"inputs": batch},
timeout=self.timeout_sec,
)
response.raise_for_status()
payload = response.json()
parsed = self._parse_payload(payload, expected_len=len(batch))
if normalize_embeddings:
parsed = [self._normalize(vec) for vec in parsed]
vectors.extend(parsed)
return np.array(vectors, dtype=object)
def _parse_payload(self, payload: Any, expected_len: int) -> List[np.ndarray]:
if not isinstance(payload, list) or len(payload) != expected_len:
got = 0 if payload is None else (len(payload) if isinstance(payload, list) else "non-list")
raise RuntimeError(
f"TEI response length mismatch: expected {expected_len}, got {got}. "
f"Response type={type(payload).__name__}"
)
vectors: List[np.ndarray] = []
for i, item in enumerate(payload):
emb = item.get("embedding") if isinstance(item, dict) else item
try:
vec = np.asarray(emb, dtype=np.float32)
except (TypeError, ValueError) as exc:
raise RuntimeError(
f"Invalid TEI embedding at index {i}: cannot convert to float array "
f"(item_type={type(item).__name__})"
) from exc
if vec.ndim != 1 or vec.size == 0:
raise RuntimeError(
f"Invalid TEI embedding at index {i}: shape={vec.shape}, size={vec.size}"
)
if not np.isfinite(vec).all():
preview = vec[:8].tolist()
raise RuntimeError(
f"Invalid TEI embedding at index {i}: contains non-finite values, "
f"preview={preview}. This often indicates TEI backend/model runtime issues "
f"(for example an incompatible dtype or model config)."
)
vectors.append(vec)
return vectors