Commit 4a37d233b1b7e61d86d841f7fb304aa38d39a919
1 parent
77516841
1. embedding cache float32 -> bf16
2. 抽象出可复用的 embedding Redis 缓存类(图文共用)
详细:
1. embedding 缓存改为 BF16 存 Redis(读回恢复 FP32)
关键行为(按你给的流程落地)
写入前:FP32 embedding →(normalize_embeddings=True 时)L2 normalize →
转 BF16 → bytes(2字节/维,大端) → redis.setex
读取后:redis.get bytes → BF16 → 恢复 FP32(np.float32 向量)
变更点
新增 embeddings/bf16.py
提供 float32_to_bf16 / bf16_to_float32
encode_embedding_for_redis():FP32 → BF16 → bytes
decode_embedding_from_redis():bytes → BF16 → FP32
l2_normalize_fp32():按需归一化
修改 embeddings/text_encoder.py
Redis value 从 pickle.dumps(np.ndarray) 改为 BF16 bytes
缓存 key 改为包含 normalize 标记:{prefix}:{n0|n1}:{query}(避免
normalize 开关不同却共用缓存)
修改 tests/test_embedding_pipeline.py
cache hit 用例改为写入 BF16 bytes,并使用新
key:embedding:n1:cached-text
修改 docs/缓存与Redis使用说明.md
embedding 缓存的 Key/Value 格式更新为 BF16 bytes + n0/n1
修改 scripts/redis/redis_cache_health_check.py
embedding pattern 不再硬编码 embedding:*,改为读取
REDIS_CONFIG["embedding_cache_prefix"]
value 预览从 pickle 解码改为 BF16 解码后展示 dim/bytes/dtype
自检
在激活环境后跑过 BF16 编解码往返 sanity check:bytes
长度、维度恢复正常;归一化向量读回后范数接近 1(会有 BF16 量化误差)。
2. 抽象出可复用的 embedding Redis 缓存类(图文共用)
新增
embeddings/redis_embedding_cache.py:RedisEmbeddingCache
统一 Redis 初始化(读 REDIS_CONFIG)
统一 BF16 bytes 编解码(复用 embeddings/bf16.py)
统一过期策略:写入 setex(expire_time),命中读取后 expire(expire_time)
滑动过期刷新 TTL
统一异常/坏数据处理:解码失败或向量非 1D/为空/含 NaN/Inf 会删除该 key
并当作 miss
已接入复用
文本 embeddings/text_encoder.py
用 self.cache = RedisEmbeddingCache(key_prefix=..., namespace="")
key 仍是:{prefix}:{query}
图片 embeddings/image_encoder.py
用 self.cache = RedisEmbeddingCache(key_prefix=..., namespace="image")
key 仍是:{prefix}:image:{url_or_path}
Showing
8 changed files
with
277 additions
and
88 deletions
Show diff stats
docs/TODO.txt
| 1 | 1 | |
| 2 | 2 | |
| 3 | +product_enrich : Partial Mode | |
| 4 | +https://help.aliyun.com/zh/model-studio/partial-mode?spm=a2c4g.11186623.help-menu-2400256.d_0_3_0_7.74a630119Ct6zR | |
| 5 | +需在messages 数组中将最后一条消息的 role 设置为 assistant,并在其 content 中提供前缀,在此消息中设置参数 "partial": true。messages格式如下: | |
| 6 | +[ | |
| 7 | + { | |
| 8 | + "role": "user", | |
| 9 | + "content": "请补全这个斐波那契函数,勿添加其它内容" | |
| 10 | + }, | |
| 11 | + { | |
| 12 | + "role": "assistant", | |
| 13 | + "content": "def calculate_fibonacci(n):\n if n <= 1:\n return n\n else:\n", | |
| 14 | + "partial": true | |
| 15 | + } | |
| 16 | +] | |
| 17 | +模型会以前缀内容为起点开始生成。 | |
| 18 | + | |
| 19 | +支持 非思考模式。 | |
| 20 | + | |
| 3 | 21 | |
| 4 | 22 | |
| 5 | 23 | |
| ... | ... | @@ -109,7 +127,7 @@ translation: |
| 109 | 127 | base_url: "http://127.0.0.1:6006" |
| 110 | 128 | model: "llm" # 或 "qwen-mt-flush",看你想用哪个 |
| 111 | 129 | timeout_sec: 10.0 |
| 112 | - llm: | |
| 130 | + llm:. | |
| 113 | 131 | model: "qwen-flash" # 留给翻译服务自身内部使用 |
| 114 | 132 | qwen-mt: ... |
| 115 | 133 | deepl: ... | ... | ... |
docs/缓存与Redis使用说明.md
| ... | ... | @@ -20,7 +20,7 @@ |
| 20 | 20 | |
| 21 | 21 | | 模块 / 场景 | Key 模板 | Value 内容示例 | 过期策略 | 备注 | |
| 22 | 22 | |------------|----------|----------------|----------|------| |
| 23 | -| 文本向量缓存(embedding) | `{EMBEDDING_CACHE_PREFIX}:{query}` | `pickle.dumps(np.ndarray)`,如 1024 维 BGE 向量 | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py`,前缀由 `REDIS_CONFIG["embedding_cache_prefix"]` 控制 | | |
| 23 | +| 向量缓存(text/image embedding) | `{EMBEDDING_CACHE_PREFIX}:{query_or_url}` / `{EMBEDDING_CACHE_PREFIX}:image:{url_or_path}` | **BF16 bytes**(每维 2 字节大端存储),读取后恢复为 `np.float32` | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py`(文本)与 `embeddings/image_encoder.py`(图片);前缀由 `REDIS_CONFIG["embedding_cache_prefix"]` 控制 | | |
| 24 | 24 | | 翻译结果缓存(Qwen-MT 翻译) | `{cache_prefix}:{model}:{src}:{tgt}:{sha256(payload)}` | 机翻后的单条字符串 | TTL=`services.translation.cache.ttl_seconds` 秒;可配置滑动过期 | 见 `query/qwen_mt_translate.py` + `config/config.yaml` | |
| 25 | 25 | | 商品内容理解缓存(anchors / 语义属性 / tags) | `{ANCHOR_CACHE_PREFIX}:{tenant_or_global}:{target_lang}:{md5(title)}` | `json.dumps(dict)`,包含 id/title/category/tags/anchor_text 等 | TTL=`ANCHOR_CACHE_EXPIRE_DAYS` 天 | 见 `indexer/product_enrich.py` | |
| 26 | 26 | |
| ... | ... | @@ -44,13 +44,14 @@ |
| 44 | 44 | |
| 45 | 45 | - 字段说明: |
| 46 | 46 | - `EMBEDDING_CACHE_PREFIX`:来自 `REDIS_CONFIG["embedding_cache_prefix"]`,默认值为 `"embedding"`,可通过环境变量 `REDIS_EMBEDDING_CACHE_PREFIX` 覆盖; |
| 47 | - - 当前实现**不再区分 language 与 normalize flag**,即无论是否归一化,key 结构都相同; | |
| 48 | 47 | - `query`:原始文本(未做哈希),注意长度特别长的 query 会直接出现在 key 中。 |
| 49 | 48 | |
| 50 | 49 | ### 2.2 Value 与类型 |
| 51 | 50 | |
| 52 | -- 类型:`pickle.dumps(np.ndarray)`,在读取时通过 `pickle.loads` 还原为 `np.ndarray`。 | |
| 53 | -- 典型示例:BGE-M3 1024 维 `float32` 向量。 | |
| 51 | +- 类型:**BF16 bytes**(bfloat16),每一维用 2 字节无符号整数表示,按**大端**序列化。 | |
| 52 | +- 写入流程:FP32 向量 → BF16 → bytes → Redis | |
| 53 | +- 读取流程:Redis bytes → BF16 → FP32(`np.float32`)向量 | |
| 54 | +- 典型示例:BGE-M3 1024 维向量在 Redis value 大小约为 \(1024*2=2048\) bytes(不含 Redis 元数据开销)。 | |
| 54 | 55 | |
| 55 | 56 | ### 2.3 过期策略 |
| 56 | 57 | ... | ... |
| ... | ... | @@ -0,0 +1,87 @@ |
| 1 | +""" | |
| 2 | +BF16 (bfloat16) codec helpers for Redis embedding cache. | |
| 3 | + | |
| 4 | +We store embeddings in Redis as: | |
| 5 | + FP32 vector -> (optional L2 normalize) -> BF16 (uint16 per element, big-endian) -> bytes | |
| 6 | + | |
| 7 | +No backward compatibility is provided by design. | |
| 8 | +""" | |
| 9 | + | |
| 10 | +from __future__ import annotations | |
| 11 | + | |
| 12 | +import struct | |
| 13 | +from typing import Iterable, List, Sequence | |
| 14 | + | |
| 15 | +import numpy as np | |
| 16 | + | |
| 17 | + | |
| 18 | +def float32_to_bf16(value: float) -> int: | |
| 19 | + """ | |
| 20 | + float32 -> bfloat16 (returns 0..65535 uint16) | |
| 21 | + Round-to-nearest-even. | |
| 22 | + """ | |
| 23 | + bits = struct.unpack(">I", struct.pack(">f", float(value)))[0] | |
| 24 | + rounding_bias = ((bits >> 16) & 1) + 0x7FFF | |
| 25 | + bits += rounding_bias | |
| 26 | + return (bits >> 16) & 0xFFFF | |
| 27 | + | |
| 28 | + | |
| 29 | +def bf16_to_float32(bf16: int) -> float: | |
| 30 | + """ | |
| 31 | + bfloat16 -> float32. | |
| 32 | + bf16 is an int in 0..65535. | |
| 33 | + """ | |
| 34 | + bits = (int(bf16) & 0xFFFF) << 16 | |
| 35 | + return struct.unpack(">f", struct.pack(">I", bits))[0] | |
| 36 | + | |
| 37 | + | |
| 38 | +def float_array_to_bf16(vector: Sequence[float]) -> List[int]: | |
| 39 | + return [float32_to_bf16(v) for v in vector] | |
| 40 | + | |
| 41 | + | |
| 42 | +def bf16_array_to_float(vector_bf16: Sequence[int]) -> List[float]: | |
| 43 | + return [bf16_to_float32(v) for v in vector_bf16] | |
| 44 | + | |
| 45 | + | |
| 46 | +def bf16_list_to_bytes(bf16_list: Sequence[int]) -> bytes: | |
| 47 | + """Each bf16 uses 2 bytes big-endian.""" | |
| 48 | + return b"".join(struct.pack(">H", int(x) & 0xFFFF) for x in bf16_list) | |
| 49 | + | |
| 50 | + | |
| 51 | +def bytes_to_bf16_list(data: bytes) -> List[int]: | |
| 52 | + if len(data) % 2 != 0: | |
| 53 | + raise ValueError("BF16 byte length must be even") | |
| 54 | + return [struct.unpack(">H", data[i : i + 2])[0] for i in range(0, len(data), 2)] | |
| 55 | + | |
| 56 | + | |
| 57 | +def encode_embedding_for_redis(embedding: np.ndarray) -> bytes: | |
| 58 | + """ | |
| 59 | + FP32 embedding -> BF16 -> bytes. | |
| 60 | + """ | |
| 61 | + arr = np.asarray(embedding, dtype=np.float32) | |
| 62 | + if arr.ndim != 1: | |
| 63 | + arr = arr.reshape(-1) | |
| 64 | + # Ensure we operate on plain Python floats for the reference codec. | |
| 65 | + bf16_list = float_array_to_bf16(arr.tolist()) | |
| 66 | + return bf16_list_to_bytes(bf16_list) | |
| 67 | + | |
| 68 | + | |
| 69 | +def decode_embedding_from_redis(data: bytes) -> np.ndarray: | |
| 70 | + """ | |
| 71 | + Redis bytes -> BF16 -> FP32 numpy array. | |
| 72 | + """ | |
| 73 | + bf16_list = bytes_to_bf16_list(data) | |
| 74 | + floats = bf16_array_to_float(bf16_list) | |
| 75 | + return np.asarray(floats, dtype=np.float32) | |
| 76 | + | |
| 77 | + | |
| 78 | +def l2_normalize_fp32(vec: np.ndarray) -> np.ndarray: | |
| 79 | + """L2-normalize a 1D FP32 vector. Raises on invalid norms.""" | |
| 80 | + arr = np.asarray(vec, dtype=np.float32) | |
| 81 | + if arr.ndim != 1: | |
| 82 | + arr = arr.reshape(-1) | |
| 83 | + norm = float(np.linalg.norm(arr)) | |
| 84 | + if not np.isfinite(norm) or norm <= 0.0: | |
| 85 | + raise ValueError("Embedding vector has invalid norm (must be > 0)") | |
| 86 | + return (arr / norm).astype(np.float32, copy=False) | |
| 87 | + | ... | ... |
embeddings/image_encoder.py
| ... | ... | @@ -11,6 +11,8 @@ from PIL import Image |
| 11 | 11 | logger = logging.getLogger(__name__) |
| 12 | 12 | |
| 13 | 13 | from config.services_config import get_embedding_base_url |
| 14 | +from config.env_config import REDIS_CONFIG | |
| 15 | +from embeddings.redis_embedding_cache import RedisEmbeddingCache | |
| 14 | 16 | |
| 15 | 17 | |
| 16 | 18 | class CLIPImageEncoder: |
| ... | ... | @@ -24,7 +26,13 @@ class CLIPImageEncoder: |
| 24 | 26 | resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url() |
| 25 | 27 | self.service_url = str(resolved_url).rstrip("/") |
| 26 | 28 | self.endpoint = f"{self.service_url}/embed/image" |
| 29 | + # Reuse embedding cache prefix, but separate namespace for images to avoid collisions. | |
| 30 | + self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" | |
| 27 | 31 | logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url) |
| 32 | + self.cache = RedisEmbeddingCache( | |
| 33 | + key_prefix=self.cache_prefix, | |
| 34 | + namespace="image", | |
| 35 | + ) | |
| 28 | 36 | |
| 29 | 37 | def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: |
| 30 | 38 | """ |
| ... | ... | @@ -67,12 +75,17 @@ class CLIPImageEncoder: |
| 67 | 75 | Returns: |
| 68 | 76 | Embedding vector |
| 69 | 77 | """ |
| 78 | + cached = self.cache.get(url) | |
| 79 | + if cached is not None: | |
| 80 | + return cached | |
| 81 | + | |
| 70 | 82 | response_data = self._call_service([url], normalize_embeddings=normalize_embeddings) |
| 71 | 83 | if not response_data or len(response_data) != 1 or response_data[0] is None: |
| 72 | 84 | raise RuntimeError(f"No image embedding returned for URL: {url}") |
| 73 | 85 | vec = np.array(response_data[0], dtype=np.float32) |
| 74 | 86 | if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): |
| 75 | 87 | raise RuntimeError(f"Invalid image embedding returned for URL: {url}") |
| 88 | + self.cache.set(url, vec) | |
| 76 | 89 | return vec |
| 77 | 90 | |
| 78 | 91 | def encode_batch( |
| ... | ... | @@ -98,8 +111,21 @@ class CLIPImageEncoder: |
| 98 | 111 | raise ValueError(f"Invalid image URL/path at index {i}: {img!r}") |
| 99 | 112 | |
| 100 | 113 | results: List[np.ndarray] = [] |
| 101 | - for i in range(0, len(images), batch_size): | |
| 102 | - batch_urls = [str(u).strip() for u in images[i:i + batch_size]] | |
| 114 | + pending_urls: List[str] = [] | |
| 115 | + pending_positions: List[int] = [] | |
| 116 | + | |
| 117 | + normalized_urls = [str(u).strip() for u in images] # type: ignore[list-item] | |
| 118 | + for pos, url in enumerate(normalized_urls): | |
| 119 | + cached = self.cache.get(url) | |
| 120 | + if cached is not None: | |
| 121 | + results.append(cached) | |
| 122 | + else: | |
| 123 | + results.append(np.array([], dtype=np.float32)) # placeholder | |
| 124 | + pending_positions.append(pos) | |
| 125 | + pending_urls.append(url) | |
| 126 | + | |
| 127 | + for i in range(0, len(pending_urls), batch_size): | |
| 128 | + batch_urls = pending_urls[i : i + batch_size] | |
| 103 | 129 | response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings) |
| 104 | 130 | if not response_data or len(response_data) != len(batch_urls): |
| 105 | 131 | raise RuntimeError( |
| ... | ... | @@ -113,7 +139,9 @@ class CLIPImageEncoder: |
| 113 | 139 | vec = np.array(embedding, dtype=np.float32) |
| 114 | 140 | if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): |
| 115 | 141 | raise RuntimeError(f"Invalid image embedding returned for URL: {url}") |
| 116 | - results.append(vec) | |
| 142 | + self.cache.set(url, vec) | |
| 143 | + pos = pending_positions[i + j] | |
| 144 | + results[pos] = vec | |
| 117 | 145 | |
| 118 | 146 | return results |
| 119 | 147 | ... | ... |
| ... | ... | @@ -0,0 +1,107 @@ |
| 1 | +""" | |
| 2 | +Shared Redis cache for embedding vectors (text + image). | |
| 3 | + | |
| 4 | +Value format: BF16 bytes (see embeddings/bf16.py) | |
| 5 | +Expiration: setex on write + sliding expire on successful read. | |
| 6 | + | |
| 7 | +This module is intentionally small and dependency-light, so both indexer/search paths | |
| 8 | +can reuse it safely. | |
| 9 | +""" | |
| 10 | + | |
| 11 | +from __future__ import annotations | |
| 12 | + | |
| 13 | +import logging | |
| 14 | +from datetime import timedelta | |
| 15 | +from typing import Optional | |
| 16 | + | |
| 17 | +import numpy as np | |
| 18 | +import redis | |
| 19 | + | |
| 20 | +from config.env_config import REDIS_CONFIG | |
| 21 | +from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis | |
| 22 | + | |
| 23 | +logger = logging.getLogger(__name__) | |
| 24 | + | |
| 25 | + | |
| 26 | +class RedisEmbeddingCache: | |
| 27 | + def __init__( | |
| 28 | + self, | |
| 29 | + *, | |
| 30 | + key_prefix: str, | |
| 31 | + namespace: str = "", | |
| 32 | + expire_time: Optional[timedelta] = None, | |
| 33 | + redis_client: Optional[redis.Redis] = None, | |
| 34 | + ): | |
| 35 | + self.key_prefix = (key_prefix or "").strip() or "embedding" | |
| 36 | + self.namespace = (namespace or "").strip() | |
| 37 | + self.expire_time = expire_time or timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180)) | |
| 38 | + | |
| 39 | + if redis_client is not None: | |
| 40 | + self.redis_client = redis_client | |
| 41 | + return | |
| 42 | + | |
| 43 | + try: | |
| 44 | + client = redis.Redis( | |
| 45 | + host=REDIS_CONFIG.get("host", "localhost"), | |
| 46 | + port=REDIS_CONFIG.get("port", 6479), | |
| 47 | + password=REDIS_CONFIG.get("password"), | |
| 48 | + decode_responses=False, | |
| 49 | + socket_timeout=REDIS_CONFIG.get("socket_timeout", 1), | |
| 50 | + socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1), | |
| 51 | + retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False), | |
| 52 | + health_check_interval=10, | |
| 53 | + ) | |
| 54 | + client.ping() | |
| 55 | + self.redis_client = client | |
| 56 | + except Exception as e: | |
| 57 | + logger.warning("Failed to initialize Redis cache: %s, continuing without cache", e) | |
| 58 | + self.redis_client = None | |
| 59 | + | |
| 60 | + def make_key(self, raw_key: str) -> str: | |
| 61 | + if self.namespace: | |
| 62 | + return f"{self.key_prefix}:{self.namespace}:{raw_key}" | |
| 63 | + return f"{self.key_prefix}:{raw_key}" | |
| 64 | + | |
| 65 | + def get(self, raw_key: str) -> Optional[np.ndarray]: | |
| 66 | + if not self.redis_client: | |
| 67 | + return None | |
| 68 | + key = self.make_key(raw_key) | |
| 69 | + try: | |
| 70 | + raw = self.redis_client.get(key) | |
| 71 | + if not raw: | |
| 72 | + return None | |
| 73 | + vec = decode_embedding_from_redis(raw) | |
| 74 | + if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): | |
| 75 | + try: | |
| 76 | + self.redis_client.delete(key) | |
| 77 | + except Exception: | |
| 78 | + pass | |
| 79 | + return None | |
| 80 | + # Sliding expiration: refresh TTL on hit. | |
| 81 | + self.redis_client.expire(key, self.expire_time) | |
| 82 | + return vec | |
| 83 | + except Exception as e: | |
| 84 | + logger.warning("Error retrieving embedding from cache: %s", e) | |
| 85 | + return None | |
| 86 | + | |
| 87 | + def set(self, raw_key: str, embedding: np.ndarray) -> bool: | |
| 88 | + if not self.redis_client: | |
| 89 | + return False | |
| 90 | + key = self.make_key(raw_key) | |
| 91 | + try: | |
| 92 | + vec = np.asarray(embedding, dtype=np.float32) | |
| 93 | + raw = encode_embedding_for_redis(vec) | |
| 94 | + self.redis_client.setex(key, self.expire_time, raw) | |
| 95 | + return True | |
| 96 | + except ( | |
| 97 | + redis.exceptions.BusyLoadingError, | |
| 98 | + redis.exceptions.ConnectionError, | |
| 99 | + redis.exceptions.TimeoutError, | |
| 100 | + redis.exceptions.RedisError, | |
| 101 | + ) as e: | |
| 102 | + logger.warning("Redis error storing embedding in cache: %s", e) | |
| 103 | + return False | |
| 104 | + except Exception as e: | |
| 105 | + logger.warning("Error storing embedding in cache: %s", e) | |
| 106 | + return False | |
| 107 | + | ... | ... |
embeddings/text_encoder.py
| ... | ... | @@ -2,17 +2,16 @@ |
| 2 | 2 | |
| 3 | 3 | import logging |
| 4 | 4 | import os |
| 5 | -import pickle | |
| 6 | 5 | from datetime import timedelta |
| 7 | 6 | from typing import Any, List, Optional, Union |
| 8 | 7 | |
| 9 | 8 | import numpy as np |
| 10 | -import redis | |
| 11 | 9 | import requests |
| 12 | 10 | |
| 13 | 11 | logger = logging.getLogger(__name__) |
| 14 | 12 | |
| 15 | 13 | from config.services_config import get_embedding_base_url |
| 14 | +from embeddings.redis_embedding_cache import RedisEmbeddingCache | |
| 16 | 15 | |
| 17 | 16 | # Try to import REDIS_CONFIG, but allow import to fail |
| 18 | 17 | from config.env_config import REDIS_CONFIG |
| ... | ... | @@ -30,22 +29,11 @@ class TextEmbeddingEncoder: |
| 30 | 29 | self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" |
| 31 | 30 | logger.info("Creating TextEmbeddingEncoder instance with service URL: %s", self.service_url) |
| 32 | 31 | |
| 33 | - try: | |
| 34 | - self.redis_client = redis.Redis( | |
| 35 | - host=REDIS_CONFIG.get("host", "localhost"), | |
| 36 | - port=REDIS_CONFIG.get("port", 6479), | |
| 37 | - password=REDIS_CONFIG.get("password"), | |
| 38 | - decode_responses=False, | |
| 39 | - socket_timeout=REDIS_CONFIG.get("socket_timeout", 1), | |
| 40 | - socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1), | |
| 41 | - retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False), | |
| 42 | - health_check_interval=10, | |
| 43 | - ) | |
| 44 | - self.redis_client.ping() | |
| 45 | - logger.info("Redis cache initialized for embeddings") | |
| 46 | - except Exception as e: | |
| 47 | - logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e) | |
| 48 | - self.redis_client = None | |
| 32 | + self.cache = RedisEmbeddingCache( | |
| 33 | + key_prefix=self.cache_prefix, | |
| 34 | + namespace="", | |
| 35 | + expire_time=self.expire_time, | |
| 36 | + ) | |
| 49 | 37 | |
| 50 | 38 | def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: |
| 51 | 39 | """ |
| ... | ... | @@ -127,7 +115,7 @@ class TextEmbeddingEncoder: |
| 127 | 115 | embedding_array = np.array(embedding, dtype=np.float32) |
| 128 | 116 | if self._is_valid_embedding(embedding_array): |
| 129 | 117 | embeddings[original_idx] = embedding_array |
| 130 | - self._set_cached_embedding(text, embedding_array, normalize_embeddings) | |
| 118 | + self._set_cached_embedding(text, embedding_array) | |
| 131 | 119 | else: |
| 132 | 120 | raise ValueError( |
| 133 | 121 | f"Invalid embedding returned from service for text index {original_idx}" |
| ... | ... | @@ -161,63 +149,21 @@ class TextEmbeddingEncoder: |
| 161 | 149 | |
| 162 | 150 | def _get_cached_embedding( |
| 163 | 151 | self, |
| 164 | - query: str | |
| 152 | + query: str, | |
| 165 | 153 | ) -> Optional[np.ndarray]: |
| 166 | - """Get embedding from cache if exists (with sliding expiration)""" | |
| 167 | - if not self.redis_client: | |
| 168 | - return None | |
| 169 | - | |
| 170 | - try: | |
| 171 | - cache_key = f"{self.cache_prefix}:{query}" | |
| 172 | - cached_data = self.redis_client.get(cache_key) | |
| 173 | - if cached_data: | |
| 174 | - embedding = pickle.loads(cached_data) | |
| 175 | - # Validate cached embedding - if invalid, ignore cache and return None | |
| 176 | - if self._is_valid_embedding(embedding): | |
| 177 | - logger.debug(f"Cache hit for embedding: {query}") | |
| 178 | - # Update expiration time on access (sliding expiration) | |
| 179 | - self.redis_client.expire(cache_key, self.expire_time) | |
| 180 | - return embedding | |
| 181 | - else: | |
| 182 | - logger.warning( | |
| 183 | - f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), " | |
| 184 | - f"ignoring cache for query: {query[:50]}..." | |
| 185 | - ) | |
| 186 | - # Delete invalid cache entry | |
| 187 | - try: | |
| 188 | - self.redis_client.delete(cache_key) | |
| 189 | - except Exception as e: | |
| 190 | - logger.debug(f"Failed to delete invalid cache entry: {e}") | |
| 191 | - return None | |
| 192 | - return None | |
| 193 | - except Exception as e: | |
| 194 | - logger.error(f"Error retrieving embedding from cache: {e}") | |
| 195 | - return None | |
| 154 | + """Get embedding from cache if exists (with sliding expiration).""" | |
| 155 | + embedding = self.cache.get(query) | |
| 156 | + if embedding is not None: | |
| 157 | + logger.debug(f"Cache hit for embedding: {query}") | |
| 158 | + return embedding | |
| 196 | 159 | |
| 197 | 160 | def _set_cached_embedding( |
| 198 | 161 | self, |
| 199 | 162 | query: str, |
| 200 | 163 | embedding: np.ndarray, |
| 201 | - normalize_embeddings: bool = True, | |
| 202 | 164 | ) -> bool: |
| 203 | - """Store embedding in cache""" | |
| 204 | - if not self.redis_client: | |
| 205 | - return False | |
| 206 | - | |
| 207 | - try: | |
| 208 | - cache_key = f"{self.cache_prefix}:{query}" | |
| 209 | - serialized_data = pickle.dumps(embedding) | |
| 210 | - self.redis_client.setex( | |
| 211 | - cache_key, | |
| 212 | - self.expire_time, | |
| 213 | - serialized_data | |
| 214 | - ) | |
| 165 | + """Store embedding in cache.""" | |
| 166 | + ok = self.cache.set(query, embedding) | |
| 167 | + if ok: | |
| 215 | 168 | logger.debug(f"Successfully cached embedding for query: {query}") |
| 216 | - return True | |
| 217 | - except (redis.exceptions.BusyLoadingError, redis.exceptions.ConnectionError, | |
| 218 | - redis.exceptions.TimeoutError, redis.exceptions.RedisError) as e: | |
| 219 | - logger.warning(f"Redis error storing embedding in cache: {e}") | |
| 220 | - return False | |
| 221 | - except Exception as e: | |
| 222 | - logger.error(f"Error storing embedding in cache: {e}") | |
| 223 | - return False | |
| 169 | + return ok | ... | ... |
scripts/redis/redis_cache_health_check.py
| ... | ... | @@ -28,7 +28,6 @@ from __future__ import annotations |
| 28 | 28 | |
| 29 | 29 | import argparse |
| 30 | 30 | import json |
| 31 | -import pickle | |
| 32 | 31 | import sys |
| 33 | 32 | from collections import defaultdict |
| 34 | 33 | from dataclasses import dataclass |
| ... | ... | @@ -45,6 +44,7 @@ sys.path.insert(0, str(PROJECT_ROOT)) |
| 45 | 44 | |
| 46 | 45 | from config.env_config import REDIS_CONFIG # type: ignore |
| 47 | 46 | from config.services_config import get_translation_cache_config # type: ignore |
| 47 | +from embeddings.bf16 import decode_embedding_from_redis # type: ignore | |
| 48 | 48 | |
| 49 | 49 | |
| 50 | 50 | @dataclass |
| ... | ... | @@ -58,10 +58,11 @@ def _load_known_cache_types() -> Dict[str, CacheTypeConfig]: |
| 58 | 58 | """根据当前配置装配三种已知缓存类型及其前缀 pattern。""" |
| 59 | 59 | cache_types: Dict[str, CacheTypeConfig] = {} |
| 60 | 60 | |
| 61 | - # embedding 缓存:固定 embedding:* 前缀 | |
| 61 | + # embedding 缓存:prefix 来自 REDIS_CONFIG['embedding_cache_prefix'](默认 embedding) | |
| 62 | + embedding_prefix = REDIS_CONFIG.get("embedding_cache_prefix", "embedding") | |
| 62 | 63 | cache_types["embedding"] = CacheTypeConfig( |
| 63 | 64 | name="embedding", |
| 64 | - pattern="embedding:*", | |
| 65 | + pattern=f"{embedding_prefix}:*", | |
| 65 | 66 | description="文本向量缓存(embeddings/text_encoder.py)", |
| 66 | 67 | ) |
| 67 | 68 | |
| ... | ... | @@ -153,13 +154,14 @@ def decode_value_preview( |
| 153 | 154 | if raw_value is None: |
| 154 | 155 | return "<nil>" |
| 155 | 156 | |
| 156 | - # embedding: pickle 序列化的 numpy.ndarray | |
| 157 | + # embedding: BF16 bytes | |
| 157 | 158 | if cache_type == "embedding": |
| 158 | 159 | try: |
| 159 | - arr = pickle.loads(raw_value) | |
| 160 | + arr = decode_embedding_from_redis(raw_value) | |
| 160 | 161 | if isinstance(arr, np.ndarray): |
| 161 | - return f"ndarray shape={arr.shape} dtype={arr.dtype}" | |
| 162 | - return f"pickle object type={type(arr).__name__}" | |
| 162 | + dim = int(arr.size) | |
| 163 | + return f"bf16 bytes={len(raw_value)} dim={dim} restored_dtype={arr.dtype}" | |
| 164 | + return f"bf16 decode type={type(arr).__name__}" | |
| 163 | 165 | except Exception: |
| 164 | 166 | return f"<binary {len(raw_value)} bytes>" |
| 165 | 167 | ... | ... |
tests/test_embedding_pipeline.py
| 1 | -import pickle | |
| 2 | 1 | from typing import Any, Dict, List, Optional |
| 3 | 2 | |
| 4 | 3 | import numpy as np |
| ... | ... | @@ -13,6 +12,7 @@ from config import ( |
| 13 | 12 | SearchConfig, |
| 14 | 13 | ) |
| 15 | 14 | from embeddings.text_encoder import TextEmbeddingEncoder |
| 15 | +from embeddings.bf16 import encode_embedding_for_redis | |
| 16 | 16 | from query import QueryParser |
| 17 | 17 | |
| 18 | 18 | |
| ... | ... | @@ -128,7 +128,7 @@ def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): |
| 128 | 128 | def test_text_embedding_encoder_cache_hit(monkeypatch): |
| 129 | 129 | fake_redis = _FakeRedis() |
| 130 | 130 | cached = np.array([0.9, 0.8], dtype=np.float32) |
| 131 | - fake_redis.store["embedding:cached-text"] = pickle.dumps(cached) | |
| 131 | + fake_redis.store["embedding:cached-text"] = encode_embedding_for_redis(cached) | |
| 132 | 132 | monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) |
| 133 | 133 | |
| 134 | 134 | calls = {"count": 0} | ... | ... |