Commit 4a37d233b1b7e61d86d841f7fb304aa38d39a919
1 parent
77516841
1. embedding cache float32 -> bf16
2. 抽象出可复用的 embedding Redis 缓存类(图文共用)
详细:
1. embedding 缓存改为 BF16 存 Redis(读回恢复 FP32)
关键行为(按你给的流程落地)
写入前:FP32 embedding →(normalize_embeddings=True 时)L2 normalize →
转 BF16 → bytes(2字节/维,大端) → redis.setex
读取后:redis.get bytes → BF16 → 恢复 FP32(np.float32 向量)
变更点
新增 embeddings/bf16.py
提供 float32_to_bf16 / bf16_to_float32
encode_embedding_for_redis():FP32 → BF16 → bytes
decode_embedding_from_redis():bytes → BF16 → FP32
l2_normalize_fp32():按需归一化
修改 embeddings/text_encoder.py
Redis value 从 pickle.dumps(np.ndarray) 改为 BF16 bytes
缓存 key 改为包含 normalize 标记:{prefix}:{n0|n1}:{query}(避免
normalize 开关不同却共用缓存)
修改 tests/test_embedding_pipeline.py
cache hit 用例改为写入 BF16 bytes,并使用新
key:embedding:n1:cached-text
修改 docs/缓存与Redis使用说明.md
embedding 缓存的 Key/Value 格式更新为 BF16 bytes + n0/n1
修改 scripts/redis/redis_cache_health_check.py
embedding pattern 不再硬编码 embedding:*,改为读取
REDIS_CONFIG["embedding_cache_prefix"]
value 预览从 pickle 解码改为 BF16 解码后展示 dim/bytes/dtype
自检
在激活环境后跑过 BF16 编解码往返 sanity check:bytes
长度、维度恢复正常;归一化向量读回后范数接近 1(会有 BF16 量化误差)。
2. 抽象出可复用的 embedding Redis 缓存类(图文共用)
新增
embeddings/redis_embedding_cache.py:RedisEmbeddingCache
统一 Redis 初始化(读 REDIS_CONFIG)
统一 BF16 bytes 编解码(复用 embeddings/bf16.py)
统一过期策略:写入 setex(expire_time),命中读取后 expire(expire_time)
滑动过期刷新 TTL
统一异常/坏数据处理:解码失败或向量非 1D/为空/含 NaN/Inf 会删除该 key
并当作 miss
已接入复用
文本 embeddings/text_encoder.py
用 self.cache = RedisEmbeddingCache(key_prefix=..., namespace="")
key 仍是:{prefix}:{query}
图片 embeddings/image_encoder.py
用 self.cache = RedisEmbeddingCache(key_prefix=..., namespace="image")
key 仍是:{prefix}:image:{url_or_path}
Showing
8 changed files
with
277 additions
and
88 deletions
Show diff stats
docs/TODO.txt
| 1 | 1 | ||
| 2 | 2 | ||
| 3 | +product_enrich : Partial Mode | ||
| 4 | +https://help.aliyun.com/zh/model-studio/partial-mode?spm=a2c4g.11186623.help-menu-2400256.d_0_3_0_7.74a630119Ct6zR | ||
| 5 | +需在messages 数组中将最后一条消息的 role 设置为 assistant,并在其 content 中提供前缀,在此消息中设置参数 "partial": true。messages格式如下: | ||
| 6 | +[ | ||
| 7 | + { | ||
| 8 | + "role": "user", | ||
| 9 | + "content": "请补全这个斐波那契函数,勿添加其它内容" | ||
| 10 | + }, | ||
| 11 | + { | ||
| 12 | + "role": "assistant", | ||
| 13 | + "content": "def calculate_fibonacci(n):\n if n <= 1:\n return n\n else:\n", | ||
| 14 | + "partial": true | ||
| 15 | + } | ||
| 16 | +] | ||
| 17 | +模型会以前缀内容为起点开始生成。 | ||
| 18 | + | ||
| 19 | +支持 非思考模式。 | ||
| 20 | + | ||
| 3 | 21 | ||
| 4 | 22 | ||
| 5 | 23 | ||
| @@ -109,7 +127,7 @@ translation: | @@ -109,7 +127,7 @@ translation: | ||
| 109 | base_url: "http://127.0.0.1:6006" | 127 | base_url: "http://127.0.0.1:6006" |
| 110 | model: "llm" # 或 "qwen-mt-flush",看你想用哪个 | 128 | model: "llm" # 或 "qwen-mt-flush",看你想用哪个 |
| 111 | timeout_sec: 10.0 | 129 | timeout_sec: 10.0 |
| 112 | - llm: | 130 | + llm:. |
| 113 | model: "qwen-flash" # 留给翻译服务自身内部使用 | 131 | model: "qwen-flash" # 留给翻译服务自身内部使用 |
| 114 | qwen-mt: ... | 132 | qwen-mt: ... |
| 115 | deepl: ... | 133 | deepl: ... |
docs/缓存与Redis使用说明.md
| @@ -20,7 +20,7 @@ | @@ -20,7 +20,7 @@ | ||
| 20 | 20 | ||
| 21 | | 模块 / 场景 | Key 模板 | Value 内容示例 | 过期策略 | 备注 | | 21 | | 模块 / 场景 | Key 模板 | Value 内容示例 | 过期策略 | 备注 | |
| 22 | |------------|----------|----------------|----------|------| | 22 | |------------|----------|----------------|----------|------| |
| 23 | -| 文本向量缓存(embedding) | `{EMBEDDING_CACHE_PREFIX}:{query}` | `pickle.dumps(np.ndarray)`,如 1024 维 BGE 向量 | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py`,前缀由 `REDIS_CONFIG["embedding_cache_prefix"]` 控制 | | 23 | +| 向量缓存(text/image embedding) | `{EMBEDDING_CACHE_PREFIX}:{query_or_url}` / `{EMBEDDING_CACHE_PREFIX}:image:{url_or_path}` | **BF16 bytes**(每维 2 字节大端存储),读取后恢复为 `np.float32` | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py`(文本)与 `embeddings/image_encoder.py`(图片);前缀由 `REDIS_CONFIG["embedding_cache_prefix"]` 控制 | |
| 24 | | 翻译结果缓存(Qwen-MT 翻译) | `{cache_prefix}:{model}:{src}:{tgt}:{sha256(payload)}` | 机翻后的单条字符串 | TTL=`services.translation.cache.ttl_seconds` 秒;可配置滑动过期 | 见 `query/qwen_mt_translate.py` + `config/config.yaml` | | 24 | | 翻译结果缓存(Qwen-MT 翻译) | `{cache_prefix}:{model}:{src}:{tgt}:{sha256(payload)}` | 机翻后的单条字符串 | TTL=`services.translation.cache.ttl_seconds` 秒;可配置滑动过期 | 见 `query/qwen_mt_translate.py` + `config/config.yaml` | |
| 25 | | 商品内容理解缓存(anchors / 语义属性 / tags) | `{ANCHOR_CACHE_PREFIX}:{tenant_or_global}:{target_lang}:{md5(title)}` | `json.dumps(dict)`,包含 id/title/category/tags/anchor_text 等 | TTL=`ANCHOR_CACHE_EXPIRE_DAYS` 天 | 见 `indexer/product_enrich.py` | | 25 | | 商品内容理解缓存(anchors / 语义属性 / tags) | `{ANCHOR_CACHE_PREFIX}:{tenant_or_global}:{target_lang}:{md5(title)}` | `json.dumps(dict)`,包含 id/title/category/tags/anchor_text 等 | TTL=`ANCHOR_CACHE_EXPIRE_DAYS` 天 | 见 `indexer/product_enrich.py` | |
| 26 | 26 | ||
| @@ -44,13 +44,14 @@ | @@ -44,13 +44,14 @@ | ||
| 44 | 44 | ||
| 45 | - 字段说明: | 45 | - 字段说明: |
| 46 | - `EMBEDDING_CACHE_PREFIX`:来自 `REDIS_CONFIG["embedding_cache_prefix"]`,默认值为 `"embedding"`,可通过环境变量 `REDIS_EMBEDDING_CACHE_PREFIX` 覆盖; | 46 | - `EMBEDDING_CACHE_PREFIX`:来自 `REDIS_CONFIG["embedding_cache_prefix"]`,默认值为 `"embedding"`,可通过环境变量 `REDIS_EMBEDDING_CACHE_PREFIX` 覆盖; |
| 47 | - - 当前实现**不再区分 language 与 normalize flag**,即无论是否归一化,key 结构都相同; | ||
| 48 | - `query`:原始文本(未做哈希),注意长度特别长的 query 会直接出现在 key 中。 | 47 | - `query`:原始文本(未做哈希),注意长度特别长的 query 会直接出现在 key 中。 |
| 49 | 48 | ||
| 50 | ### 2.2 Value 与类型 | 49 | ### 2.2 Value 与类型 |
| 51 | 50 | ||
| 52 | -- 类型:`pickle.dumps(np.ndarray)`,在读取时通过 `pickle.loads` 还原为 `np.ndarray`。 | ||
| 53 | -- 典型示例:BGE-M3 1024 维 `float32` 向量。 | 51 | +- 类型:**BF16 bytes**(bfloat16),每一维用 2 字节无符号整数表示,按**大端**序列化。 |
| 52 | +- 写入流程:FP32 向量 → BF16 → bytes → Redis | ||
| 53 | +- 读取流程:Redis bytes → BF16 → FP32(`np.float32`)向量 | ||
| 54 | +- 典型示例:BGE-M3 1024 维向量在 Redis value 大小约为 \(1024*2=2048\) bytes(不含 Redis 元数据开销)。 | ||
| 54 | 55 | ||
| 55 | ### 2.3 过期策略 | 56 | ### 2.3 过期策略 |
| 56 | 57 |
| @@ -0,0 +1,87 @@ | @@ -0,0 +1,87 @@ | ||
| 1 | +""" | ||
| 2 | +BF16 (bfloat16) codec helpers for Redis embedding cache. | ||
| 3 | + | ||
| 4 | +We store embeddings in Redis as: | ||
| 5 | + FP32 vector -> (optional L2 normalize) -> BF16 (uint16 per element, big-endian) -> bytes | ||
| 6 | + | ||
| 7 | +No backward compatibility is provided by design. | ||
| 8 | +""" | ||
| 9 | + | ||
| 10 | +from __future__ import annotations | ||
| 11 | + | ||
| 12 | +import struct | ||
| 13 | +from typing import Iterable, List, Sequence | ||
| 14 | + | ||
| 15 | +import numpy as np | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +def float32_to_bf16(value: float) -> int: | ||
| 19 | + """ | ||
| 20 | + float32 -> bfloat16 (returns 0..65535 uint16) | ||
| 21 | + Round-to-nearest-even. | ||
| 22 | + """ | ||
| 23 | + bits = struct.unpack(">I", struct.pack(">f", float(value)))[0] | ||
| 24 | + rounding_bias = ((bits >> 16) & 1) + 0x7FFF | ||
| 25 | + bits += rounding_bias | ||
| 26 | + return (bits >> 16) & 0xFFFF | ||
| 27 | + | ||
| 28 | + | ||
| 29 | +def bf16_to_float32(bf16: int) -> float: | ||
| 30 | + """ | ||
| 31 | + bfloat16 -> float32. | ||
| 32 | + bf16 is an int in 0..65535. | ||
| 33 | + """ | ||
| 34 | + bits = (int(bf16) & 0xFFFF) << 16 | ||
| 35 | + return struct.unpack(">f", struct.pack(">I", bits))[0] | ||
| 36 | + | ||
| 37 | + | ||
| 38 | +def float_array_to_bf16(vector: Sequence[float]) -> List[int]: | ||
| 39 | + return [float32_to_bf16(v) for v in vector] | ||
| 40 | + | ||
| 41 | + | ||
| 42 | +def bf16_array_to_float(vector_bf16: Sequence[int]) -> List[float]: | ||
| 43 | + return [bf16_to_float32(v) for v in vector_bf16] | ||
| 44 | + | ||
| 45 | + | ||
| 46 | +def bf16_list_to_bytes(bf16_list: Sequence[int]) -> bytes: | ||
| 47 | + """Each bf16 uses 2 bytes big-endian.""" | ||
| 48 | + return b"".join(struct.pack(">H", int(x) & 0xFFFF) for x in bf16_list) | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +def bytes_to_bf16_list(data: bytes) -> List[int]: | ||
| 52 | + if len(data) % 2 != 0: | ||
| 53 | + raise ValueError("BF16 byte length must be even") | ||
| 54 | + return [struct.unpack(">H", data[i : i + 2])[0] for i in range(0, len(data), 2)] | ||
| 55 | + | ||
| 56 | + | ||
| 57 | +def encode_embedding_for_redis(embedding: np.ndarray) -> bytes: | ||
| 58 | + """ | ||
| 59 | + FP32 embedding -> BF16 -> bytes. | ||
| 60 | + """ | ||
| 61 | + arr = np.asarray(embedding, dtype=np.float32) | ||
| 62 | + if arr.ndim != 1: | ||
| 63 | + arr = arr.reshape(-1) | ||
| 64 | + # Ensure we operate on plain Python floats for the reference codec. | ||
| 65 | + bf16_list = float_array_to_bf16(arr.tolist()) | ||
| 66 | + return bf16_list_to_bytes(bf16_list) | ||
| 67 | + | ||
| 68 | + | ||
| 69 | +def decode_embedding_from_redis(data: bytes) -> np.ndarray: | ||
| 70 | + """ | ||
| 71 | + Redis bytes -> BF16 -> FP32 numpy array. | ||
| 72 | + """ | ||
| 73 | + bf16_list = bytes_to_bf16_list(data) | ||
| 74 | + floats = bf16_array_to_float(bf16_list) | ||
| 75 | + return np.asarray(floats, dtype=np.float32) | ||
| 76 | + | ||
| 77 | + | ||
| 78 | +def l2_normalize_fp32(vec: np.ndarray) -> np.ndarray: | ||
| 79 | + """L2-normalize a 1D FP32 vector. Raises on invalid norms.""" | ||
| 80 | + arr = np.asarray(vec, dtype=np.float32) | ||
| 81 | + if arr.ndim != 1: | ||
| 82 | + arr = arr.reshape(-1) | ||
| 83 | + norm = float(np.linalg.norm(arr)) | ||
| 84 | + if not np.isfinite(norm) or norm <= 0.0: | ||
| 85 | + raise ValueError("Embedding vector has invalid norm (must be > 0)") | ||
| 86 | + return (arr / norm).astype(np.float32, copy=False) | ||
| 87 | + |
embeddings/image_encoder.py
| @@ -11,6 +11,8 @@ from PIL import Image | @@ -11,6 +11,8 @@ from PIL import Image | ||
| 11 | logger = logging.getLogger(__name__) | 11 | logger = logging.getLogger(__name__) |
| 12 | 12 | ||
| 13 | from config.services_config import get_embedding_base_url | 13 | from config.services_config import get_embedding_base_url |
| 14 | +from config.env_config import REDIS_CONFIG | ||
| 15 | +from embeddings.redis_embedding_cache import RedisEmbeddingCache | ||
| 14 | 16 | ||
| 15 | 17 | ||
| 16 | class CLIPImageEncoder: | 18 | class CLIPImageEncoder: |
| @@ -24,7 +26,13 @@ class CLIPImageEncoder: | @@ -24,7 +26,13 @@ class CLIPImageEncoder: | ||
| 24 | resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url() | 26 | resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url() |
| 25 | self.service_url = str(resolved_url).rstrip("/") | 27 | self.service_url = str(resolved_url).rstrip("/") |
| 26 | self.endpoint = f"{self.service_url}/embed/image" | 28 | self.endpoint = f"{self.service_url}/embed/image" |
| 29 | + # Reuse embedding cache prefix, but separate namespace for images to avoid collisions. | ||
| 30 | + self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" | ||
| 27 | logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url) | 31 | logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url) |
| 32 | + self.cache = RedisEmbeddingCache( | ||
| 33 | + key_prefix=self.cache_prefix, | ||
| 34 | + namespace="image", | ||
| 35 | + ) | ||
| 28 | 36 | ||
| 29 | def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: | 37 | def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: |
| 30 | """ | 38 | """ |
| @@ -67,12 +75,17 @@ class CLIPImageEncoder: | @@ -67,12 +75,17 @@ class CLIPImageEncoder: | ||
| 67 | Returns: | 75 | Returns: |
| 68 | Embedding vector | 76 | Embedding vector |
| 69 | """ | 77 | """ |
| 78 | + cached = self.cache.get(url) | ||
| 79 | + if cached is not None: | ||
| 80 | + return cached | ||
| 81 | + | ||
| 70 | response_data = self._call_service([url], normalize_embeddings=normalize_embeddings) | 82 | response_data = self._call_service([url], normalize_embeddings=normalize_embeddings) |
| 71 | if not response_data or len(response_data) != 1 or response_data[0] is None: | 83 | if not response_data or len(response_data) != 1 or response_data[0] is None: |
| 72 | raise RuntimeError(f"No image embedding returned for URL: {url}") | 84 | raise RuntimeError(f"No image embedding returned for URL: {url}") |
| 73 | vec = np.array(response_data[0], dtype=np.float32) | 85 | vec = np.array(response_data[0], dtype=np.float32) |
| 74 | if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): | 86 | if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): |
| 75 | raise RuntimeError(f"Invalid image embedding returned for URL: {url}") | 87 | raise RuntimeError(f"Invalid image embedding returned for URL: {url}") |
| 88 | + self.cache.set(url, vec) | ||
| 76 | return vec | 89 | return vec |
| 77 | 90 | ||
| 78 | def encode_batch( | 91 | def encode_batch( |
| @@ -98,8 +111,21 @@ class CLIPImageEncoder: | @@ -98,8 +111,21 @@ class CLIPImageEncoder: | ||
| 98 | raise ValueError(f"Invalid image URL/path at index {i}: {img!r}") | 111 | raise ValueError(f"Invalid image URL/path at index {i}: {img!r}") |
| 99 | 112 | ||
| 100 | results: List[np.ndarray] = [] | 113 | results: List[np.ndarray] = [] |
| 101 | - for i in range(0, len(images), batch_size): | ||
| 102 | - batch_urls = [str(u).strip() for u in images[i:i + batch_size]] | 114 | + pending_urls: List[str] = [] |
| 115 | + pending_positions: List[int] = [] | ||
| 116 | + | ||
| 117 | + normalized_urls = [str(u).strip() for u in images] # type: ignore[list-item] | ||
| 118 | + for pos, url in enumerate(normalized_urls): | ||
| 119 | + cached = self.cache.get(url) | ||
| 120 | + if cached is not None: | ||
| 121 | + results.append(cached) | ||
| 122 | + else: | ||
| 123 | + results.append(np.array([], dtype=np.float32)) # placeholder | ||
| 124 | + pending_positions.append(pos) | ||
| 125 | + pending_urls.append(url) | ||
| 126 | + | ||
| 127 | + for i in range(0, len(pending_urls), batch_size): | ||
| 128 | + batch_urls = pending_urls[i : i + batch_size] | ||
| 103 | response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings) | 129 | response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings) |
| 104 | if not response_data or len(response_data) != len(batch_urls): | 130 | if not response_data or len(response_data) != len(batch_urls): |
| 105 | raise RuntimeError( | 131 | raise RuntimeError( |
| @@ -113,7 +139,9 @@ class CLIPImageEncoder: | @@ -113,7 +139,9 @@ class CLIPImageEncoder: | ||
| 113 | vec = np.array(embedding, dtype=np.float32) | 139 | vec = np.array(embedding, dtype=np.float32) |
| 114 | if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): | 140 | if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): |
| 115 | raise RuntimeError(f"Invalid image embedding returned for URL: {url}") | 141 | raise RuntimeError(f"Invalid image embedding returned for URL: {url}") |
| 116 | - results.append(vec) | 142 | + self.cache.set(url, vec) |
| 143 | + pos = pending_positions[i + j] | ||
| 144 | + results[pos] = vec | ||
| 117 | 145 | ||
| 118 | return results | 146 | return results |
| 119 | 147 |
| @@ -0,0 +1,107 @@ | @@ -0,0 +1,107 @@ | ||
| 1 | +""" | ||
| 2 | +Shared Redis cache for embedding vectors (text + image). | ||
| 3 | + | ||
| 4 | +Value format: BF16 bytes (see embeddings/bf16.py) | ||
| 5 | +Expiration: setex on write + sliding expire on successful read. | ||
| 6 | + | ||
| 7 | +This module is intentionally small and dependency-light, so both indexer/search paths | ||
| 8 | +can reuse it safely. | ||
| 9 | +""" | ||
| 10 | + | ||
| 11 | +from __future__ import annotations | ||
| 12 | + | ||
| 13 | +import logging | ||
| 14 | +from datetime import timedelta | ||
| 15 | +from typing import Optional | ||
| 16 | + | ||
| 17 | +import numpy as np | ||
| 18 | +import redis | ||
| 19 | + | ||
| 20 | +from config.env_config import REDIS_CONFIG | ||
| 21 | +from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis | ||
| 22 | + | ||
| 23 | +logger = logging.getLogger(__name__) | ||
| 24 | + | ||
| 25 | + | ||
| 26 | +class RedisEmbeddingCache: | ||
| 27 | + def __init__( | ||
| 28 | + self, | ||
| 29 | + *, | ||
| 30 | + key_prefix: str, | ||
| 31 | + namespace: str = "", | ||
| 32 | + expire_time: Optional[timedelta] = None, | ||
| 33 | + redis_client: Optional[redis.Redis] = None, | ||
| 34 | + ): | ||
| 35 | + self.key_prefix = (key_prefix or "").strip() or "embedding" | ||
| 36 | + self.namespace = (namespace or "").strip() | ||
| 37 | + self.expire_time = expire_time or timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180)) | ||
| 38 | + | ||
| 39 | + if redis_client is not None: | ||
| 40 | + self.redis_client = redis_client | ||
| 41 | + return | ||
| 42 | + | ||
| 43 | + try: | ||
| 44 | + client = redis.Redis( | ||
| 45 | + host=REDIS_CONFIG.get("host", "localhost"), | ||
| 46 | + port=REDIS_CONFIG.get("port", 6479), | ||
| 47 | + password=REDIS_CONFIG.get("password"), | ||
| 48 | + decode_responses=False, | ||
| 49 | + socket_timeout=REDIS_CONFIG.get("socket_timeout", 1), | ||
| 50 | + socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1), | ||
| 51 | + retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False), | ||
| 52 | + health_check_interval=10, | ||
| 53 | + ) | ||
| 54 | + client.ping() | ||
| 55 | + self.redis_client = client | ||
| 56 | + except Exception as e: | ||
| 57 | + logger.warning("Failed to initialize Redis cache: %s, continuing without cache", e) | ||
| 58 | + self.redis_client = None | ||
| 59 | + | ||
| 60 | + def make_key(self, raw_key: str) -> str: | ||
| 61 | + if self.namespace: | ||
| 62 | + return f"{self.key_prefix}:{self.namespace}:{raw_key}" | ||
| 63 | + return f"{self.key_prefix}:{raw_key}" | ||
| 64 | + | ||
| 65 | + def get(self, raw_key: str) -> Optional[np.ndarray]: | ||
| 66 | + if not self.redis_client: | ||
| 67 | + return None | ||
| 68 | + key = self.make_key(raw_key) | ||
| 69 | + try: | ||
| 70 | + raw = self.redis_client.get(key) | ||
| 71 | + if not raw: | ||
| 72 | + return None | ||
| 73 | + vec = decode_embedding_from_redis(raw) | ||
| 74 | + if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): | ||
| 75 | + try: | ||
| 76 | + self.redis_client.delete(key) | ||
| 77 | + except Exception: | ||
| 78 | + pass | ||
| 79 | + return None | ||
| 80 | + # Sliding expiration: refresh TTL on hit. | ||
| 81 | + self.redis_client.expire(key, self.expire_time) | ||
| 82 | + return vec | ||
| 83 | + except Exception as e: | ||
| 84 | + logger.warning("Error retrieving embedding from cache: %s", e) | ||
| 85 | + return None | ||
| 86 | + | ||
| 87 | + def set(self, raw_key: str, embedding: np.ndarray) -> bool: | ||
| 88 | + if not self.redis_client: | ||
| 89 | + return False | ||
| 90 | + key = self.make_key(raw_key) | ||
| 91 | + try: | ||
| 92 | + vec = np.asarray(embedding, dtype=np.float32) | ||
| 93 | + raw = encode_embedding_for_redis(vec) | ||
| 94 | + self.redis_client.setex(key, self.expire_time, raw) | ||
| 95 | + return True | ||
| 96 | + except ( | ||
| 97 | + redis.exceptions.BusyLoadingError, | ||
| 98 | + redis.exceptions.ConnectionError, | ||
| 99 | + redis.exceptions.TimeoutError, | ||
| 100 | + redis.exceptions.RedisError, | ||
| 101 | + ) as e: | ||
| 102 | + logger.warning("Redis error storing embedding in cache: %s", e) | ||
| 103 | + return False | ||
| 104 | + except Exception as e: | ||
| 105 | + logger.warning("Error storing embedding in cache: %s", e) | ||
| 106 | + return False | ||
| 107 | + |
embeddings/text_encoder.py
| @@ -2,17 +2,16 @@ | @@ -2,17 +2,16 @@ | ||
| 2 | 2 | ||
| 3 | import logging | 3 | import logging |
| 4 | import os | 4 | import os |
| 5 | -import pickle | ||
| 6 | from datetime import timedelta | 5 | from datetime import timedelta |
| 7 | from typing import Any, List, Optional, Union | 6 | from typing import Any, List, Optional, Union |
| 8 | 7 | ||
| 9 | import numpy as np | 8 | import numpy as np |
| 10 | -import redis | ||
| 11 | import requests | 9 | import requests |
| 12 | 10 | ||
| 13 | logger = logging.getLogger(__name__) | 11 | logger = logging.getLogger(__name__) |
| 14 | 12 | ||
| 15 | from config.services_config import get_embedding_base_url | 13 | from config.services_config import get_embedding_base_url |
| 14 | +from embeddings.redis_embedding_cache import RedisEmbeddingCache | ||
| 16 | 15 | ||
| 17 | # Try to import REDIS_CONFIG, but allow import to fail | 16 | # Try to import REDIS_CONFIG, but allow import to fail |
| 18 | from config.env_config import REDIS_CONFIG | 17 | from config.env_config import REDIS_CONFIG |
| @@ -30,22 +29,11 @@ class TextEmbeddingEncoder: | @@ -30,22 +29,11 @@ class TextEmbeddingEncoder: | ||
| 30 | self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" | 29 | self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" |
| 31 | logger.info("Creating TextEmbeddingEncoder instance with service URL: %s", self.service_url) | 30 | logger.info("Creating TextEmbeddingEncoder instance with service URL: %s", self.service_url) |
| 32 | 31 | ||
| 33 | - try: | ||
| 34 | - self.redis_client = redis.Redis( | ||
| 35 | - host=REDIS_CONFIG.get("host", "localhost"), | ||
| 36 | - port=REDIS_CONFIG.get("port", 6479), | ||
| 37 | - password=REDIS_CONFIG.get("password"), | ||
| 38 | - decode_responses=False, | ||
| 39 | - socket_timeout=REDIS_CONFIG.get("socket_timeout", 1), | ||
| 40 | - socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1), | ||
| 41 | - retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False), | ||
| 42 | - health_check_interval=10, | ||
| 43 | - ) | ||
| 44 | - self.redis_client.ping() | ||
| 45 | - logger.info("Redis cache initialized for embeddings") | ||
| 46 | - except Exception as e: | ||
| 47 | - logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e) | ||
| 48 | - self.redis_client = None | 32 | + self.cache = RedisEmbeddingCache( |
| 33 | + key_prefix=self.cache_prefix, | ||
| 34 | + namespace="", | ||
| 35 | + expire_time=self.expire_time, | ||
| 36 | + ) | ||
| 49 | 37 | ||
| 50 | def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: | 38 | def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: |
| 51 | """ | 39 | """ |
| @@ -127,7 +115,7 @@ class TextEmbeddingEncoder: | @@ -127,7 +115,7 @@ class TextEmbeddingEncoder: | ||
| 127 | embedding_array = np.array(embedding, dtype=np.float32) | 115 | embedding_array = np.array(embedding, dtype=np.float32) |
| 128 | if self._is_valid_embedding(embedding_array): | 116 | if self._is_valid_embedding(embedding_array): |
| 129 | embeddings[original_idx] = embedding_array | 117 | embeddings[original_idx] = embedding_array |
| 130 | - self._set_cached_embedding(text, embedding_array, normalize_embeddings) | 118 | + self._set_cached_embedding(text, embedding_array) |
| 131 | else: | 119 | else: |
| 132 | raise ValueError( | 120 | raise ValueError( |
| 133 | f"Invalid embedding returned from service for text index {original_idx}" | 121 | f"Invalid embedding returned from service for text index {original_idx}" |
| @@ -161,63 +149,21 @@ class TextEmbeddingEncoder: | @@ -161,63 +149,21 @@ class TextEmbeddingEncoder: | ||
| 161 | 149 | ||
| 162 | def _get_cached_embedding( | 150 | def _get_cached_embedding( |
| 163 | self, | 151 | self, |
| 164 | - query: str | 152 | + query: str, |
| 165 | ) -> Optional[np.ndarray]: | 153 | ) -> Optional[np.ndarray]: |
| 166 | - """Get embedding from cache if exists (with sliding expiration)""" | ||
| 167 | - if not self.redis_client: | ||
| 168 | - return None | ||
| 169 | - | ||
| 170 | - try: | ||
| 171 | - cache_key = f"{self.cache_prefix}:{query}" | ||
| 172 | - cached_data = self.redis_client.get(cache_key) | ||
| 173 | - if cached_data: | ||
| 174 | - embedding = pickle.loads(cached_data) | ||
| 175 | - # Validate cached embedding - if invalid, ignore cache and return None | ||
| 176 | - if self._is_valid_embedding(embedding): | ||
| 177 | - logger.debug(f"Cache hit for embedding: {query}") | ||
| 178 | - # Update expiration time on access (sliding expiration) | ||
| 179 | - self.redis_client.expire(cache_key, self.expire_time) | ||
| 180 | - return embedding | ||
| 181 | - else: | ||
| 182 | - logger.warning( | ||
| 183 | - f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), " | ||
| 184 | - f"ignoring cache for query: {query[:50]}..." | ||
| 185 | - ) | ||
| 186 | - # Delete invalid cache entry | ||
| 187 | - try: | ||
| 188 | - self.redis_client.delete(cache_key) | ||
| 189 | - except Exception as e: | ||
| 190 | - logger.debug(f"Failed to delete invalid cache entry: {e}") | ||
| 191 | - return None | ||
| 192 | - return None | ||
| 193 | - except Exception as e: | ||
| 194 | - logger.error(f"Error retrieving embedding from cache: {e}") | ||
| 195 | - return None | 154 | + """Get embedding from cache if exists (with sliding expiration).""" |
| 155 | + embedding = self.cache.get(query) | ||
| 156 | + if embedding is not None: | ||
| 157 | + logger.debug(f"Cache hit for embedding: {query}") | ||
| 158 | + return embedding | ||
| 196 | 159 | ||
| 197 | def _set_cached_embedding( | 160 | def _set_cached_embedding( |
| 198 | self, | 161 | self, |
| 199 | query: str, | 162 | query: str, |
| 200 | embedding: np.ndarray, | 163 | embedding: np.ndarray, |
| 201 | - normalize_embeddings: bool = True, | ||
| 202 | ) -> bool: | 164 | ) -> bool: |
| 203 | - """Store embedding in cache""" | ||
| 204 | - if not self.redis_client: | ||
| 205 | - return False | ||
| 206 | - | ||
| 207 | - try: | ||
| 208 | - cache_key = f"{self.cache_prefix}:{query}" | ||
| 209 | - serialized_data = pickle.dumps(embedding) | ||
| 210 | - self.redis_client.setex( | ||
| 211 | - cache_key, | ||
| 212 | - self.expire_time, | ||
| 213 | - serialized_data | ||
| 214 | - ) | 165 | + """Store embedding in cache.""" |
| 166 | + ok = self.cache.set(query, embedding) | ||
| 167 | + if ok: | ||
| 215 | logger.debug(f"Successfully cached embedding for query: {query}") | 168 | logger.debug(f"Successfully cached embedding for query: {query}") |
| 216 | - return True | ||
| 217 | - except (redis.exceptions.BusyLoadingError, redis.exceptions.ConnectionError, | ||
| 218 | - redis.exceptions.TimeoutError, redis.exceptions.RedisError) as e: | ||
| 219 | - logger.warning(f"Redis error storing embedding in cache: {e}") | ||
| 220 | - return False | ||
| 221 | - except Exception as e: | ||
| 222 | - logger.error(f"Error storing embedding in cache: {e}") | ||
| 223 | - return False | 169 | + return ok |
scripts/redis/redis_cache_health_check.py
| @@ -28,7 +28,6 @@ from __future__ import annotations | @@ -28,7 +28,6 @@ from __future__ import annotations | ||
| 28 | 28 | ||
| 29 | import argparse | 29 | import argparse |
| 30 | import json | 30 | import json |
| 31 | -import pickle | ||
| 32 | import sys | 31 | import sys |
| 33 | from collections import defaultdict | 32 | from collections import defaultdict |
| 34 | from dataclasses import dataclass | 33 | from dataclasses import dataclass |
| @@ -45,6 +44,7 @@ sys.path.insert(0, str(PROJECT_ROOT)) | @@ -45,6 +44,7 @@ sys.path.insert(0, str(PROJECT_ROOT)) | ||
| 45 | 44 | ||
| 46 | from config.env_config import REDIS_CONFIG # type: ignore | 45 | from config.env_config import REDIS_CONFIG # type: ignore |
| 47 | from config.services_config import get_translation_cache_config # type: ignore | 46 | from config.services_config import get_translation_cache_config # type: ignore |
| 47 | +from embeddings.bf16 import decode_embedding_from_redis # type: ignore | ||
| 48 | 48 | ||
| 49 | 49 | ||
| 50 | @dataclass | 50 | @dataclass |
| @@ -58,10 +58,11 @@ def _load_known_cache_types() -> Dict[str, CacheTypeConfig]: | @@ -58,10 +58,11 @@ def _load_known_cache_types() -> Dict[str, CacheTypeConfig]: | ||
| 58 | """根据当前配置装配三种已知缓存类型及其前缀 pattern。""" | 58 | """根据当前配置装配三种已知缓存类型及其前缀 pattern。""" |
| 59 | cache_types: Dict[str, CacheTypeConfig] = {} | 59 | cache_types: Dict[str, CacheTypeConfig] = {} |
| 60 | 60 | ||
| 61 | - # embedding 缓存:固定 embedding:* 前缀 | 61 | + # embedding 缓存:prefix 来自 REDIS_CONFIG['embedding_cache_prefix'](默认 embedding) |
| 62 | + embedding_prefix = REDIS_CONFIG.get("embedding_cache_prefix", "embedding") | ||
| 62 | cache_types["embedding"] = CacheTypeConfig( | 63 | cache_types["embedding"] = CacheTypeConfig( |
| 63 | name="embedding", | 64 | name="embedding", |
| 64 | - pattern="embedding:*", | 65 | + pattern=f"{embedding_prefix}:*", |
| 65 | description="文本向量缓存(embeddings/text_encoder.py)", | 66 | description="文本向量缓存(embeddings/text_encoder.py)", |
| 66 | ) | 67 | ) |
| 67 | 68 | ||
| @@ -153,13 +154,14 @@ def decode_value_preview( | @@ -153,13 +154,14 @@ def decode_value_preview( | ||
| 153 | if raw_value is None: | 154 | if raw_value is None: |
| 154 | return "<nil>" | 155 | return "<nil>" |
| 155 | 156 | ||
| 156 | - # embedding: pickle 序列化的 numpy.ndarray | 157 | + # embedding: BF16 bytes |
| 157 | if cache_type == "embedding": | 158 | if cache_type == "embedding": |
| 158 | try: | 159 | try: |
| 159 | - arr = pickle.loads(raw_value) | 160 | + arr = decode_embedding_from_redis(raw_value) |
| 160 | if isinstance(arr, np.ndarray): | 161 | if isinstance(arr, np.ndarray): |
| 161 | - return f"ndarray shape={arr.shape} dtype={arr.dtype}" | ||
| 162 | - return f"pickle object type={type(arr).__name__}" | 162 | + dim = int(arr.size) |
| 163 | + return f"bf16 bytes={len(raw_value)} dim={dim} restored_dtype={arr.dtype}" | ||
| 164 | + return f"bf16 decode type={type(arr).__name__}" | ||
| 163 | except Exception: | 165 | except Exception: |
| 164 | return f"<binary {len(raw_value)} bytes>" | 166 | return f"<binary {len(raw_value)} bytes>" |
| 165 | 167 |
tests/test_embedding_pipeline.py
| 1 | -import pickle | ||
| 2 | from typing import Any, Dict, List, Optional | 1 | from typing import Any, Dict, List, Optional |
| 3 | 2 | ||
| 4 | import numpy as np | 3 | import numpy as np |
| @@ -13,6 +12,7 @@ from config import ( | @@ -13,6 +12,7 @@ from config import ( | ||
| 13 | SearchConfig, | 12 | SearchConfig, |
| 14 | ) | 13 | ) |
| 15 | from embeddings.text_encoder import TextEmbeddingEncoder | 14 | from embeddings.text_encoder import TextEmbeddingEncoder |
| 15 | +from embeddings.bf16 import encode_embedding_for_redis | ||
| 16 | from query import QueryParser | 16 | from query import QueryParser |
| 17 | 17 | ||
| 18 | 18 | ||
| @@ -128,7 +128,7 @@ def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): | @@ -128,7 +128,7 @@ def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): | ||
| 128 | def test_text_embedding_encoder_cache_hit(monkeypatch): | 128 | def test_text_embedding_encoder_cache_hit(monkeypatch): |
| 129 | fake_redis = _FakeRedis() | 129 | fake_redis = _FakeRedis() |
| 130 | cached = np.array([0.9, 0.8], dtype=np.float32) | 130 | cached = np.array([0.9, 0.8], dtype=np.float32) |
| 131 | - fake_redis.store["embedding:cached-text"] = pickle.dumps(cached) | 131 | + fake_redis.store["embedding:cached-text"] = encode_embedding_for_redis(cached) |
| 132 | monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) | 132 | monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) |
| 133 | 133 | ||
| 134 | calls = {"count": 0} | 134 | calls = {"count": 0} |