Blame view

embeddings/redis_embedding_cache.py 3.65 KB
4a37d233   tangwang   1. embedding cach...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
  """
  Shared Redis cache for embedding vectors (text + image).
  
  Value format: BF16 bytes (see embeddings/bf16.py)
  Expiration: setex on write + sliding expire on successful read.
  
  This module is intentionally small and dependency-light, so both indexer/search paths
  can reuse it safely.
  """
  
  from __future__ import annotations
  
  import logging
  from datetime import timedelta
  from typing import Optional
  
  import numpy as np
  import redis
  
  from config.env_config import REDIS_CONFIG
  from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis
  
  logger = logging.getLogger(__name__)
  
  
  class RedisEmbeddingCache:
      def __init__(
          self,
          *,
          key_prefix: str,
          namespace: str = "",
          expire_time: Optional[timedelta] = None,
          redis_client: Optional[redis.Redis] = None,
      ):
          self.key_prefix = (key_prefix or "").strip() or "embedding"
          self.namespace = (namespace or "").strip()
          self.expire_time = expire_time or timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180))
  
          if redis_client is not None:
              self.redis_client = redis_client
              return
  
          try:
              client = redis.Redis(
                  host=REDIS_CONFIG.get("host", "localhost"),
                  port=REDIS_CONFIG.get("port", 6479),
                  password=REDIS_CONFIG.get("password"),
                  decode_responses=False,
                  socket_timeout=REDIS_CONFIG.get("socket_timeout", 1),
                  socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1),
                  retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False),
                  health_check_interval=10,
              )
              client.ping()
              self.redis_client = client
          except Exception as e:
              logger.warning("Failed to initialize Redis cache: %s, continuing without cache", e)
              self.redis_client = None
  
      def make_key(self, raw_key: str) -> str:
          if self.namespace:
              return f"{self.key_prefix}:{self.namespace}:{raw_key}"
          return f"{self.key_prefix}:{raw_key}"
  
      def get(self, raw_key: str) -> Optional[np.ndarray]:
          if not self.redis_client:
              return None
          key = self.make_key(raw_key)
          try:
              raw = self.redis_client.get(key)
              if not raw:
                  return None
              vec = decode_embedding_from_redis(raw)
              if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
                  try:
                      self.redis_client.delete(key)
                  except Exception:
                      pass
                  return None
              # Sliding expiration: refresh TTL on hit.
              self.redis_client.expire(key, self.expire_time)
              return vec
          except Exception as e:
              logger.warning("Error retrieving embedding from cache: %s", e)
              return None
  
      def set(self, raw_key: str, embedding: np.ndarray) -> bool:
          if not self.redis_client:
              return False
          key = self.make_key(raw_key)
          try:
              vec = np.asarray(embedding, dtype=np.float32)
              raw = encode_embedding_for_redis(vec)
              self.redis_client.setex(key, self.expire_time, raw)
              return True
          except (
              redis.exceptions.BusyLoadingError,
              redis.exceptions.ConnectionError,
              redis.exceptions.TimeoutError,
              redis.exceptions.RedisError,
          ) as e:
              logger.warning("Redis error storing embedding in cache: %s", e)
              return False
          except Exception as e:
              logger.warning("Error storing embedding in cache: %s", e)
              return False