""" BF16 (bfloat16) codec helpers for Redis embedding cache. We store embeddings in Redis as: FP32 vector -> (optional L2 normalize) -> BF16 (uint16 per element, big-endian) -> bytes No backward compatibility is provided by design. """ from __future__ import annotations import struct from typing import Iterable, List, Sequence import numpy as np def float32_to_bf16(value: float) -> int: """ float32 -> bfloat16 (returns 0..65535 uint16) Round-to-nearest-even. """ bits = struct.unpack(">I", struct.pack(">f", float(value)))[0] rounding_bias = ((bits >> 16) & 1) + 0x7FFF bits += rounding_bias return (bits >> 16) & 0xFFFF def bf16_to_float32(bf16: int) -> float: """ bfloat16 -> float32. bf16 is an int in 0..65535. """ bits = (int(bf16) & 0xFFFF) << 16 return struct.unpack(">f", struct.pack(">I", bits))[0] def float_array_to_bf16(vector: Sequence[float]) -> List[int]: return [float32_to_bf16(v) for v in vector] def bf16_array_to_float(vector_bf16: Sequence[int]) -> List[float]: return [bf16_to_float32(v) for v in vector_bf16] def bf16_list_to_bytes(bf16_list: Sequence[int]) -> bytes: """Each bf16 uses 2 bytes big-endian.""" return b"".join(struct.pack(">H", int(x) & 0xFFFF) for x in bf16_list) def bytes_to_bf16_list(data: bytes) -> List[int]: if len(data) % 2 != 0: raise ValueError("BF16 byte length must be even") return [struct.unpack(">H", data[i : i + 2])[0] for i in range(0, len(data), 2)] def encode_embedding_for_redis(embedding: np.ndarray) -> bytes: """ FP32 embedding -> BF16 -> bytes. """ arr = np.asarray(embedding, dtype=np.float32) if arr.ndim != 1: arr = arr.reshape(-1) # Ensure we operate on plain Python floats for the reference codec. bf16_list = float_array_to_bf16(arr.tolist()) return bf16_list_to_bytes(bf16_list) def decode_embedding_from_redis(data: bytes) -> np.ndarray: """ Redis bytes -> BF16 -> FP32 numpy array. """ bf16_list = bytes_to_bf16_list(data) floats = bf16_array_to_float(bf16_list) return np.asarray(floats, dtype=np.float32) def l2_normalize_fp32(vec: np.ndarray) -> np.ndarray: """L2-normalize a 1D FP32 vector. Raises on invalid norms.""" arr = np.asarray(vec, dtype=np.float32) if arr.ndim != 1: arr = arr.reshape(-1) norm = float(np.linalg.norm(arr)) if not np.isfinite(norm) or norm <= 0.0: raise ValueError("Embedding vector has invalid norm (must be > 0)") return (arr / norm).astype(np.float32, copy=False)