"""Text embedding client for the local embedding HTTP service.""" import logging from datetime import timedelta from typing import Any, List, Optional, Union import numpy as np import requests logger = logging.getLogger(__name__) from config.services_config import get_embedding_text_base_url from embeddings.cache_keys import build_text_cache_key from embeddings.redis_embedding_cache import RedisEmbeddingCache # Try to import REDIS_CONFIG, but allow import to fail from config.env_config import REDIS_CONFIG class TextEmbeddingEncoder: """ Text embedding encoder using network service. """ def __init__(self, service_url: Optional[str] = None): resolved_url = service_url or get_embedding_text_base_url() self.service_url = str(resolved_url).rstrip("/") self.endpoint = f"{self.service_url}/embed/text" self.expire_time = timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180)) self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" logger.info("Creating TextEmbeddingEncoder instance with service URL: %s", self.service_url) self.cache = RedisEmbeddingCache( key_prefix=self.cache_prefix, namespace="", expire_time=self.expire_time, ) def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: """ Call the embedding service API. Args: request_data: List of texts Returns: List of embeddings (list[float]) or nulls (None), aligned to input order """ try: response = requests.post( self.endpoint, params={"normalize": "true" if normalize_embeddings else "false"}, json=request_data, timeout=60 ) response.raise_for_status() return response.json() except requests.exceptions.RequestException as e: logger.error(f"TextEmbeddingEncoder service request failed: {e}", exc_info=True) raise def encode( self, sentences: Union[str, List[str]], normalize_embeddings: bool = True, device: str = 'cpu', batch_size: int = 32 ) -> np.ndarray: """ Encode text into embeddings via network service with Redis caching. Args: sentences: Single string or list of strings to encode normalize_embeddings: Whether to request normalized embeddings from service device: Device parameter ignored for service compatibility batch_size: Batch size for processing (used for service requests) Returns: numpy array of dtype=object,元素均为有效 np.ndarray 向量。 若任一输入无法生成向量,将直接抛出异常。 """ # Convert single string to list if isinstance(sentences, str): sentences = [sentences] # Check cache first uncached_indices: List[int] = [] uncached_texts: List[str] = [] embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) for i, text in enumerate(sentences): cached = self._get_cached_embedding(text, normalize_embeddings=normalize_embeddings) if cached is not None: embeddings[i] = cached else: uncached_indices.append(i) uncached_texts.append(text) # Prepare request data for uncached texts (after cache check) request_data = list(uncached_texts) # If there are uncached texts, call service if uncached_texts: response_data = self._call_service(request_data, normalize_embeddings=normalize_embeddings) # Process response for i, text in enumerate(uncached_texts): original_idx = uncached_indices[i] if response_data and i < len(response_data): embedding = response_data[i] else: embedding = None if embedding is not None: embedding_array = np.array(embedding, dtype=np.float32) if self._is_valid_embedding(embedding_array): embeddings[original_idx] = embedding_array self._set_cached_embedding( text, embedding_array, normalize_embeddings=normalize_embeddings, ) else: raise ValueError( f"Invalid embedding returned from service for text index {original_idx}" ) else: raise ValueError(f"No embedding found for text index {original_idx}: {text[:50]}...") # 返回 numpy 数组(dtype=object),元素均为有效 np.ndarray 向量 return np.array(embeddings, dtype=object) def _is_valid_embedding(self, embedding: np.ndarray) -> bool: """ Check if embedding is valid (not None, correct shape, no NaN/Inf). Args: embedding: Embedding array to validate Returns: True if valid, False otherwise """ if embedding is None: return False if not isinstance(embedding, np.ndarray): return False if embedding.size == 0: return False # Check for NaN or Inf values if not np.isfinite(embedding).all(): return False return True def _get_cached_embedding( self, query: str, *, normalize_embeddings: bool, ) -> Optional[np.ndarray]: """Get embedding from cache if exists (with sliding expiration).""" cache_key = build_text_cache_key(query, normalize=normalize_embeddings) embedding = self.cache.get(cache_key) if embedding is not None: logger.debug( "Cache hit for text embedding | normalize=%s query=%s key=%s", normalize_embeddings, query, cache_key, ) return embedding def _set_cached_embedding( self, query: str, embedding: np.ndarray, *, normalize_embeddings: bool, ) -> bool: """Store embedding in cache.""" ok = self.cache.set(build_text_cache_key(query, normalize=normalize_embeddings), embedding) if ok: logger.debug( "Successfully cached text embedding | normalize=%s query=%s", normalize_embeddings, query, ) return ok