From 4a37d233b1b7e61d86d841f7fb304aa38d39a919 Mon Sep 17 00:00:00 2001 From: tangwang Date: Tue, 17 Mar 2026 15:06:51 +0800 Subject: [PATCH] 1. embedding cache float32 -> bf16 2. 抽象出可复用的 embedding Redis 缓存类(图文共用) --- docs/TODO.txt | 20 +++++++++++++++++++- docs/缓存与Redis使用说明.md | 9 +++++---- embeddings/bf16.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ embeddings/image_encoder.py | 34 +++++++++++++++++++++++++++++++--- embeddings/redis_embedding_cache.py | 107 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ embeddings/text_encoder.py | 88 +++++++++++++++++----------------------------------------------------------------------- scripts/redis/redis_cache_health_check.py | 16 +++++++++------- tests/test_embedding_pipeline.py | 4 ++-- 8 files changed, 277 insertions(+), 88 deletions(-) create mode 100644 embeddings/bf16.py create mode 100644 embeddings/redis_embedding_cache.py diff --git a/docs/TODO.txt b/docs/TODO.txt index dcbd517..3af2428 100644 --- a/docs/TODO.txt +++ b/docs/TODO.txt @@ -1,5 +1,23 @@ +product_enrich : Partial Mode +https://help.aliyun.com/zh/model-studio/partial-mode?spm=a2c4g.11186623.help-menu-2400256.d_0_3_0_7.74a630119Ct6zR +需在messages 数组中将最后一条消息的 role 设置为 assistant,并在其 content 中提供前缀,在此消息中设置参数 "partial": true。messages格式如下: +[ + { + "role": "user", + "content": "请补全这个斐波那契函数,勿添加其它内容" + }, + { + "role": "assistant", + "content": "def calculate_fibonacci(n):\n if n <= 1:\n return n\n else:\n", + "partial": true + } +] +模型会以前缀内容为起点开始生成。 + +支持 非思考模式。 + @@ -109,7 +127,7 @@ translation: base_url: "http://127.0.0.1:6006" model: "llm" # 或 "qwen-mt-flush",看你想用哪个 timeout_sec: 10.0 - llm: + llm:. model: "qwen-flash" # 留给翻译服务自身内部使用 qwen-mt: ... deepl: ... diff --git a/docs/缓存与Redis使用说明.md b/docs/缓存与Redis使用说明.md index 8563978..f5677c9 100644 --- a/docs/缓存与Redis使用说明.md +++ b/docs/缓存与Redis使用说明.md @@ -20,7 +20,7 @@ | 模块 / 场景 | Key 模板 | Value 内容示例 | 过期策略 | 备注 | |------------|----------|----------------|----------|------| -| 文本向量缓存(embedding) | `{EMBEDDING_CACHE_PREFIX}:{query}` | `pickle.dumps(np.ndarray)`,如 1024 维 BGE 向量 | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py`,前缀由 `REDIS_CONFIG["embedding_cache_prefix"]` 控制 | +| 向量缓存(text/image embedding) | `{EMBEDDING_CACHE_PREFIX}:{query_or_url}` / `{EMBEDDING_CACHE_PREFIX}:image:{url_or_path}` | **BF16 bytes**(每维 2 字节大端存储),读取后恢复为 `np.float32` | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py`(文本)与 `embeddings/image_encoder.py`(图片);前缀由 `REDIS_CONFIG["embedding_cache_prefix"]` 控制 | | 翻译结果缓存(Qwen-MT 翻译) | `{cache_prefix}:{model}:{src}:{tgt}:{sha256(payload)}` | 机翻后的单条字符串 | TTL=`services.translation.cache.ttl_seconds` 秒;可配置滑动过期 | 见 `query/qwen_mt_translate.py` + `config/config.yaml` | | 商品内容理解缓存(anchors / 语义属性 / tags) | `{ANCHOR_CACHE_PREFIX}:{tenant_or_global}:{target_lang}:{md5(title)}` | `json.dumps(dict)`,包含 id/title/category/tags/anchor_text 等 | TTL=`ANCHOR_CACHE_EXPIRE_DAYS` 天 | 见 `indexer/product_enrich.py` | @@ -44,13 +44,14 @@ - 字段说明: - `EMBEDDING_CACHE_PREFIX`:来自 `REDIS_CONFIG["embedding_cache_prefix"]`,默认值为 `"embedding"`,可通过环境变量 `REDIS_EMBEDDING_CACHE_PREFIX` 覆盖; - - 当前实现**不再区分 language 与 normalize flag**,即无论是否归一化,key 结构都相同; - `query`:原始文本(未做哈希),注意长度特别长的 query 会直接出现在 key 中。 ### 2.2 Value 与类型 -- 类型:`pickle.dumps(np.ndarray)`,在读取时通过 `pickle.loads` 还原为 `np.ndarray`。 -- 典型示例:BGE-M3 1024 维 `float32` 向量。 +- 类型:**BF16 bytes**(bfloat16),每一维用 2 字节无符号整数表示,按**大端**序列化。 +- 写入流程:FP32 向量 → BF16 → bytes → Redis +- 读取流程:Redis bytes → BF16 → FP32(`np.float32`)向量 +- 典型示例:BGE-M3 1024 维向量在 Redis value 大小约为 \(1024*2=2048\) bytes(不含 Redis 元数据开销)。 ### 2.3 过期策略 diff --git a/embeddings/bf16.py b/embeddings/bf16.py new file mode 100644 index 0000000..aecdb57 --- /dev/null +++ b/embeddings/bf16.py @@ -0,0 +1,87 @@ +""" +BF16 (bfloat16) codec helpers for Redis embedding cache. + +We store embeddings in Redis as: + FP32 vector -> (optional L2 normalize) -> BF16 (uint16 per element, big-endian) -> bytes + +No backward compatibility is provided by design. +""" + +from __future__ import annotations + +import struct +from typing import Iterable, List, Sequence + +import numpy as np + + +def float32_to_bf16(value: float) -> int: + """ + float32 -> bfloat16 (returns 0..65535 uint16) + Round-to-nearest-even. + """ + bits = struct.unpack(">I", struct.pack(">f", float(value)))[0] + rounding_bias = ((bits >> 16) & 1) + 0x7FFF + bits += rounding_bias + return (bits >> 16) & 0xFFFF + + +def bf16_to_float32(bf16: int) -> float: + """ + bfloat16 -> float32. + bf16 is an int in 0..65535. + """ + bits = (int(bf16) & 0xFFFF) << 16 + return struct.unpack(">f", struct.pack(">I", bits))[0] + + +def float_array_to_bf16(vector: Sequence[float]) -> List[int]: + return [float32_to_bf16(v) for v in vector] + + +def bf16_array_to_float(vector_bf16: Sequence[int]) -> List[float]: + return [bf16_to_float32(v) for v in vector_bf16] + + +def bf16_list_to_bytes(bf16_list: Sequence[int]) -> bytes: + """Each bf16 uses 2 bytes big-endian.""" + return b"".join(struct.pack(">H", int(x) & 0xFFFF) for x in bf16_list) + + +def bytes_to_bf16_list(data: bytes) -> List[int]: + if len(data) % 2 != 0: + raise ValueError("BF16 byte length must be even") + return [struct.unpack(">H", data[i : i + 2])[0] for i in range(0, len(data), 2)] + + +def encode_embedding_for_redis(embedding: np.ndarray) -> bytes: + """ + FP32 embedding -> BF16 -> bytes. + """ + arr = np.asarray(embedding, dtype=np.float32) + if arr.ndim != 1: + arr = arr.reshape(-1) + # Ensure we operate on plain Python floats for the reference codec. + bf16_list = float_array_to_bf16(arr.tolist()) + return bf16_list_to_bytes(bf16_list) + + +def decode_embedding_from_redis(data: bytes) -> np.ndarray: + """ + Redis bytes -> BF16 -> FP32 numpy array. + """ + bf16_list = bytes_to_bf16_list(data) + floats = bf16_array_to_float(bf16_list) + return np.asarray(floats, dtype=np.float32) + + +def l2_normalize_fp32(vec: np.ndarray) -> np.ndarray: + """L2-normalize a 1D FP32 vector. Raises on invalid norms.""" + arr = np.asarray(vec, dtype=np.float32) + if arr.ndim != 1: + arr = arr.reshape(-1) + norm = float(np.linalg.norm(arr)) + if not np.isfinite(norm) or norm <= 0.0: + raise ValueError("Embedding vector has invalid norm (must be > 0)") + return (arr / norm).astype(np.float32, copy=False) + diff --git a/embeddings/image_encoder.py b/embeddings/image_encoder.py index 728c184..d2b8e4c 100644 --- a/embeddings/image_encoder.py +++ b/embeddings/image_encoder.py @@ -11,6 +11,8 @@ from PIL import Image logger = logging.getLogger(__name__) from config.services_config import get_embedding_base_url +from config.env_config import REDIS_CONFIG +from embeddings.redis_embedding_cache import RedisEmbeddingCache class CLIPImageEncoder: @@ -24,7 +26,13 @@ class CLIPImageEncoder: resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url() self.service_url = str(resolved_url).rstrip("/") self.endpoint = f"{self.service_url}/embed/image" + # Reuse embedding cache prefix, but separate namespace for images to avoid collisions. + self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url) + self.cache = RedisEmbeddingCache( + key_prefix=self.cache_prefix, + namespace="image", + ) def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: """ @@ -67,12 +75,17 @@ class CLIPImageEncoder: Returns: Embedding vector """ + cached = self.cache.get(url) + if cached is not None: + return cached + response_data = self._call_service([url], normalize_embeddings=normalize_embeddings) if not response_data or len(response_data) != 1 or response_data[0] is None: raise RuntimeError(f"No image embedding returned for URL: {url}") vec = np.array(response_data[0], dtype=np.float32) if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): raise RuntimeError(f"Invalid image embedding returned for URL: {url}") + self.cache.set(url, vec) return vec def encode_batch( @@ -98,8 +111,21 @@ class CLIPImageEncoder: raise ValueError(f"Invalid image URL/path at index {i}: {img!r}") results: List[np.ndarray] = [] - for i in range(0, len(images), batch_size): - batch_urls = [str(u).strip() for u in images[i:i + batch_size]] + pending_urls: List[str] = [] + pending_positions: List[int] = [] + + normalized_urls = [str(u).strip() for u in images] # type: ignore[list-item] + for pos, url in enumerate(normalized_urls): + cached = self.cache.get(url) + if cached is not None: + results.append(cached) + else: + results.append(np.array([], dtype=np.float32)) # placeholder + pending_positions.append(pos) + pending_urls.append(url) + + for i in range(0, len(pending_urls), batch_size): + batch_urls = pending_urls[i : i + batch_size] response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings) if not response_data or len(response_data) != len(batch_urls): raise RuntimeError( @@ -113,7 +139,9 @@ class CLIPImageEncoder: vec = np.array(embedding, dtype=np.float32) if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): raise RuntimeError(f"Invalid image embedding returned for URL: {url}") - results.append(vec) + self.cache.set(url, vec) + pos = pending_positions[i + j] + results[pos] = vec return results diff --git a/embeddings/redis_embedding_cache.py b/embeddings/redis_embedding_cache.py new file mode 100644 index 0000000..0a1a1e4 --- /dev/null +++ b/embeddings/redis_embedding_cache.py @@ -0,0 +1,107 @@ +""" +Shared Redis cache for embedding vectors (text + image). + +Value format: BF16 bytes (see embeddings/bf16.py) +Expiration: setex on write + sliding expire on successful read. + +This module is intentionally small and dependency-light, so both indexer/search paths +can reuse it safely. +""" + +from __future__ import annotations + +import logging +from datetime import timedelta +from typing import Optional + +import numpy as np +import redis + +from config.env_config import REDIS_CONFIG +from embeddings.bf16 import decode_embedding_from_redis, encode_embedding_for_redis + +logger = logging.getLogger(__name__) + + +class RedisEmbeddingCache: + def __init__( + self, + *, + key_prefix: str, + namespace: str = "", + expire_time: Optional[timedelta] = None, + redis_client: Optional[redis.Redis] = None, + ): + self.key_prefix = (key_prefix or "").strip() or "embedding" + self.namespace = (namespace or "").strip() + self.expire_time = expire_time or timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180)) + + if redis_client is not None: + self.redis_client = redis_client + return + + try: + client = redis.Redis( + host=REDIS_CONFIG.get("host", "localhost"), + port=REDIS_CONFIG.get("port", 6479), + password=REDIS_CONFIG.get("password"), + decode_responses=False, + socket_timeout=REDIS_CONFIG.get("socket_timeout", 1), + socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1), + retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False), + health_check_interval=10, + ) + client.ping() + self.redis_client = client + except Exception as e: + logger.warning("Failed to initialize Redis cache: %s, continuing without cache", e) + self.redis_client = None + + def make_key(self, raw_key: str) -> str: + if self.namespace: + return f"{self.key_prefix}:{self.namespace}:{raw_key}" + return f"{self.key_prefix}:{raw_key}" + + def get(self, raw_key: str) -> Optional[np.ndarray]: + if not self.redis_client: + return None + key = self.make_key(raw_key) + try: + raw = self.redis_client.get(key) + if not raw: + return None + vec = decode_embedding_from_redis(raw) + if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all(): + try: + self.redis_client.delete(key) + except Exception: + pass + return None + # Sliding expiration: refresh TTL on hit. + self.redis_client.expire(key, self.expire_time) + return vec + except Exception as e: + logger.warning("Error retrieving embedding from cache: %s", e) + return None + + def set(self, raw_key: str, embedding: np.ndarray) -> bool: + if not self.redis_client: + return False + key = self.make_key(raw_key) + try: + vec = np.asarray(embedding, dtype=np.float32) + raw = encode_embedding_for_redis(vec) + self.redis_client.setex(key, self.expire_time, raw) + return True + except ( + redis.exceptions.BusyLoadingError, + redis.exceptions.ConnectionError, + redis.exceptions.TimeoutError, + redis.exceptions.RedisError, + ) as e: + logger.warning("Redis error storing embedding in cache: %s", e) + return False + except Exception as e: + logger.warning("Error storing embedding in cache: %s", e) + return False + diff --git a/embeddings/text_encoder.py b/embeddings/text_encoder.py index ee54a46..d3f08fe 100644 --- a/embeddings/text_encoder.py +++ b/embeddings/text_encoder.py @@ -2,17 +2,16 @@ import logging import os -import pickle from datetime import timedelta from typing import Any, List, Optional, Union import numpy as np -import redis import requests logger = logging.getLogger(__name__) from config.services_config import get_embedding_base_url +from embeddings.redis_embedding_cache import RedisEmbeddingCache # Try to import REDIS_CONFIG, but allow import to fail from config.env_config import REDIS_CONFIG @@ -30,22 +29,11 @@ class TextEmbeddingEncoder: self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" logger.info("Creating TextEmbeddingEncoder instance with service URL: %s", self.service_url) - try: - self.redis_client = redis.Redis( - host=REDIS_CONFIG.get("host", "localhost"), - port=REDIS_CONFIG.get("port", 6479), - password=REDIS_CONFIG.get("password"), - decode_responses=False, - socket_timeout=REDIS_CONFIG.get("socket_timeout", 1), - socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1), - retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False), - health_check_interval=10, - ) - self.redis_client.ping() - logger.info("Redis cache initialized for embeddings") - except Exception as e: - logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e) - self.redis_client = None + self.cache = RedisEmbeddingCache( + key_prefix=self.cache_prefix, + namespace="", + expire_time=self.expire_time, + ) def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]: """ @@ -127,7 +115,7 @@ class TextEmbeddingEncoder: embedding_array = np.array(embedding, dtype=np.float32) if self._is_valid_embedding(embedding_array): embeddings[original_idx] = embedding_array - self._set_cached_embedding(text, embedding_array, normalize_embeddings) + self._set_cached_embedding(text, embedding_array) else: raise ValueError( f"Invalid embedding returned from service for text index {original_idx}" @@ -161,63 +149,21 @@ class TextEmbeddingEncoder: def _get_cached_embedding( self, - query: str + query: str, ) -> Optional[np.ndarray]: - """Get embedding from cache if exists (with sliding expiration)""" - if not self.redis_client: - return None - - try: - cache_key = f"{self.cache_prefix}:{query}" - cached_data = self.redis_client.get(cache_key) - if cached_data: - embedding = pickle.loads(cached_data) - # Validate cached embedding - if invalid, ignore cache and return None - if self._is_valid_embedding(embedding): - logger.debug(f"Cache hit for embedding: {query}") - # Update expiration time on access (sliding expiration) - self.redis_client.expire(cache_key, self.expire_time) - return embedding - else: - logger.warning( - f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), " - f"ignoring cache for query: {query[:50]}..." - ) - # Delete invalid cache entry - try: - self.redis_client.delete(cache_key) - except Exception as e: - logger.debug(f"Failed to delete invalid cache entry: {e}") - return None - return None - except Exception as e: - logger.error(f"Error retrieving embedding from cache: {e}") - return None + """Get embedding from cache if exists (with sliding expiration).""" + embedding = self.cache.get(query) + if embedding is not None: + logger.debug(f"Cache hit for embedding: {query}") + return embedding def _set_cached_embedding( self, query: str, embedding: np.ndarray, - normalize_embeddings: bool = True, ) -> bool: - """Store embedding in cache""" - if not self.redis_client: - return False - - try: - cache_key = f"{self.cache_prefix}:{query}" - serialized_data = pickle.dumps(embedding) - self.redis_client.setex( - cache_key, - self.expire_time, - serialized_data - ) + """Store embedding in cache.""" + ok = self.cache.set(query, embedding) + if ok: logger.debug(f"Successfully cached embedding for query: {query}") - return True - except (redis.exceptions.BusyLoadingError, redis.exceptions.ConnectionError, - redis.exceptions.TimeoutError, redis.exceptions.RedisError) as e: - logger.warning(f"Redis error storing embedding in cache: {e}") - return False - except Exception as e: - logger.error(f"Error storing embedding in cache: {e}") - return False + return ok diff --git a/scripts/redis/redis_cache_health_check.py b/scripts/redis/redis_cache_health_check.py index 6c8e293..e3854f2 100644 --- a/scripts/redis/redis_cache_health_check.py +++ b/scripts/redis/redis_cache_health_check.py @@ -28,7 +28,6 @@ from __future__ import annotations import argparse import json -import pickle import sys from collections import defaultdict from dataclasses import dataclass @@ -45,6 +44,7 @@ sys.path.insert(0, str(PROJECT_ROOT)) from config.env_config import REDIS_CONFIG # type: ignore from config.services_config import get_translation_cache_config # type: ignore +from embeddings.bf16 import decode_embedding_from_redis # type: ignore @dataclass @@ -58,10 +58,11 @@ def _load_known_cache_types() -> Dict[str, CacheTypeConfig]: """根据当前配置装配三种已知缓存类型及其前缀 pattern。""" cache_types: Dict[str, CacheTypeConfig] = {} - # embedding 缓存:固定 embedding:* 前缀 + # embedding 缓存:prefix 来自 REDIS_CONFIG['embedding_cache_prefix'](默认 embedding) + embedding_prefix = REDIS_CONFIG.get("embedding_cache_prefix", "embedding") cache_types["embedding"] = CacheTypeConfig( name="embedding", - pattern="embedding:*", + pattern=f"{embedding_prefix}:*", description="文本向量缓存(embeddings/text_encoder.py)", ) @@ -153,13 +154,14 @@ def decode_value_preview( if raw_value is None: return "" - # embedding: pickle 序列化的 numpy.ndarray + # embedding: BF16 bytes if cache_type == "embedding": try: - arr = pickle.loads(raw_value) + arr = decode_embedding_from_redis(raw_value) if isinstance(arr, np.ndarray): - return f"ndarray shape={arr.shape} dtype={arr.dtype}" - return f"pickle object type={type(arr).__name__}" + dim = int(arr.size) + return f"bf16 bytes={len(raw_value)} dim={dim} restored_dtype={arr.dtype}" + return f"bf16 decode type={type(arr).__name__}" except Exception: return f"" diff --git a/tests/test_embedding_pipeline.py b/tests/test_embedding_pipeline.py index b08d4aa..8e826cd 100644 --- a/tests/test_embedding_pipeline.py +++ b/tests/test_embedding_pipeline.py @@ -1,4 +1,3 @@ -import pickle from typing import Any, Dict, List, Optional import numpy as np @@ -13,6 +12,7 @@ from config import ( SearchConfig, ) from embeddings.text_encoder import TextEmbeddingEncoder +from embeddings.bf16 import encode_embedding_for_redis from query import QueryParser @@ -128,7 +128,7 @@ def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): def test_text_embedding_encoder_cache_hit(monkeypatch): fake_redis = _FakeRedis() cached = np.array([0.9, 0.8], dtype=np.float32) - fake_redis.store["embedding:cached-text"] = pickle.dumps(cached) + fake_redis.store["embedding:cached-text"] = encode_embedding_for_redis(cached) monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) calls = {"count": 0} -- libgit2 0.21.2