redis_embedding_cache.py
3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Shared Redis cache for embedding vectors (text + image).
Value format: BF16 bytes (see embeddings/bf16.py)
Expiration: setex on write + sliding expire on successful read.
This module is intentionally small and dependency-light, so both indexer/search paths
can reuse it safely.
"""
from __future__ import annotations
import logging
from datetime import timedelta
from typing import Any, Optional
import numpy as np
try:
import redis
except ImportError: # pragma: no cover - runtime fallback for minimal envs
redis = None # type: ignore[assignment]
from config.loader import get_app_config
from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis
logger = logging.getLogger(__name__)
class RedisEmbeddingCache:
def __init__(
self,
*,
key_prefix: str,
namespace: str = "",
expire_time: Optional[timedelta] = None,
redis_client: Optional[redis.Redis] = None,
):
self.key_prefix = (key_prefix or "").strip() or "embedding"
self.namespace = (namespace or "").strip()
redis_config = get_app_config().infrastructure.redis
self.expire_time = expire_time or timedelta(days=redis_config.cache_expire_days)
if redis_client is not None:
self.redis_client = redis_client
return
if redis is None:
logger.warning("redis package is not installed, continuing without embedding cache")
self.redis_client = None
return
try:
client = redis.Redis(
host=redis_config.host,
port=redis_config.port,
db=redis_config.snapshot_db,
password=redis_config.password,
decode_responses=False,
socket_timeout=redis_config.socket_timeout,
socket_connect_timeout=redis_config.socket_connect_timeout,
retry_on_timeout=redis_config.retry_on_timeout,
health_check_interval=10,
)
client.ping()
self.redis_client = client
except Exception as e:
logger.warning("Failed to initialize Redis cache: %s, continuing without cache", e)
self.redis_client = None
def make_key(self, raw_key: str) -> str:
if self.namespace:
return f"{self.key_prefix}:{self.namespace}:{raw_key}"
return f"{self.key_prefix}:{raw_key}"
def get(self, raw_key: str) -> Optional[np.ndarray]:
if not self.redis_client:
return None
key = self.make_key(raw_key)
try:
raw = self.redis_client.get(key)
if not raw:
return None
vec = decode_embedding_from_redis(raw)
if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
try:
self.redis_client.delete(key)
except Exception:
pass
return None
# Sliding expiration: refresh TTL on hit.
self.redis_client.expire(key, self.expire_time)
return vec
except Exception as e:
logger.warning("Error retrieving embedding from cache: %s", e)
return None
def set(self, raw_key: str, embedding: np.ndarray) -> bool:
if not self.redis_client:
return False
key = self.make_key(raw_key)
try:
vec = np.asarray(embedding, dtype=np.float32)
raw = encode_embedding_for_redis(vec)
self.redis_client.setex(key, self.expire_time, raw)
return True
except (
redis.exceptions.BusyLoadingError,
redis.exceptions.ConnectionError,
redis.exceptions.TimeoutError,
redis.exceptions.RedisError,
) as e:
logger.warning("Redis error storing embedding in cache: %s", e)
return False
except Exception as e:
logger.warning("Error storing embedding in cache: %s", e)
return False