text_embedding_tei.py 5.04 KB
"""TEI text embedding backend client."""

from __future__ import annotations

import logging
from typing import Any, List, Union

import numpy as np
import requests

logger = logging.getLogger(__name__)


class TEITextModel:
    """
    Text embedding backend implemented via Hugging Face TEI HTTP API.

    Expected TEI endpoint:
      POST {base_url}/embed
      body: {"inputs": ["text1", "text2", ...]}
      response: [[...], [...], ...]
    """

    def __init__(self, base_url: str, timeout_sec: int = 60, max_client_batch_size: int = 24):
        if not base_url or not str(base_url).strip():
            raise ValueError("TEI base_url must not be empty")
        self.base_url = str(base_url).rstrip("/")
        self.endpoint = f"{self.base_url}/embed"
        self.timeout_sec = int(timeout_sec)
        self.max_client_batch_size = max(1, int(max_client_batch_size))
        self._health_check()

    def _health_check(self) -> None:
        health_url = f"{self.base_url}/health"
        response = requests.get(health_url, timeout=5)
        response.raise_for_status()
        # Probe one tiny embedding at startup so runtime requests do not fail later
        # with opaque "Invalid TEI embedding" errors.
        probe_resp = requests.post(
            self.endpoint,
            json={"inputs": ["health check"]},
            timeout=min(self.timeout_sec, 15),
        )
        probe_resp.raise_for_status()
        self._parse_payload(probe_resp.json(), expected_len=1)

    @staticmethod
    def _normalize(embedding: np.ndarray) -> np.ndarray:
        norm = np.linalg.norm(embedding)
        if norm <= 0:
            raise RuntimeError("TEI returned zero-norm embedding")
        return embedding / norm

    def encode(
        self,
        sentences: Union[str, List[str]],
        normalize_embeddings: bool = True,
        device: str = "cuda",
        batch_size: int = 32,
    ) -> np.ndarray:
        """
        Encode a single sentence or a list of sentences.

        TEI HTTP 后端天然是批量接口,这里统一通过 encode 处理单条和批量输入,
        不再额外暴露 encode_batch。
        """

        if isinstance(sentences, str):
            texts: List[str] = [sentences]
        else:
            texts = sentences

        if texts is None or len(texts) == 0:
            return np.array([], dtype=object)
        for i, t in enumerate(texts):
            if not isinstance(t, str) or not t.strip():
                raise ValueError(f"Invalid input text at index {i}: {t!r}")

        if len(texts) > self.max_client_batch_size:
            logger.info(
                "TEI batch split | total_inputs=%d chunk_size=%d chunks=%d",
                len(texts),
                self.max_client_batch_size,
                (len(texts) + self.max_client_batch_size - 1) // self.max_client_batch_size,
            )

        vectors: List[np.ndarray] = []
        for start in range(0, len(texts), self.max_client_batch_size):
            batch = texts[start : start + self.max_client_batch_size]
            response = requests.post(
                self.endpoint,
                json={"inputs": batch},
                timeout=self.timeout_sec,
            )
            response.raise_for_status()
            payload = response.json()
            parsed = self._parse_payload(payload, expected_len=len(batch))
            if normalize_embeddings:
                parsed = [self._normalize(vec) for vec in parsed]
            vectors.extend(parsed)
        return np.array(vectors, dtype=object)

    def _parse_payload(self, payload: Any, expected_len: int) -> List[np.ndarray]:
        if not isinstance(payload, list) or len(payload) != expected_len:
            got = 0 if payload is None else (len(payload) if isinstance(payload, list) else "non-list")
            raise RuntimeError(
                f"TEI response length mismatch: expected {expected_len}, got {got}. "
                f"Response type={type(payload).__name__}"
            )

        vectors: List[np.ndarray] = []
        for i, item in enumerate(payload):
            emb = item.get("embedding") if isinstance(item, dict) else item
            try:
                vec = np.asarray(emb, dtype=np.float32)
            except (TypeError, ValueError) as exc:
                raise RuntimeError(
                    f"Invalid TEI embedding at index {i}: cannot convert to float array "
                    f"(item_type={type(item).__name__})"
                ) from exc

            if vec.ndim != 1 or vec.size == 0:
                raise RuntimeError(
                    f"Invalid TEI embedding at index {i}: shape={vec.shape}, size={vec.size}"
                )
            if not np.isfinite(vec).all():
                preview = vec[:8].tolist()
                raise RuntimeError(
                    f"Invalid TEI embedding at index {i}: contains non-finite values, "
                    f"preview={preview}. This often indicates TEI backend/model runtime issues "
                    f"(for example an incompatible dtype or model config)."
                )
            vectors.append(vec)
        return vectors