""" Shared Redis cache for embedding vectors (text + image). Value format: BF16 bytes (see embeddings/bf16.py) Expiration: setex on write + sliding expire on successful read. This module is intentionally small and dependency-light, so both indexer/search paths can reuse it safely. """ from __future__ import annotations import logging from datetime import timedelta from typing import Any, Optional import numpy as np try: import redis except ImportError: # pragma: no cover - runtime fallback for minimal envs redis = None # type: ignore[assignment] from config.env_config import REDIS_CONFIG from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis logger = logging.getLogger(__name__) class RedisEmbeddingCache: def __init__( self, *, key_prefix: str, namespace: str = "", expire_time: Optional[timedelta] = None, redis_client: Optional[redis.Redis] = None, ): self.key_prefix = (key_prefix or "").strip() or "embedding" self.namespace = (namespace or "").strip() self.expire_time = expire_time or timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180)) if redis_client is not None: self.redis_client = redis_client return if redis is None: logger.warning("redis package is not installed, continuing without embedding cache") self.redis_client = None return try: client = redis.Redis( host=REDIS_CONFIG.get("host", "localhost"), port=REDIS_CONFIG.get("port", 6479), password=REDIS_CONFIG.get("password"), decode_responses=False, socket_timeout=REDIS_CONFIG.get("socket_timeout", 1), socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1), retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False), health_check_interval=10, ) client.ping() self.redis_client = client except Exception as e: logger.warning("Failed to initialize Redis cache: %s, continuing without cache", e) self.redis_client = None def make_key(self, raw_key: str) -> str: if self.namespace: return f"{self.key_prefix}:{self.namespace}:{raw_key}" return f"{self.key_prefix}:{raw_key}" def get(self, raw_key: str) -> Optional[np.ndarray]: if not self.redis_client: return None key = self.make_key(raw_key) try: raw = self.redis_client.get(key) if not raw: return None vec = decode_embedding_from_redis(raw) if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): try: self.redis_client.delete(key) except Exception: pass return None # Sliding expiration: refresh TTL on hit. self.redis_client.expire(key, self.expire_time) return vec except Exception as e: logger.warning("Error retrieving embedding from cache: %s", e) return None def set(self, raw_key: str, embedding: np.ndarray) -> bool: if not self.redis_client: return False key = self.make_key(raw_key) try: vec = np.asarray(embedding, dtype=np.float32) raw = encode_embedding_for_redis(vec) self.redis_client.setex(key, self.expire_time, raw) return True except ( redis.exceptions.BusyLoadingError, redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, redis.exceptions.RedisError, ) as e: logger.warning("Redis error storing embedding in cache: %s", e) return False except Exception as e: logger.warning("Error storing embedding in cache: %s", e) return False