tei_model.py 3.02 KB
"""TEI text embedding backend client."""

from __future__ import annotations

from typing import List, Union

import numpy as np
import requests


class TEITextModel:
    """
    Text embedding backend implemented via Hugging Face TEI HTTP API.

    Expected TEI endpoint:
      POST {base_url}/embed
      body: {"inputs": ["text1", "text2", ...]}
      response: [[...], [...], ...]
    """

    def __init__(self, base_url: str, timeout_sec: int = 60):
        if not base_url or not str(base_url).strip():
            raise ValueError("TEI base_url must not be empty")
        self.base_url = str(base_url).rstrip("/")
        self.endpoint = f"{self.base_url}/embed"
        self.timeout_sec = int(timeout_sec)
        self._health_check()

    def _health_check(self) -> None:
        health_url = f"{self.base_url}/health"
        response = requests.get(health_url, timeout=5)
        response.raise_for_status()

    @staticmethod
    def _normalize(embedding: np.ndarray) -> np.ndarray:
        norm = np.linalg.norm(embedding)
        if norm <= 0:
            raise RuntimeError("TEI returned zero-norm embedding")
        return embedding / norm

    def encode(
        self,
        sentences: Union[str, List[str]],
        normalize_embeddings: bool = True,
        device: str = "cuda",
        batch_size: int = 32,
    ) -> np.ndarray:
        if isinstance(sentences, str):
            sentences = [sentences]
        return self.encode_batch(
            texts=sentences,
            batch_size=batch_size,
            device=device,
            normalize_embeddings=normalize_embeddings,
        )

    def encode_batch(
        self,
        texts: List[str],
        batch_size: int = 32,
        device: str = "cuda",
        normalize_embeddings: bool = True,
    ) -> np.ndarray:
        del batch_size  # TEI performs its own batching.
        del device      # Not used by HTTP backend.

        if texts is None or len(texts) == 0:
            return np.array([], dtype=object)
        for i, t in enumerate(texts):
            if not isinstance(t, str) or not t.strip():
                raise ValueError(f"Invalid input text at index {i}: {t!r}")

        response = requests.post(
            self.endpoint,
            json={"inputs": texts},
            timeout=self.timeout_sec,
        )
        response.raise_for_status()
        payload = response.json()

        if not isinstance(payload, list) or len(payload) != len(texts):
            raise RuntimeError(
                f"TEI response length mismatch: expected {len(texts)}, "
                f"got {0 if payload is None else len(payload)}"
            )

        vectors: List[np.ndarray] = []
        for i, emb in enumerate(payload):
            vec = np.asarray(emb, dtype=np.float32)
            if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
                raise RuntimeError(f"Invalid TEI embedding at index {i}")
            if normalize_embeddings:
                vec = self._normalize(vec)
            vectors.append(vec)
        return np.array(vectors, dtype=object)