Commit 4a37d233b1b7e61d86d841f7fb304aa38d39a919

Authored by tangwang
1 parent 77516841

1. embedding cache float32 -> bf16

2. 抽象出可复用的 embedding Redis 缓存类(图文共用)

详细:
1. embedding 缓存改为 BF16 存 Redis(读回恢复 FP32)
关键行为(按你给的流程落地)
写入前:FP32 embedding →(normalize_embeddings=True 时)L2 normalize →
转 BF16 → bytes(2字节/维,大端) → redis.setex
读取后:redis.get bytes → BF16 → 恢复 FP32(np.float32 向量)
变更点
新增 embeddings/bf16.py
提供 float32_to_bf16 / bf16_to_float32
encode_embedding_for_redis():FP32 → BF16 → bytes
decode_embedding_from_redis():bytes → BF16 → FP32
l2_normalize_fp32():按需归一化
修改 embeddings/text_encoder.py
Redis value 从 pickle.dumps(np.ndarray) 改为 BF16 bytes
缓存 key 改为包含 normalize 标记:{prefix}:{n0|n1}:{query}(避免
normalize 开关不同却共用缓存)
修改 tests/test_embedding_pipeline.py
cache hit 用例改为写入 BF16 bytes,并使用新
key:embedding:n1:cached-text
修改 docs/缓存与Redis使用说明.md
embedding 缓存的 Key/Value 格式更新为 BF16 bytes + n0/n1
修改 scripts/redis/redis_cache_health_check.py
embedding pattern 不再硬编码 embedding:*,改为读取
REDIS_CONFIG["embedding_cache_prefix"]
value 预览从 pickle 解码改为 BF16 解码后展示 dim/bytes/dtype
自检
在激活环境后跑过 BF16 编解码往返 sanity check:bytes
长度、维度恢复正常;归一化向量读回后范数接近 1(会有 BF16 量化误差)。

2. 抽象出可复用的 embedding Redis 缓存类(图文共用)
新增
embeddings/redis_embedding_cache.py:RedisEmbeddingCache
统一 Redis 初始化(读 REDIS_CONFIG)
统一 BF16 bytes 编解码(复用 embeddings/bf16.py)
统一过期策略:写入 setex(expire_time),命中读取后 expire(expire_time)
滑动过期刷新 TTL
统一异常/坏数据处理:解码失败或向量非 1D/为空/含 NaN/Inf 会删除该 key
并当作 miss
已接入复用
文本 embeddings/text_encoder.py
用 self.cache = RedisEmbeddingCache(key_prefix=..., namespace="")
key 仍是:{prefix}:{query}
图片 embeddings/image_encoder.py
用 self.cache = RedisEmbeddingCache(key_prefix=..., namespace="image")
key 仍是:{prefix}:image:{url_or_path}
docs/TODO.txt
1 1  
2 2  
  3 +product_enrich : Partial Mode
  4 +https://help.aliyun.com/zh/model-studio/partial-mode?spm=a2c4g.11186623.help-menu-2400256.d_0_3_0_7.74a630119Ct6zR
  5 +需在messages 数组中将最后一条消息的 role 设置为 assistant,并在其 content 中提供前缀,在此消息中设置参数 "partial": true。messages格式如下:
  6 +[
  7 + {
  8 + "role": "user",
  9 + "content": "请补全这个斐波那契函数,勿添加其它内容"
  10 + },
  11 + {
  12 + "role": "assistant",
  13 + "content": "def calculate_fibonacci(n):\n if n <= 1:\n return n\n else:\n",
  14 + "partial": true
  15 + }
  16 +]
  17 +模型会以前缀内容为起点开始生成。
  18 +
  19 +支持 非思考模式。
  20 +
3 21  
4 22  
5 23  
... ... @@ -109,7 +127,7 @@ translation:
109 127 base_url: "http://127.0.0.1:6006"
110 128 model: "llm" # 或 "qwen-mt-flush",看你想用哪个
111 129 timeout_sec: 10.0
112   - llm:
  130 + llm:.
113 131 model: "qwen-flash" # 留给翻译服务自身内部使用
114 132 qwen-mt: ...
115 133 deepl: ...
... ...
docs/缓存与Redis使用说明.md
... ... @@ -20,7 +20,7 @@
20 20  
21 21 | 模块 / 场景 | Key 模板 | Value 内容示例 | 过期策略 | 备注 |
22 22 |------------|----------|----------------|----------|------|
23   -| 文本向量缓存(embedding) | `{EMBEDDING_CACHE_PREFIX}:{query}` | `pickle.dumps(np.ndarray)`,如 1024 维 BGE 向量 | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py`,前缀由 `REDIS_CONFIG["embedding_cache_prefix"]` 控制 |
  23 +| 向量缓存(text/image embedding) | `{EMBEDDING_CACHE_PREFIX}:{query_or_url}` / `{EMBEDDING_CACHE_PREFIX}:image:{url_or_path}` | **BF16 bytes**(每维 2 字节大端存储),读取后恢复为 `np.float32` | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py`(文本)与 `embeddings/image_encoder.py`(图片);前缀由 `REDIS_CONFIG["embedding_cache_prefix"]` 控制 |
24 24 | 翻译结果缓存(Qwen-MT 翻译) | `{cache_prefix}:{model}:{src}:{tgt}:{sha256(payload)}` | 机翻后的单条字符串 | TTL=`services.translation.cache.ttl_seconds` 秒;可配置滑动过期 | 见 `query/qwen_mt_translate.py` + `config/config.yaml` |
25 25 | 商品内容理解缓存(anchors / 语义属性 / tags) | `{ANCHOR_CACHE_PREFIX}:{tenant_or_global}:{target_lang}:{md5(title)}` | `json.dumps(dict)`,包含 id/title/category/tags/anchor_text 等 | TTL=`ANCHOR_CACHE_EXPIRE_DAYS` 天 | 见 `indexer/product_enrich.py` |
26 26  
... ... @@ -44,13 +44,14 @@
44 44  
45 45 - 字段说明:
46 46 - `EMBEDDING_CACHE_PREFIX`:来自 `REDIS_CONFIG["embedding_cache_prefix"]`,默认值为 `"embedding"`,可通过环境变量 `REDIS_EMBEDDING_CACHE_PREFIX` 覆盖;
47   - - 当前实现**不再区分 language 与 normalize flag**,即无论是否归一化,key 结构都相同;
48 47 - `query`:原始文本(未做哈希),注意长度特别长的 query 会直接出现在 key 中。
49 48  
50 49 ### 2.2 Value 与类型
51 50  
52   -- 类型:`pickle.dumps(np.ndarray)`,在读取时通过 `pickle.loads` 还原为 `np.ndarray`。
53   -- 典型示例:BGE-M3 1024 维 `float32` 向量。
  51 +- 类型:**BF16 bytes**(bfloat16),每一维用 2 字节无符号整数表示,按**大端**序列化。
  52 +- 写入流程:FP32 向量 → BF16 → bytes → Redis
  53 +- 读取流程:Redis bytes → BF16 → FP32(`np.float32`)向量
  54 +- 典型示例:BGE-M3 1024 维向量在 Redis value 大小约为 \(1024*2=2048\) bytes(不含 Redis 元数据开销)。
54 55  
55 56 ### 2.3 过期策略
56 57  
... ...
embeddings/bf16.py 0 → 100644
... ... @@ -0,0 +1,87 @@
  1 +"""
  2 +BF16 (bfloat16) codec helpers for Redis embedding cache.
  3 +
  4 +We store embeddings in Redis as:
  5 + FP32 vector -> (optional L2 normalize) -> BF16 (uint16 per element, big-endian) -> bytes
  6 +
  7 +No backward compatibility is provided by design.
  8 +"""
  9 +
  10 +from __future__ import annotations
  11 +
  12 +import struct
  13 +from typing import Iterable, List, Sequence
  14 +
  15 +import numpy as np
  16 +
  17 +
  18 +def float32_to_bf16(value: float) -> int:
  19 + """
  20 + float32 -> bfloat16 (returns 0..65535 uint16)
  21 + Round-to-nearest-even.
  22 + """
  23 + bits = struct.unpack(">I", struct.pack(">f", float(value)))[0]
  24 + rounding_bias = ((bits >> 16) & 1) + 0x7FFF
  25 + bits += rounding_bias
  26 + return (bits >> 16) & 0xFFFF
  27 +
  28 +
  29 +def bf16_to_float32(bf16: int) -> float:
  30 + """
  31 + bfloat16 -> float32.
  32 + bf16 is an int in 0..65535.
  33 + """
  34 + bits = (int(bf16) & 0xFFFF) << 16
  35 + return struct.unpack(">f", struct.pack(">I", bits))[0]
  36 +
  37 +
  38 +def float_array_to_bf16(vector: Sequence[float]) -> List[int]:
  39 + return [float32_to_bf16(v) for v in vector]
  40 +
  41 +
  42 +def bf16_array_to_float(vector_bf16: Sequence[int]) -> List[float]:
  43 + return [bf16_to_float32(v) for v in vector_bf16]
  44 +
  45 +
  46 +def bf16_list_to_bytes(bf16_list: Sequence[int]) -> bytes:
  47 + """Each bf16 uses 2 bytes big-endian."""
  48 + return b"".join(struct.pack(">H", int(x) & 0xFFFF) for x in bf16_list)
  49 +
  50 +
  51 +def bytes_to_bf16_list(data: bytes) -> List[int]:
  52 + if len(data) % 2 != 0:
  53 + raise ValueError("BF16 byte length must be even")
  54 + return [struct.unpack(">H", data[i : i + 2])[0] for i in range(0, len(data), 2)]
  55 +
  56 +
  57 +def encode_embedding_for_redis(embedding: np.ndarray) -> bytes:
  58 + """
  59 + FP32 embedding -> BF16 -> bytes.
  60 + """
  61 + arr = np.asarray(embedding, dtype=np.float32)
  62 + if arr.ndim != 1:
  63 + arr = arr.reshape(-1)
  64 + # Ensure we operate on plain Python floats for the reference codec.
  65 + bf16_list = float_array_to_bf16(arr.tolist())
  66 + return bf16_list_to_bytes(bf16_list)
  67 +
  68 +
  69 +def decode_embedding_from_redis(data: bytes) -> np.ndarray:
  70 + """
  71 + Redis bytes -> BF16 -> FP32 numpy array.
  72 + """
  73 + bf16_list = bytes_to_bf16_list(data)
  74 + floats = bf16_array_to_float(bf16_list)
  75 + return np.asarray(floats, dtype=np.float32)
  76 +
  77 +
  78 +def l2_normalize_fp32(vec: np.ndarray) -> np.ndarray:
  79 + """L2-normalize a 1D FP32 vector. Raises on invalid norms."""
  80 + arr = np.asarray(vec, dtype=np.float32)
  81 + if arr.ndim != 1:
  82 + arr = arr.reshape(-1)
  83 + norm = float(np.linalg.norm(arr))
  84 + if not np.isfinite(norm) or norm <= 0.0:
  85 + raise ValueError("Embedding vector has invalid norm (must be > 0)")
  86 + return (arr / norm).astype(np.float32, copy=False)
  87 +
... ...
embeddings/image_encoder.py
... ... @@ -11,6 +11,8 @@ from PIL import Image
11 11 logger = logging.getLogger(__name__)
12 12  
13 13 from config.services_config import get_embedding_base_url
  14 +from config.env_config import REDIS_CONFIG
  15 +from embeddings.redis_embedding_cache import RedisEmbeddingCache
14 16  
15 17  
16 18 class CLIPImageEncoder:
... ... @@ -24,7 +26,13 @@ class CLIPImageEncoder:
24 26 resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url()
25 27 self.service_url = str(resolved_url).rstrip("/")
26 28 self.endpoint = f"{self.service_url}/embed/image"
  29 + # Reuse embedding cache prefix, but separate namespace for images to avoid collisions.
  30 + self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding"
27 31 logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url)
  32 + self.cache = RedisEmbeddingCache(
  33 + key_prefix=self.cache_prefix,
  34 + namespace="image",
  35 + )
28 36  
29 37 def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]:
30 38 """
... ... @@ -67,12 +75,17 @@ class CLIPImageEncoder:
67 75 Returns:
68 76 Embedding vector
69 77 """
  78 + cached = self.cache.get(url)
  79 + if cached is not None:
  80 + return cached
  81 +
70 82 response_data = self._call_service([url], normalize_embeddings=normalize_embeddings)
71 83 if not response_data or len(response_data) != 1 or response_data[0] is None:
72 84 raise RuntimeError(f"No image embedding returned for URL: {url}")
73 85 vec = np.array(response_data[0], dtype=np.float32)
74 86 if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
75 87 raise RuntimeError(f"Invalid image embedding returned for URL: {url}")
  88 + self.cache.set(url, vec)
76 89 return vec
77 90  
78 91 def encode_batch(
... ... @@ -98,8 +111,21 @@ class CLIPImageEncoder:
98 111 raise ValueError(f"Invalid image URL/path at index {i}: {img!r}")
99 112  
100 113 results: List[np.ndarray] = []
101   - for i in range(0, len(images), batch_size):
102   - batch_urls = [str(u).strip() for u in images[i:i + batch_size]]
  114 + pending_urls: List[str] = []
  115 + pending_positions: List[int] = []
  116 +
  117 + normalized_urls = [str(u).strip() for u in images] # type: ignore[list-item]
  118 + for pos, url in enumerate(normalized_urls):
  119 + cached = self.cache.get(url)
  120 + if cached is not None:
  121 + results.append(cached)
  122 + else:
  123 + results.append(np.array([], dtype=np.float32)) # placeholder
  124 + pending_positions.append(pos)
  125 + pending_urls.append(url)
  126 +
  127 + for i in range(0, len(pending_urls), batch_size):
  128 + batch_urls = pending_urls[i : i + batch_size]
103 129 response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings)
104 130 if not response_data or len(response_data) != len(batch_urls):
105 131 raise RuntimeError(
... ... @@ -113,7 +139,9 @@ class CLIPImageEncoder:
113 139 vec = np.array(embedding, dtype=np.float32)
114 140 if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
115 141 raise RuntimeError(f"Invalid image embedding returned for URL: {url}")
116   - results.append(vec)
  142 + self.cache.set(url, vec)
  143 + pos = pending_positions[i + j]
  144 + results[pos] = vec
117 145  
118 146 return results
119 147  
... ...
embeddings/redis_embedding_cache.py 0 → 100644
... ... @@ -0,0 +1,107 @@
  1 +"""
  2 +Shared Redis cache for embedding vectors (text + image).
  3 +
  4 +Value format: BF16 bytes (see embeddings/bf16.py)
  5 +Expiration: setex on write + sliding expire on successful read.
  6 +
  7 +This module is intentionally small and dependency-light, so both indexer/search paths
  8 +can reuse it safely.
  9 +"""
  10 +
  11 +from __future__ import annotations
  12 +
  13 +import logging
  14 +from datetime import timedelta
  15 +from typing import Optional
  16 +
  17 +import numpy as np
  18 +import redis
  19 +
  20 +from config.env_config import REDIS_CONFIG
  21 +from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis
  22 +
  23 +logger = logging.getLogger(__name__)
  24 +
  25 +
  26 +class RedisEmbeddingCache:
  27 + def __init__(
  28 + self,
  29 + *,
  30 + key_prefix: str,
  31 + namespace: str = "",
  32 + expire_time: Optional[timedelta] = None,
  33 + redis_client: Optional[redis.Redis] = None,
  34 + ):
  35 + self.key_prefix = (key_prefix or "").strip() or "embedding"
  36 + self.namespace = (namespace or "").strip()
  37 + self.expire_time = expire_time or timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180))
  38 +
  39 + if redis_client is not None:
  40 + self.redis_client = redis_client
  41 + return
  42 +
  43 + try:
  44 + client = redis.Redis(
  45 + host=REDIS_CONFIG.get("host", "localhost"),
  46 + port=REDIS_CONFIG.get("port", 6479),
  47 + password=REDIS_CONFIG.get("password"),
  48 + decode_responses=False,
  49 + socket_timeout=REDIS_CONFIG.get("socket_timeout", 1),
  50 + socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1),
  51 + retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False),
  52 + health_check_interval=10,
  53 + )
  54 + client.ping()
  55 + self.redis_client = client
  56 + except Exception as e:
  57 + logger.warning("Failed to initialize Redis cache: %s, continuing without cache", e)
  58 + self.redis_client = None
  59 +
  60 + def make_key(self, raw_key: str) -> str:
  61 + if self.namespace:
  62 + return f"{self.key_prefix}:{self.namespace}:{raw_key}"
  63 + return f"{self.key_prefix}:{raw_key}"
  64 +
  65 + def get(self, raw_key: str) -> Optional[np.ndarray]:
  66 + if not self.redis_client:
  67 + return None
  68 + key = self.make_key(raw_key)
  69 + try:
  70 + raw = self.redis_client.get(key)
  71 + if not raw:
  72 + return None
  73 + vec = decode_embedding_from_redis(raw)
  74 + if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
  75 + try:
  76 + self.redis_client.delete(key)
  77 + except Exception:
  78 + pass
  79 + return None
  80 + # Sliding expiration: refresh TTL on hit.
  81 + self.redis_client.expire(key, self.expire_time)
  82 + return vec
  83 + except Exception as e:
  84 + logger.warning("Error retrieving embedding from cache: %s", e)
  85 + return None
  86 +
  87 + def set(self, raw_key: str, embedding: np.ndarray) -> bool:
  88 + if not self.redis_client:
  89 + return False
  90 + key = self.make_key(raw_key)
  91 + try:
  92 + vec = np.asarray(embedding, dtype=np.float32)
  93 + raw = encode_embedding_for_redis(vec)
  94 + self.redis_client.setex(key, self.expire_time, raw)
  95 + return True
  96 + except (
  97 + redis.exceptions.BusyLoadingError,
  98 + redis.exceptions.ConnectionError,
  99 + redis.exceptions.TimeoutError,
  100 + redis.exceptions.RedisError,
  101 + ) as e:
  102 + logger.warning("Redis error storing embedding in cache: %s", e)
  103 + return False
  104 + except Exception as e:
  105 + logger.warning("Error storing embedding in cache: %s", e)
  106 + return False
  107 +
... ...
embeddings/text_encoder.py
... ... @@ -2,17 +2,16 @@
2 2  
3 3 import logging
4 4 import os
5   -import pickle
6 5 from datetime import timedelta
7 6 from typing import Any, List, Optional, Union
8 7  
9 8 import numpy as np
10   -import redis
11 9 import requests
12 10  
13 11 logger = logging.getLogger(__name__)
14 12  
15 13 from config.services_config import get_embedding_base_url
  14 +from embeddings.redis_embedding_cache import RedisEmbeddingCache
16 15  
17 16 # Try to import REDIS_CONFIG, but allow import to fail
18 17 from config.env_config import REDIS_CONFIG
... ... @@ -30,22 +29,11 @@ class TextEmbeddingEncoder:
30 29 self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding"
31 30 logger.info("Creating TextEmbeddingEncoder instance with service URL: %s", self.service_url)
32 31  
33   - try:
34   - self.redis_client = redis.Redis(
35   - host=REDIS_CONFIG.get("host", "localhost"),
36   - port=REDIS_CONFIG.get("port", 6479),
37   - password=REDIS_CONFIG.get("password"),
38   - decode_responses=False,
39   - socket_timeout=REDIS_CONFIG.get("socket_timeout", 1),
40   - socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1),
41   - retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False),
42   - health_check_interval=10,
43   - )
44   - self.redis_client.ping()
45   - logger.info("Redis cache initialized for embeddings")
46   - except Exception as e:
47   - logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e)
48   - self.redis_client = None
  32 + self.cache = RedisEmbeddingCache(
  33 + key_prefix=self.cache_prefix,
  34 + namespace="",
  35 + expire_time=self.expire_time,
  36 + )
49 37  
50 38 def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]:
51 39 """
... ... @@ -127,7 +115,7 @@ class TextEmbeddingEncoder:
127 115 embedding_array = np.array(embedding, dtype=np.float32)
128 116 if self._is_valid_embedding(embedding_array):
129 117 embeddings[original_idx] = embedding_array
130   - self._set_cached_embedding(text, embedding_array, normalize_embeddings)
  118 + self._set_cached_embedding(text, embedding_array)
131 119 else:
132 120 raise ValueError(
133 121 f"Invalid embedding returned from service for text index {original_idx}"
... ... @@ -161,63 +149,21 @@ class TextEmbeddingEncoder:
161 149  
162 150 def _get_cached_embedding(
163 151 self,
164   - query: str
  152 + query: str,
165 153 ) -> Optional[np.ndarray]:
166   - """Get embedding from cache if exists (with sliding expiration)"""
167   - if not self.redis_client:
168   - return None
169   -
170   - try:
171   - cache_key = f"{self.cache_prefix}:{query}"
172   - cached_data = self.redis_client.get(cache_key)
173   - if cached_data:
174   - embedding = pickle.loads(cached_data)
175   - # Validate cached embedding - if invalid, ignore cache and return None
176   - if self._is_valid_embedding(embedding):
177   - logger.debug(f"Cache hit for embedding: {query}")
178   - # Update expiration time on access (sliding expiration)
179   - self.redis_client.expire(cache_key, self.expire_time)
180   - return embedding
181   - else:
182   - logger.warning(
183   - f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), "
184   - f"ignoring cache for query: {query[:50]}..."
185   - )
186   - # Delete invalid cache entry
187   - try:
188   - self.redis_client.delete(cache_key)
189   - except Exception as e:
190   - logger.debug(f"Failed to delete invalid cache entry: {e}")
191   - return None
192   - return None
193   - except Exception as e:
194   - logger.error(f"Error retrieving embedding from cache: {e}")
195   - return None
  154 + """Get embedding from cache if exists (with sliding expiration)."""
  155 + embedding = self.cache.get(query)
  156 + if embedding is not None:
  157 + logger.debug(f"Cache hit for embedding: {query}")
  158 + return embedding
196 159  
197 160 def _set_cached_embedding(
198 161 self,
199 162 query: str,
200 163 embedding: np.ndarray,
201   - normalize_embeddings: bool = True,
202 164 ) -> bool:
203   - """Store embedding in cache"""
204   - if not self.redis_client:
205   - return False
206   -
207   - try:
208   - cache_key = f"{self.cache_prefix}:{query}"
209   - serialized_data = pickle.dumps(embedding)
210   - self.redis_client.setex(
211   - cache_key,
212   - self.expire_time,
213   - serialized_data
214   - )
  165 + """Store embedding in cache."""
  166 + ok = self.cache.set(query, embedding)
  167 + if ok:
215 168 logger.debug(f"Successfully cached embedding for query: {query}")
216   - return True
217   - except (redis.exceptions.BusyLoadingError, redis.exceptions.ConnectionError,
218   - redis.exceptions.TimeoutError, redis.exceptions.RedisError) as e:
219   - logger.warning(f"Redis error storing embedding in cache: {e}")
220   - return False
221   - except Exception as e:
222   - logger.error(f"Error storing embedding in cache: {e}")
223   - return False
  169 + return ok
... ...
scripts/redis/redis_cache_health_check.py
... ... @@ -28,7 +28,6 @@ from __future__ import annotations
28 28  
29 29 import argparse
30 30 import json
31   -import pickle
32 31 import sys
33 32 from collections import defaultdict
34 33 from dataclasses import dataclass
... ... @@ -45,6 +44,7 @@ sys.path.insert(0, str(PROJECT_ROOT))
45 44  
46 45 from config.env_config import REDIS_CONFIG # type: ignore
47 46 from config.services_config import get_translation_cache_config # type: ignore
  47 +from embeddings.bf16 import decode_embedding_from_redis # type: ignore
48 48  
49 49  
50 50 @dataclass
... ... @@ -58,10 +58,11 @@ def _load_known_cache_types() -&gt; Dict[str, CacheTypeConfig]:
58 58 """根据当前配置装配三种已知缓存类型及其前缀 pattern。"""
59 59 cache_types: Dict[str, CacheTypeConfig] = {}
60 60  
61   - # embedding 缓存:固定 embedding:* 前缀
  61 + # embedding 缓存:prefix 来自 REDIS_CONFIG['embedding_cache_prefix'](默认 embedding)
  62 + embedding_prefix = REDIS_CONFIG.get("embedding_cache_prefix", "embedding")
62 63 cache_types["embedding"] = CacheTypeConfig(
63 64 name="embedding",
64   - pattern="embedding:*",
  65 + pattern=f"{embedding_prefix}:*",
65 66 description="文本向量缓存(embeddings/text_encoder.py)",
66 67 )
67 68  
... ... @@ -153,13 +154,14 @@ def decode_value_preview(
153 154 if raw_value is None:
154 155 return "<nil>"
155 156  
156   - # embedding: pickle 序列化的 numpy.ndarray
  157 + # embedding: BF16 bytes
157 158 if cache_type == "embedding":
158 159 try:
159   - arr = pickle.loads(raw_value)
  160 + arr = decode_embedding_from_redis(raw_value)
160 161 if isinstance(arr, np.ndarray):
161   - return f"ndarray shape={arr.shape} dtype={arr.dtype}"
162   - return f"pickle object type={type(arr).__name__}"
  162 + dim = int(arr.size)
  163 + return f"bf16 bytes={len(raw_value)} dim={dim} restored_dtype={arr.dtype}"
  164 + return f"bf16 decode type={type(arr).__name__}"
163 165 except Exception:
164 166 return f"<binary {len(raw_value)} bytes>"
165 167  
... ...
tests/test_embedding_pipeline.py
1   -import pickle
2 1 from typing import Any, Dict, List, Optional
3 2  
4 3 import numpy as np
... ... @@ -13,6 +12,7 @@ from config import (
13 12 SearchConfig,
14 13 )
15 14 from embeddings.text_encoder import TextEmbeddingEncoder
  15 +from embeddings.bf16 import encode_embedding_for_redis
16 16 from query import QueryParser
17 17  
18 18  
... ... @@ -128,7 +128,7 @@ def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch):
128 128 def test_text_embedding_encoder_cache_hit(monkeypatch):
129 129 fake_redis = _FakeRedis()
130 130 cached = np.array([0.9, 0.8], dtype=np.float32)
131   - fake_redis.store["embedding:cached-text"] = pickle.dumps(cached)
  131 + fake_redis.store["embedding:cached-text"] = encode_embedding_for_redis(cached)
132 132 monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis)
133 133  
134 134 calls = {"count": 0}
... ...