redis_embedding_cache.py 3.96 KB
"""
Shared Redis cache for embedding vectors (text + image).

Value format: BF16 bytes (see embeddings/bf16.py)
Expiration: setex on write + sliding expire on successful read.

This module is intentionally small and dependency-light, so both indexer/search paths
can reuse it safely.
"""

from __future__ import annotations

import logging
from datetime import timedelta
from typing import Any, Optional

import numpy as np
try:
    import redis
except ImportError:  # pragma: no cover - runtime fallback for minimal envs
    redis = None  # type: ignore[assignment]

from config.env_config import REDIS_CONFIG
from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis

logger = logging.getLogger(__name__)


class RedisEmbeddingCache:
    def __init__(
        self,
        *,
        key_prefix: str,
        namespace: str = "",
        expire_time: Optional[timedelta] = None,
        redis_client: Optional[redis.Redis] = None,
    ):
        self.key_prefix = (key_prefix or "").strip() or "embedding"
        self.namespace = (namespace or "").strip()
        self.expire_time = expire_time or timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180))

        if redis_client is not None:
            self.redis_client = redis_client
            return

        if redis is None:
            logger.warning("redis package is not installed, continuing without embedding cache")
            self.redis_client = None
            return

        try:
            client = redis.Redis(
                host=REDIS_CONFIG.get("host", "localhost"),
                port=REDIS_CONFIG.get("port", 6479),
                password=REDIS_CONFIG.get("password"),
                decode_responses=False,
                socket_timeout=REDIS_CONFIG.get("socket_timeout", 1),
                socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1),
                retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False),
                health_check_interval=10,
            )
            client.ping()
            self.redis_client = client
        except Exception as e:
            logger.warning("Failed to initialize Redis cache: %s, continuing without cache", e)
            self.redis_client = None

    def make_key(self, raw_key: str) -> str:
        if self.namespace:
            return f"{self.key_prefix}:{self.namespace}:{raw_key}"
        return f"{self.key_prefix}:{raw_key}"

    def get(self, raw_key: str) -> Optional[np.ndarray]:
        if not self.redis_client:
            return None
        key = self.make_key(raw_key)
        try:
            raw = self.redis_client.get(key)
            if not raw:
                return None
            vec = decode_embedding_from_redis(raw)
            if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
                try:
                    self.redis_client.delete(key)
                except Exception:
                    pass
                return None
            # Sliding expiration: refresh TTL on hit.
            self.redis_client.expire(key, self.expire_time)
            return vec
        except Exception as e:
            logger.warning("Error retrieving embedding from cache: %s", e)
            return None

    def set(self, raw_key: str, embedding: np.ndarray) -> bool:
        if not self.redis_client:
            return False
        key = self.make_key(raw_key)
        try:
            vec = np.asarray(embedding, dtype=np.float32)
            raw = encode_embedding_for_redis(vec)
            self.redis_client.setex(key, self.expire_time, raw)
            return True
        except (
            redis.exceptions.BusyLoadingError,
            redis.exceptions.ConnectionError,
            redis.exceptions.TimeoutError,
            redis.exceptions.RedisError,
        ) as e:
            logger.warning("Redis error storing embedding in cache: %s", e)
            return False
        except Exception as e:
            logger.warning("Error storing embedding in cache: %s", e)
            return False