Commit 4a37d233b1b7e61d86d841f7fb304aa38d39a919

Authored by tangwang
1 parent 77516841

1. embedding cache float32 -> bf16

2. 抽象出可复用的 embedding Redis 缓存类(图文共用)

详细:
1. embedding 缓存改为 BF16 存 Redis(读回恢复 FP32)
关键行为(按你给的流程落地)
写入前:FP32 embedding →(normalize_embeddings=True 时)L2 normalize →
转 BF16 → bytes(2字节/维,大端) → redis.setex
读取后:redis.get bytes → BF16 → 恢复 FP32(np.float32 向量)
变更点
新增 embeddings/bf16.py
提供 float32_to_bf16 / bf16_to_float32
encode_embedding_for_redis():FP32 → BF16 → bytes
decode_embedding_from_redis():bytes → BF16 → FP32
l2_normalize_fp32():按需归一化
修改 embeddings/text_encoder.py
Redis value 从 pickle.dumps(np.ndarray) 改为 BF16 bytes
缓存 key 改为包含 normalize 标记:{prefix}:{n0|n1}:{query}(避免
normalize 开关不同却共用缓存)
修改 tests/test_embedding_pipeline.py
cache hit 用例改为写入 BF16 bytes,并使用新
key:embedding:n1:cached-text
修改 docs/缓存与Redis使用说明.md
embedding 缓存的 Key/Value 格式更新为 BF16 bytes + n0/n1
修改 scripts/redis/redis_cache_health_check.py
embedding pattern 不再硬编码 embedding:*,改为读取
REDIS_CONFIG["embedding_cache_prefix"]
value 预览从 pickle 解码改为 BF16 解码后展示 dim/bytes/dtype
自检
在激活环境后跑过 BF16 编解码往返 sanity check:bytes
长度、维度恢复正常;归一化向量读回后范数接近 1(会有 BF16 量化误差)。

2. 抽象出可复用的 embedding Redis 缓存类(图文共用)
新增
embeddings/redis_embedding_cache.py:RedisEmbeddingCache
统一 Redis 初始化(读 REDIS_CONFIG)
统一 BF16 bytes 编解码(复用 embeddings/bf16.py)
统一过期策略:写入 setex(expire_time),命中读取后 expire(expire_time)
滑动过期刷新 TTL
统一异常/坏数据处理:解码失败或向量非 1D/为空/含 NaN/Inf 会删除该 key
并当作 miss
已接入复用
文本 embeddings/text_encoder.py
用 self.cache = RedisEmbeddingCache(key_prefix=..., namespace="")
key 仍是:{prefix}:{query}
图片 embeddings/image_encoder.py
用 self.cache = RedisEmbeddingCache(key_prefix=..., namespace="image")
key 仍是:{prefix}:image:{url_or_path}
1 1
2 2
  3 +product_enrich : Partial Mode
  4 +https://help.aliyun.com/zh/model-studio/partial-mode?spm=a2c4g.11186623.help-menu-2400256.d_0_3_0_7.74a630119Ct6zR
  5 +需在messages 数组中将最后一条消息的 role 设置为 assistant,并在其 content 中提供前缀,在此消息中设置参数 "partial": true。messages格式如下:
  6 +[
  7 + {
  8 + "role": "user",
  9 + "content": "请补全这个斐波那契函数,勿添加其它内容"
  10 + },
  11 + {
  12 + "role": "assistant",
  13 + "content": "def calculate_fibonacci(n):\n if n <= 1:\n return n\n else:\n",
  14 + "partial": true
  15 + }
  16 +]
  17 +模型会以前缀内容为起点开始生成。
  18 +
  19 +支持 非思考模式。
  20 +
3 21
4 22
5 23
@@ -109,7 +127,7 @@ translation: @@ -109,7 +127,7 @@ translation:
109 base_url: "http://127.0.0.1:6006" 127 base_url: "http://127.0.0.1:6006"
110 model: "llm" # 或 "qwen-mt-flush",看你想用哪个 128 model: "llm" # 或 "qwen-mt-flush",看你想用哪个
111 timeout_sec: 10.0 129 timeout_sec: 10.0
112 - llm: 130 + llm:.
113 model: "qwen-flash" # 留给翻译服务自身内部使用 131 model: "qwen-flash" # 留给翻译服务自身内部使用
114 qwen-mt: ... 132 qwen-mt: ...
115 deepl: ... 133 deepl: ...
docs/缓存与Redis使用说明.md
@@ -20,7 +20,7 @@ @@ -20,7 +20,7 @@
20 20
21 | 模块 / 场景 | Key 模板 | Value 内容示例 | 过期策略 | 备注 | 21 | 模块 / 场景 | Key 模板 | Value 内容示例 | 过期策略 | 备注 |
22 |------------|----------|----------------|----------|------| 22 |------------|----------|----------------|----------|------|
23 -| 文本向量缓存(embedding) | `{EMBEDDING_CACHE_PREFIX}:{query}` | `pickle.dumps(np.ndarray)`,如 1024 维 BGE 向量 | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py`,前缀由 `REDIS_CONFIG["embedding_cache_prefix"]` 控制 | 23 +| 向量缓存(text/image embedding) | `{EMBEDDING_CACHE_PREFIX}:{query_or_url}` / `{EMBEDDING_CACHE_PREFIX}:image:{url_or_path}` | **BF16 bytes**(每维 2 字节大端存储),读取后恢复为 `np.float32` | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py`(文本)与 `embeddings/image_encoder.py`(图片);前缀由 `REDIS_CONFIG["embedding_cache_prefix"]` 控制 |
24 | 翻译结果缓存(Qwen-MT 翻译) | `{cache_prefix}:{model}:{src}:{tgt}:{sha256(payload)}` | 机翻后的单条字符串 | TTL=`services.translation.cache.ttl_seconds` 秒;可配置滑动过期 | 见 `query/qwen_mt_translate.py` + `config/config.yaml` | 24 | 翻译结果缓存(Qwen-MT 翻译) | `{cache_prefix}:{model}:{src}:{tgt}:{sha256(payload)}` | 机翻后的单条字符串 | TTL=`services.translation.cache.ttl_seconds` 秒;可配置滑动过期 | 见 `query/qwen_mt_translate.py` + `config/config.yaml` |
25 | 商品内容理解缓存(anchors / 语义属性 / tags) | `{ANCHOR_CACHE_PREFIX}:{tenant_or_global}:{target_lang}:{md5(title)}` | `json.dumps(dict)`,包含 id/title/category/tags/anchor_text 等 | TTL=`ANCHOR_CACHE_EXPIRE_DAYS` 天 | 见 `indexer/product_enrich.py` | 25 | 商品内容理解缓存(anchors / 语义属性 / tags) | `{ANCHOR_CACHE_PREFIX}:{tenant_or_global}:{target_lang}:{md5(title)}` | `json.dumps(dict)`,包含 id/title/category/tags/anchor_text 等 | TTL=`ANCHOR_CACHE_EXPIRE_DAYS` 天 | 见 `indexer/product_enrich.py` |
26 26
@@ -44,13 +44,14 @@ @@ -44,13 +44,14 @@
44 44
45 - 字段说明: 45 - 字段说明:
46 - `EMBEDDING_CACHE_PREFIX`:来自 `REDIS_CONFIG["embedding_cache_prefix"]`,默认值为 `"embedding"`,可通过环境变量 `REDIS_EMBEDDING_CACHE_PREFIX` 覆盖; 46 - `EMBEDDING_CACHE_PREFIX`:来自 `REDIS_CONFIG["embedding_cache_prefix"]`,默认值为 `"embedding"`,可通过环境变量 `REDIS_EMBEDDING_CACHE_PREFIX` 覆盖;
47 - - 当前实现**不再区分 language 与 normalize flag**,即无论是否归一化,key 结构都相同;  
48 - `query`:原始文本(未做哈希),注意长度特别长的 query 会直接出现在 key 中。 47 - `query`:原始文本(未做哈希),注意长度特别长的 query 会直接出现在 key 中。
49 48
50 ### 2.2 Value 与类型 49 ### 2.2 Value 与类型
51 50
52 -- 类型:`pickle.dumps(np.ndarray)`,在读取时通过 `pickle.loads` 还原为 `np.ndarray`。  
53 -- 典型示例:BGE-M3 1024 维 `float32` 向量。 51 +- 类型:**BF16 bytes**(bfloat16),每一维用 2 字节无符号整数表示,按**大端**序列化。
  52 +- 写入流程:FP32 向量 → BF16 → bytes → Redis
  53 +- 读取流程:Redis bytes → BF16 → FP32(`np.float32`)向量
  54 +- 典型示例:BGE-M3 1024 维向量在 Redis value 大小约为 \(1024*2=2048\) bytes(不含 Redis 元数据开销)。
54 55
55 ### 2.3 过期策略 56 ### 2.3 过期策略
56 57
embeddings/bf16.py 0 → 100644
@@ -0,0 +1,87 @@ @@ -0,0 +1,87 @@
  1 +"""
  2 +BF16 (bfloat16) codec helpers for Redis embedding cache.
  3 +
  4 +We store embeddings in Redis as:
  5 + FP32 vector -> (optional L2 normalize) -> BF16 (uint16 per element, big-endian) -> bytes
  6 +
  7 +No backward compatibility is provided by design.
  8 +"""
  9 +
  10 +from __future__ import annotations
  11 +
  12 +import struct
  13 +from typing import Iterable, List, Sequence
  14 +
  15 +import numpy as np
  16 +
  17 +
  18 +def float32_to_bf16(value: float) -> int:
  19 + """
  20 + float32 -> bfloat16 (returns 0..65535 uint16)
  21 + Round-to-nearest-even.
  22 + """
  23 + bits = struct.unpack(">I", struct.pack(">f", float(value)))[0]
  24 + rounding_bias = ((bits >> 16) & 1) + 0x7FFF
  25 + bits += rounding_bias
  26 + return (bits >> 16) & 0xFFFF
  27 +
  28 +
  29 +def bf16_to_float32(bf16: int) -> float:
  30 + """
  31 + bfloat16 -> float32.
  32 + bf16 is an int in 0..65535.
  33 + """
  34 + bits = (int(bf16) & 0xFFFF) << 16
  35 + return struct.unpack(">f", struct.pack(">I", bits))[0]
  36 +
  37 +
  38 +def float_array_to_bf16(vector: Sequence[float]) -> List[int]:
  39 + return [float32_to_bf16(v) for v in vector]
  40 +
  41 +
  42 +def bf16_array_to_float(vector_bf16: Sequence[int]) -> List[float]:
  43 + return [bf16_to_float32(v) for v in vector_bf16]
  44 +
  45 +
  46 +def bf16_list_to_bytes(bf16_list: Sequence[int]) -> bytes:
  47 + """Each bf16 uses 2 bytes big-endian."""
  48 + return b"".join(struct.pack(">H", int(x) & 0xFFFF) for x in bf16_list)
  49 +
  50 +
  51 +def bytes_to_bf16_list(data: bytes) -> List[int]:
  52 + if len(data) % 2 != 0:
  53 + raise ValueError("BF16 byte length must be even")
  54 + return [struct.unpack(">H", data[i : i + 2])[0] for i in range(0, len(data), 2)]
  55 +
  56 +
  57 +def encode_embedding_for_redis(embedding: np.ndarray) -> bytes:
  58 + """
  59 + FP32 embedding -> BF16 -> bytes.
  60 + """
  61 + arr = np.asarray(embedding, dtype=np.float32)
  62 + if arr.ndim != 1:
  63 + arr = arr.reshape(-1)
  64 + # Ensure we operate on plain Python floats for the reference codec.
  65 + bf16_list = float_array_to_bf16(arr.tolist())
  66 + return bf16_list_to_bytes(bf16_list)
  67 +
  68 +
  69 +def decode_embedding_from_redis(data: bytes) -> np.ndarray:
  70 + """
  71 + Redis bytes -> BF16 -> FP32 numpy array.
  72 + """
  73 + bf16_list = bytes_to_bf16_list(data)
  74 + floats = bf16_array_to_float(bf16_list)
  75 + return np.asarray(floats, dtype=np.float32)
  76 +
  77 +
  78 +def l2_normalize_fp32(vec: np.ndarray) -> np.ndarray:
  79 + """L2-normalize a 1D FP32 vector. Raises on invalid norms."""
  80 + arr = np.asarray(vec, dtype=np.float32)
  81 + if arr.ndim != 1:
  82 + arr = arr.reshape(-1)
  83 + norm = float(np.linalg.norm(arr))
  84 + if not np.isfinite(norm) or norm <= 0.0:
  85 + raise ValueError("Embedding vector has invalid norm (must be > 0)")
  86 + return (arr / norm).astype(np.float32, copy=False)
  87 +
embeddings/image_encoder.py
@@ -11,6 +11,8 @@ from PIL import Image @@ -11,6 +11,8 @@ from PIL import Image
11 logger = logging.getLogger(__name__) 11 logger = logging.getLogger(__name__)
12 12
13 from config.services_config import get_embedding_base_url 13 from config.services_config import get_embedding_base_url
  14 +from config.env_config import REDIS_CONFIG
  15 +from embeddings.redis_embedding_cache import RedisEmbeddingCache
14 16
15 17
16 class CLIPImageEncoder: 18 class CLIPImageEncoder:
@@ -24,7 +26,13 @@ class CLIPImageEncoder: @@ -24,7 +26,13 @@ class CLIPImageEncoder:
24 resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url() 26 resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url()
25 self.service_url = str(resolved_url).rstrip("/") 27 self.service_url = str(resolved_url).rstrip("/")
26 self.endpoint = f"{self.service_url}/embed/image" 28 self.endpoint = f"{self.service_url}/embed/image"
  29 + # Reuse embedding cache prefix, but separate namespace for images to avoid collisions.
  30 + self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding"
27 logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url) 31 logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url)
  32 + self.cache = RedisEmbeddingCache(
  33 + key_prefix=self.cache_prefix,
  34 + namespace="image",
  35 + )
28 36
29 def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: 37 def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]:
30 """ 38 """
@@ -67,12 +75,17 @@ class CLIPImageEncoder: @@ -67,12 +75,17 @@ class CLIPImageEncoder:
67 Returns: 75 Returns:
68 Embedding vector 76 Embedding vector
69 """ 77 """
  78 + cached = self.cache.get(url)
  79 + if cached is not None:
  80 + return cached
  81 +
70 response_data = self._call_service([url], normalize_embeddings=normalize_embeddings) 82 response_data = self._call_service([url], normalize_embeddings=normalize_embeddings)
71 if not response_data or len(response_data) != 1 or response_data[0] is None: 83 if not response_data or len(response_data) != 1 or response_data[0] is None:
72 raise RuntimeError(f"No image embedding returned for URL: {url}") 84 raise RuntimeError(f"No image embedding returned for URL: {url}")
73 vec = np.array(response_data[0], dtype=np.float32) 85 vec = np.array(response_data[0], dtype=np.float32)
74 if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): 86 if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
75 raise RuntimeError(f"Invalid image embedding returned for URL: {url}") 87 raise RuntimeError(f"Invalid image embedding returned for URL: {url}")
  88 + self.cache.set(url, vec)
76 return vec 89 return vec
77 90
78 def encode_batch( 91 def encode_batch(
@@ -98,8 +111,21 @@ class CLIPImageEncoder: @@ -98,8 +111,21 @@ class CLIPImageEncoder:
98 raise ValueError(f"Invalid image URL/path at index {i}: {img!r}") 111 raise ValueError(f"Invalid image URL/path at index {i}: {img!r}")
99 112
100 results: List[np.ndarray] = [] 113 results: List[np.ndarray] = []
101 - for i in range(0, len(images), batch_size):  
102 - batch_urls = [str(u).strip() for u in images[i:i + batch_size]] 114 + pending_urls: List[str] = []
  115 + pending_positions: List[int] = []
  116 +
  117 + normalized_urls = [str(u).strip() for u in images] # type: ignore[list-item]
  118 + for pos, url in enumerate(normalized_urls):
  119 + cached = self.cache.get(url)
  120 + if cached is not None:
  121 + results.append(cached)
  122 + else:
  123 + results.append(np.array([], dtype=np.float32)) # placeholder
  124 + pending_positions.append(pos)
  125 + pending_urls.append(url)
  126 +
  127 + for i in range(0, len(pending_urls), batch_size):
  128 + batch_urls = pending_urls[i : i + batch_size]
103 response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings) 129 response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings)
104 if not response_data or len(response_data) != len(batch_urls): 130 if not response_data or len(response_data) != len(batch_urls):
105 raise RuntimeError( 131 raise RuntimeError(
@@ -113,7 +139,9 @@ class CLIPImageEncoder: @@ -113,7 +139,9 @@ class CLIPImageEncoder:
113 vec = np.array(embedding, dtype=np.float32) 139 vec = np.array(embedding, dtype=np.float32)
114 if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): 140 if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
115 raise RuntimeError(f"Invalid image embedding returned for URL: {url}") 141 raise RuntimeError(f"Invalid image embedding returned for URL: {url}")
116 - results.append(vec) 142 + self.cache.set(url, vec)
  143 + pos = pending_positions[i + j]
  144 + results[pos] = vec
117 145
118 return results 146 return results
119 147
embeddings/redis_embedding_cache.py 0 → 100644
@@ -0,0 +1,107 @@ @@ -0,0 +1,107 @@
  1 +"""
  2 +Shared Redis cache for embedding vectors (text + image).
  3 +
  4 +Value format: BF16 bytes (see embeddings/bf16.py)
  5 +Expiration: setex on write + sliding expire on successful read.
  6 +
  7 +This module is intentionally small and dependency-light, so both indexer/search paths
  8 +can reuse it safely.
  9 +"""
  10 +
  11 +from __future__ import annotations
  12 +
  13 +import logging
  14 +from datetime import timedelta
  15 +from typing import Optional
  16 +
  17 +import numpy as np
  18 +import redis
  19 +
  20 +from config.env_config import REDIS_CONFIG
  21 +from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis
  22 +
  23 +logger = logging.getLogger(__name__)
  24 +
  25 +
  26 +class RedisEmbeddingCache:
  27 + def __init__(
  28 + self,
  29 + *,
  30 + key_prefix: str,
  31 + namespace: str = "",
  32 + expire_time: Optional[timedelta] = None,
  33 + redis_client: Optional[redis.Redis] = None,
  34 + ):
  35 + self.key_prefix = (key_prefix or "").strip() or "embedding"
  36 + self.namespace = (namespace or "").strip()
  37 + self.expire_time = expire_time or timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180))
  38 +
  39 + if redis_client is not None:
  40 + self.redis_client = redis_client
  41 + return
  42 +
  43 + try:
  44 + client = redis.Redis(
  45 + host=REDIS_CONFIG.get("host", "localhost"),
  46 + port=REDIS_CONFIG.get("port", 6479),
  47 + password=REDIS_CONFIG.get("password"),
  48 + decode_responses=False,
  49 + socket_timeout=REDIS_CONFIG.get("socket_timeout", 1),
  50 + socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1),
  51 + retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False),
  52 + health_check_interval=10,
  53 + )
  54 + client.ping()
  55 + self.redis_client = client
  56 + except Exception as e:
  57 + logger.warning("Failed to initialize Redis cache: %s, continuing without cache", e)
  58 + self.redis_client = None
  59 +
  60 + def make_key(self, raw_key: str) -> str:
  61 + if self.namespace:
  62 + return f"{self.key_prefix}:{self.namespace}:{raw_key}"
  63 + return f"{self.key_prefix}:{raw_key}"
  64 +
  65 + def get(self, raw_key: str) -> Optional[np.ndarray]:
  66 + if not self.redis_client:
  67 + return None
  68 + key = self.make_key(raw_key)
  69 + try:
  70 + raw = self.redis_client.get(key)
  71 + if not raw:
  72 + return None
  73 + vec = decode_embedding_from_redis(raw)
  74 + if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
  75 + try:
  76 + self.redis_client.delete(key)
  77 + except Exception:
  78 + pass
  79 + return None
  80 + # Sliding expiration: refresh TTL on hit.
  81 + self.redis_client.expire(key, self.expire_time)
  82 + return vec
  83 + except Exception as e:
  84 + logger.warning("Error retrieving embedding from cache: %s", e)
  85 + return None
  86 +
  87 + def set(self, raw_key: str, embedding: np.ndarray) -> bool:
  88 + if not self.redis_client:
  89 + return False
  90 + key = self.make_key(raw_key)
  91 + try:
  92 + vec = np.asarray(embedding, dtype=np.float32)
  93 + raw = encode_embedding_for_redis(vec)
  94 + self.redis_client.setex(key, self.expire_time, raw)
  95 + return True
  96 + except (
  97 + redis.exceptions.BusyLoadingError,
  98 + redis.exceptions.ConnectionError,
  99 + redis.exceptions.TimeoutError,
  100 + redis.exceptions.RedisError,
  101 + ) as e:
  102 + logger.warning("Redis error storing embedding in cache: %s", e)
  103 + return False
  104 + except Exception as e:
  105 + logger.warning("Error storing embedding in cache: %s", e)
  106 + return False
  107 +
embeddings/text_encoder.py
@@ -2,17 +2,16 @@ @@ -2,17 +2,16 @@
2 2
3 import logging 3 import logging
4 import os 4 import os
5 -import pickle  
6 from datetime import timedelta 5 from datetime import timedelta
7 from typing import Any, List, Optional, Union 6 from typing import Any, List, Optional, Union
8 7
9 import numpy as np 8 import numpy as np
10 -import redis  
11 import requests 9 import requests
12 10
13 logger = logging.getLogger(__name__) 11 logger = logging.getLogger(__name__)
14 12
15 from config.services_config import get_embedding_base_url 13 from config.services_config import get_embedding_base_url
  14 +from embeddings.redis_embedding_cache import RedisEmbeddingCache
16 15
17 # Try to import REDIS_CONFIG, but allow import to fail 16 # Try to import REDIS_CONFIG, but allow import to fail
18 from config.env_config import REDIS_CONFIG 17 from config.env_config import REDIS_CONFIG
@@ -30,22 +29,11 @@ class TextEmbeddingEncoder: @@ -30,22 +29,11 @@ class TextEmbeddingEncoder:
30 self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" 29 self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding"
31 logger.info("Creating TextEmbeddingEncoder instance with service URL: %s", self.service_url) 30 logger.info("Creating TextEmbeddingEncoder instance with service URL: %s", self.service_url)
32 31
33 - try:  
34 - self.redis_client = redis.Redis(  
35 - host=REDIS_CONFIG.get("host", "localhost"),  
36 - port=REDIS_CONFIG.get("port", 6479),  
37 - password=REDIS_CONFIG.get("password"),  
38 - decode_responses=False,  
39 - socket_timeout=REDIS_CONFIG.get("socket_timeout", 1),  
40 - socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1),  
41 - retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False),  
42 - health_check_interval=10,  
43 - )  
44 - self.redis_client.ping()  
45 - logger.info("Redis cache initialized for embeddings")  
46 - except Exception as e:  
47 - logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e)  
48 - self.redis_client = None 32 + self.cache = RedisEmbeddingCache(
  33 + key_prefix=self.cache_prefix,
  34 + namespace="",
  35 + expire_time=self.expire_time,
  36 + )
49 37
50 def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: 38 def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]:
51 """ 39 """
@@ -127,7 +115,7 @@ class TextEmbeddingEncoder: @@ -127,7 +115,7 @@ class TextEmbeddingEncoder:
127 embedding_array = np.array(embedding, dtype=np.float32) 115 embedding_array = np.array(embedding, dtype=np.float32)
128 if self._is_valid_embedding(embedding_array): 116 if self._is_valid_embedding(embedding_array):
129 embeddings[original_idx] = embedding_array 117 embeddings[original_idx] = embedding_array
130 - self._set_cached_embedding(text, embedding_array, normalize_embeddings) 118 + self._set_cached_embedding(text, embedding_array)
131 else: 119 else:
132 raise ValueError( 120 raise ValueError(
133 f"Invalid embedding returned from service for text index {original_idx}" 121 f"Invalid embedding returned from service for text index {original_idx}"
@@ -161,63 +149,21 @@ class TextEmbeddingEncoder: @@ -161,63 +149,21 @@ class TextEmbeddingEncoder:
161 149
162 def _get_cached_embedding( 150 def _get_cached_embedding(
163 self, 151 self,
164 - query: str 152 + query: str,
165 ) -> Optional[np.ndarray]: 153 ) -> Optional[np.ndarray]:
166 - """Get embedding from cache if exists (with sliding expiration)"""  
167 - if not self.redis_client:  
168 - return None  
169 -  
170 - try:  
171 - cache_key = f"{self.cache_prefix}:{query}"  
172 - cached_data = self.redis_client.get(cache_key)  
173 - if cached_data:  
174 - embedding = pickle.loads(cached_data)  
175 - # Validate cached embedding - if invalid, ignore cache and return None  
176 - if self._is_valid_embedding(embedding):  
177 - logger.debug(f"Cache hit for embedding: {query}")  
178 - # Update expiration time on access (sliding expiration)  
179 - self.redis_client.expire(cache_key, self.expire_time)  
180 - return embedding  
181 - else:  
182 - logger.warning(  
183 - f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), "  
184 - f"ignoring cache for query: {query[:50]}..."  
185 - )  
186 - # Delete invalid cache entry  
187 - try:  
188 - self.redis_client.delete(cache_key)  
189 - except Exception as e:  
190 - logger.debug(f"Failed to delete invalid cache entry: {e}")  
191 - return None  
192 - return None  
193 - except Exception as e:  
194 - logger.error(f"Error retrieving embedding from cache: {e}")  
195 - return None 154 + """Get embedding from cache if exists (with sliding expiration)."""
  155 + embedding = self.cache.get(query)
  156 + if embedding is not None:
  157 + logger.debug(f"Cache hit for embedding: {query}")
  158 + return embedding
196 159
197 def _set_cached_embedding( 160 def _set_cached_embedding(
198 self, 161 self,
199 query: str, 162 query: str,
200 embedding: np.ndarray, 163 embedding: np.ndarray,
201 - normalize_embeddings: bool = True,  
202 ) -> bool: 164 ) -> bool:
203 - """Store embedding in cache"""  
204 - if not self.redis_client:  
205 - return False  
206 -  
207 - try:  
208 - cache_key = f"{self.cache_prefix}:{query}"  
209 - serialized_data = pickle.dumps(embedding)  
210 - self.redis_client.setex(  
211 - cache_key,  
212 - self.expire_time,  
213 - serialized_data  
214 - ) 165 + """Store embedding in cache."""
  166 + ok = self.cache.set(query, embedding)
  167 + if ok:
215 logger.debug(f"Successfully cached embedding for query: {query}") 168 logger.debug(f"Successfully cached embedding for query: {query}")
216 - return True  
217 - except (redis.exceptions.BusyLoadingError, redis.exceptions.ConnectionError,  
218 - redis.exceptions.TimeoutError, redis.exceptions.RedisError) as e:  
219 - logger.warning(f"Redis error storing embedding in cache: {e}")  
220 - return False  
221 - except Exception as e:  
222 - logger.error(f"Error storing embedding in cache: {e}")  
223 - return False 169 + return ok
scripts/redis/redis_cache_health_check.py
@@ -28,7 +28,6 @@ from __future__ import annotations @@ -28,7 +28,6 @@ from __future__ import annotations
28 28
29 import argparse 29 import argparse
30 import json 30 import json
31 -import pickle  
32 import sys 31 import sys
33 from collections import defaultdict 32 from collections import defaultdict
34 from dataclasses import dataclass 33 from dataclasses import dataclass
@@ -45,6 +44,7 @@ sys.path.insert(0, str(PROJECT_ROOT)) @@ -45,6 +44,7 @@ sys.path.insert(0, str(PROJECT_ROOT))
45 44
46 from config.env_config import REDIS_CONFIG # type: ignore 45 from config.env_config import REDIS_CONFIG # type: ignore
47 from config.services_config import get_translation_cache_config # type: ignore 46 from config.services_config import get_translation_cache_config # type: ignore
  47 +from embeddings.bf16 import decode_embedding_from_redis # type: ignore
48 48
49 49
50 @dataclass 50 @dataclass
@@ -58,10 +58,11 @@ def _load_known_cache_types() -&gt; Dict[str, CacheTypeConfig]: @@ -58,10 +58,11 @@ def _load_known_cache_types() -&gt; Dict[str, CacheTypeConfig]:
58 """根据当前配置装配三种已知缓存类型及其前缀 pattern。""" 58 """根据当前配置装配三种已知缓存类型及其前缀 pattern。"""
59 cache_types: Dict[str, CacheTypeConfig] = {} 59 cache_types: Dict[str, CacheTypeConfig] = {}
60 60
61 - # embedding 缓存:固定 embedding:* 前缀 61 + # embedding 缓存:prefix 来自 REDIS_CONFIG['embedding_cache_prefix'](默认 embedding)
  62 + embedding_prefix = REDIS_CONFIG.get("embedding_cache_prefix", "embedding")
62 cache_types["embedding"] = CacheTypeConfig( 63 cache_types["embedding"] = CacheTypeConfig(
63 name="embedding", 64 name="embedding",
64 - pattern="embedding:*", 65 + pattern=f"{embedding_prefix}:*",
65 description="文本向量缓存(embeddings/text_encoder.py)", 66 description="文本向量缓存(embeddings/text_encoder.py)",
66 ) 67 )
67 68
@@ -153,13 +154,14 @@ def decode_value_preview( @@ -153,13 +154,14 @@ def decode_value_preview(
153 if raw_value is None: 154 if raw_value is None:
154 return "<nil>" 155 return "<nil>"
155 156
156 - # embedding: pickle 序列化的 numpy.ndarray 157 + # embedding: BF16 bytes
157 if cache_type == "embedding": 158 if cache_type == "embedding":
158 try: 159 try:
159 - arr = pickle.loads(raw_value) 160 + arr = decode_embedding_from_redis(raw_value)
160 if isinstance(arr, np.ndarray): 161 if isinstance(arr, np.ndarray):
161 - return f"ndarray shape={arr.shape} dtype={arr.dtype}"  
162 - return f"pickle object type={type(arr).__name__}" 162 + dim = int(arr.size)
  163 + return f"bf16 bytes={len(raw_value)} dim={dim} restored_dtype={arr.dtype}"
  164 + return f"bf16 decode type={type(arr).__name__}"
163 except Exception: 165 except Exception:
164 return f"<binary {len(raw_value)} bytes>" 166 return f"<binary {len(raw_value)} bytes>"
165 167
tests/test_embedding_pipeline.py
1 -import pickle  
2 from typing import Any, Dict, List, Optional 1 from typing import Any, Dict, List, Optional
3 2
4 import numpy as np 3 import numpy as np
@@ -13,6 +12,7 @@ from config import ( @@ -13,6 +12,7 @@ from config import (
13 SearchConfig, 12 SearchConfig,
14 ) 13 )
15 from embeddings.text_encoder import TextEmbeddingEncoder 14 from embeddings.text_encoder import TextEmbeddingEncoder
  15 +from embeddings.bf16 import encode_embedding_for_redis
16 from query import QueryParser 16 from query import QueryParser
17 17
18 18
@@ -128,7 +128,7 @@ def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): @@ -128,7 +128,7 @@ def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch):
128 def test_text_embedding_encoder_cache_hit(monkeypatch): 128 def test_text_embedding_encoder_cache_hit(monkeypatch):
129 fake_redis = _FakeRedis() 129 fake_redis = _FakeRedis()
130 cached = np.array([0.9, 0.8], dtype=np.float32) 130 cached = np.array([0.9, 0.8], dtype=np.float32)
131 - fake_redis.store["embedding:cached-text"] = pickle.dumps(cached) 131 + fake_redis.store["embedding:cached-text"] = encode_embedding_for_redis(cached)
132 monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) 132 monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis)
133 133
134 calls = {"count": 0} 134 calls = {"count": 0}