text_encoder.py 7.85 KB
"""Text embedding client for the local embedding HTTP service."""

import logging
from datetime import timedelta
from typing import Any, List, Optional, Union

import numpy as np
import requests

logger = logging.getLogger(__name__)

from config.loader import get_app_config
from config.services_config import get_embedding_text_base_url
from embeddings.cache_keys import build_text_cache_key
from embeddings.redis_embedding_cache import RedisEmbeddingCache
from request_log_context import build_downstream_request_headers, build_request_log_extra


class TextEmbeddingEncoder:
    """
    Text embedding encoder using network service.
    """

    def __init__(self, service_url: Optional[str] = None):
        resolved_url = service_url or get_embedding_text_base_url()
        redis_config = get_app_config().infrastructure.redis
        self.service_url = str(resolved_url).rstrip("/")
        self.endpoint = f"{self.service_url}/embed/text"
        self.expire_time = timedelta(days=redis_config.cache_expire_days)
        self.cache_prefix = str(redis_config.embedding_cache_prefix).strip() or "embedding"
        logger.info("Creating TextEmbeddingEncoder instance with service URL: %s", self.service_url)

        self.cache = RedisEmbeddingCache(
            key_prefix=self.cache_prefix,
            namespace="",
            expire_time=self.expire_time,
        )

    def _call_service(
        self,
        request_data: List[str],
        normalize_embeddings: bool = True,
        priority: int = 0,
        request_id: Optional[str] = None,
        user_id: Optional[str] = None,
    ) -> List[Any]:
        """
        Call the embedding service API.

        Args:
            request_data: List of texts

        Returns:
            List of embeddings (list[float]) or nulls (None), aligned to input order
        """
        response = None
        try:
            response = requests.post(
                self.endpoint,
                params={
                    "normalize": "true" if normalize_embeddings else "false",
                    "priority": max(0, int(priority)),
                },
                json=request_data,
                headers=build_downstream_request_headers(request_id=request_id, user_id=user_id),
                timeout=60
            )
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            body_preview = ""
            if response is not None:
                try:
                    body_preview = (response.text or "")[:300]
                except Exception:
                    body_preview = ""
            logger.error(
                "TextEmbeddingEncoder service request failed | status=%s body=%s error=%s",
                getattr(response, "status_code", "n/a"),
                body_preview,
                e,
                exc_info=True,
                extra=build_request_log_extra(request_id=request_id, user_id=user_id),
            )
            raise

    def encode(
        self,
        sentences: Union[str, List[str]],
        normalize_embeddings: bool = True,
        priority: int = 0,
        device: str = 'cpu',
        batch_size: int = 32,
        request_id: Optional[str] = None,
        user_id: Optional[str] = None,
    ) -> np.ndarray:
        """
        Encode text into embeddings via network service with Redis caching.

        Args:
            sentences: Single string or list of strings to encode
            normalize_embeddings: Whether to request normalized embeddings from service
            device: Device parameter ignored for service compatibility
            batch_size: Batch size for processing (used for service requests)

        Returns:
            numpy array of dtype=object,元素均为有效 np.ndarray 向量。
            若任一输入无法生成向量,将直接抛出异常。
        """
        # Convert single string to list
        if isinstance(sentences, str):
            sentences = [sentences]

        # Check cache first
        uncached_indices: List[int] = []
        uncached_texts: List[str] = []
        
        embeddings: List[Optional[np.ndarray]] = [None] * len(sentences)
        for i, text in enumerate(sentences):
            cached = self._get_cached_embedding(text, normalize_embeddings=normalize_embeddings)
            if cached is not None:
                embeddings[i] = cached
            else:
                uncached_indices.append(i)
                uncached_texts.append(text)
        
        # Prepare request data for uncached texts (after cache check)
        request_data = list(uncached_texts)
        
        # If there are uncached texts, call service
        if uncached_texts:
            response_data = self._call_service(
                request_data,
                normalize_embeddings=normalize_embeddings,
                priority=priority,
                request_id=request_id,
                user_id=user_id,
            )

            # Process response
            for i, text in enumerate(uncached_texts):
                original_idx = uncached_indices[i]
                if response_data and i < len(response_data):
                    embedding = response_data[i]
                else:
                    embedding = None

                if embedding is not None:
                    embedding_array = np.array(embedding, dtype=np.float32)
                    if self._is_valid_embedding(embedding_array):
                        embeddings[original_idx] = embedding_array
                        self._set_cached_embedding(
                            text,
                            embedding_array,
                            normalize_embeddings=normalize_embeddings,
                        )
                    else:
                        raise ValueError(
                            f"Invalid embedding returned from service for text index {original_idx}"
                        )
                else:
                    raise ValueError(f"No embedding found for text index {original_idx}: {text[:50]}...")
        
        # 返回 numpy 数组(dtype=object),元素均为有效 np.ndarray 向量
        return np.array(embeddings, dtype=object)
        
    def _is_valid_embedding(self, embedding: np.ndarray) -> bool:
        """
        Check if embedding is valid (not None, correct shape, no NaN/Inf).
        
        Args:
            embedding: Embedding array to validate
            
        Returns:
            True if valid, False otherwise
        """
        if embedding is None:
            return False
        if not isinstance(embedding, np.ndarray):
            return False
        if embedding.size == 0:
            return False
        # Check for NaN or Inf values
        if not np.isfinite(embedding).all():
            return False
        return True
    
    def _get_cached_embedding(
        self,
        query: str,
        *,
        normalize_embeddings: bool,
    ) -> Optional[np.ndarray]:
        """Get embedding from cache if exists (with sliding expiration)."""
        cache_key = build_text_cache_key(query, normalize=normalize_embeddings)
        embedding = self.cache.get(cache_key)
        if embedding is not None:
            logger.debug(
                "Cache hit for text embedding | normalize=%s query=%s key=%s",
                normalize_embeddings,
                query,
                cache_key,
            )
        return embedding
    
    def _set_cached_embedding(
        self,
        query: str,
        embedding: np.ndarray,
        *,
        normalize_embeddings: bool,
    ) -> bool:
        """Store embedding in cache."""
        ok = self.cache.set(build_text_cache_key(query, normalize=normalize_embeddings), embedding)
        if ok:
            logger.debug(
                "Successfully cached text embedding | normalize=%s query=%s",
                normalize_embeddings,
                query,
            )
        return ok