Blame view

embeddings/redis_embedding_cache.py 3.96 KB
4a37d233   tangwang   1. embedding cach...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
  """
  Shared Redis cache for embedding vectors (text + image).
  
  Value format: BF16 bytes (see embeddings/bf16.py)
  Expiration: setex on write + sliding expire on successful read.
  
  This module is intentionally small and dependency-light, so both indexer/search paths
  can reuse it safely.
  """
  
  from __future__ import annotations
  
  import logging
  from datetime import timedelta
7214c2e7   tangwang   mplemented**
15
  from typing import Any, Optional
4a37d233   tangwang   1. embedding cach...
16
17
  
  import numpy as np
7214c2e7   tangwang   mplemented**
18
19
20
21
  try:
      import redis
  except ImportError:  # pragma: no cover - runtime fallback for minimal envs
      redis = None  # type: ignore[assignment]
4a37d233   tangwang   1. embedding cach...
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
  
  from config.env_config import REDIS_CONFIG
  from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis
  
  logger = logging.getLogger(__name__)
  
  
  class RedisEmbeddingCache:
      def __init__(
          self,
          *,
          key_prefix: str,
          namespace: str = "",
          expire_time: Optional[timedelta] = None,
          redis_client: Optional[redis.Redis] = None,
      ):
          self.key_prefix = (key_prefix or "").strip() or "embedding"
          self.namespace = (namespace or "").strip()
          self.expire_time = expire_time or timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180))
  
          if redis_client is not None:
              self.redis_client = redis_client
              return
  
7214c2e7   tangwang   mplemented**
46
47
48
49
50
          if redis is None:
              logger.warning("redis package is not installed, continuing without embedding cache")
              self.redis_client = None
              return
  
4a37d233   tangwang   1. embedding cach...
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
          try:
              client = redis.Redis(
                  host=REDIS_CONFIG.get("host", "localhost"),
                  port=REDIS_CONFIG.get("port", 6479),
                  password=REDIS_CONFIG.get("password"),
                  decode_responses=False,
                  socket_timeout=REDIS_CONFIG.get("socket_timeout", 1),
                  socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1),
                  retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False),
                  health_check_interval=10,
              )
              client.ping()
              self.redis_client = client
          except Exception as e:
              logger.warning("Failed to initialize Redis cache: %s, continuing without cache", e)
              self.redis_client = None
  
      def make_key(self, raw_key: str) -> str:
          if self.namespace:
              return f"{self.key_prefix}:{self.namespace}:{raw_key}"
          return f"{self.key_prefix}:{raw_key}"
  
      def get(self, raw_key: str) -> Optional[np.ndarray]:
          if not self.redis_client:
              return None
          key = self.make_key(raw_key)
          try:
              raw = self.redis_client.get(key)
              if not raw:
                  return None
              vec = decode_embedding_from_redis(raw)
              if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
                  try:
                      self.redis_client.delete(key)
                  except Exception:
                      pass
                  return None
              # Sliding expiration: refresh TTL on hit.
              self.redis_client.expire(key, self.expire_time)
              return vec
          except Exception as e:
              logger.warning("Error retrieving embedding from cache: %s", e)
              return None
  
      def set(self, raw_key: str, embedding: np.ndarray) -> bool:
          if not self.redis_client:
              return False
          key = self.make_key(raw_key)
          try:
              vec = np.asarray(embedding, dtype=np.float32)
              raw = encode_embedding_for_redis(vec)
              self.redis_client.setex(key, self.expire_time, raw)
              return True
          except (
              redis.exceptions.BusyLoadingError,
              redis.exceptions.ConnectionError,
              redis.exceptions.TimeoutError,
              redis.exceptions.RedisError,
          ) as e:
              logger.warning("Redis error storing embedding in cache: %s", e)
              return False
          except Exception as e:
              logger.warning("Error storing embedding in cache: %s", e)
              return False