bf16.py 2.53 KB
"""
BF16 (bfloat16) codec helpers for Redis embedding cache.

We store embeddings in Redis as:
  FP32 vector -> (optional L2 normalize) -> BF16 (uint16 per element, big-endian) -> bytes

No backward compatibility is provided by design.
"""

from __future__ import annotations

import struct
from typing import Iterable, List, Sequence

import numpy as np


def float32_to_bf16(value: float) -> int:
    """
    float32 -> bfloat16 (returns 0..65535 uint16)
    Round-to-nearest-even.
    """
    bits = struct.unpack(">I", struct.pack(">f", float(value)))[0]
    rounding_bias = ((bits >> 16) & 1) + 0x7FFF
    bits += rounding_bias
    return (bits >> 16) & 0xFFFF


def bf16_to_float32(bf16: int) -> float:
    """
    bfloat16 -> float32.
    bf16 is an int in 0..65535.
    """
    bits = (int(bf16) & 0xFFFF) << 16
    return struct.unpack(">f", struct.pack(">I", bits))[0]


def float_array_to_bf16(vector: Sequence[float]) -> List[int]:
    return [float32_to_bf16(v) for v in vector]


def bf16_array_to_float(vector_bf16: Sequence[int]) -> List[float]:
    return [bf16_to_float32(v) for v in vector_bf16]


def bf16_list_to_bytes(bf16_list: Sequence[int]) -> bytes:
    """Each bf16 uses 2 bytes big-endian."""
    return b"".join(struct.pack(">H", int(x) & 0xFFFF) for x in bf16_list)


def bytes_to_bf16_list(data: bytes) -> List[int]:
    if len(data) % 2 != 0:
        raise ValueError("BF16 byte length must be even")
    return [struct.unpack(">H", data[i : i + 2])[0] for i in range(0, len(data), 2)]


def encode_embedding_for_redis(embedding: np.ndarray) -> bytes:
    """
    FP32 embedding -> BF16 -> bytes.
    """
    arr = np.asarray(embedding, dtype=np.float32)
    if arr.ndim != 1:
        arr = arr.reshape(-1)
    # Ensure we operate on plain Python floats for the reference codec.
    bf16_list = float_array_to_bf16(arr.tolist())
    return bf16_list_to_bytes(bf16_list)


def decode_embedding_from_redis(data: bytes) -> np.ndarray:
    """
    Redis bytes -> BF16 -> FP32 numpy array.
    """
    bf16_list = bytes_to_bf16_list(data)
    floats = bf16_array_to_float(bf16_list)
    return np.asarray(floats, dtype=np.float32)


def l2_normalize_fp32(vec: np.ndarray) -> np.ndarray:
    """L2-normalize a 1D FP32 vector. Raises on invalid norms."""
    arr = np.asarray(vec, dtype=np.float32)
    if arr.ndim != 1:
        arr = arr.reshape(-1)
    norm = float(np.linalg.norm(arr))
    if not np.isfinite(norm) or norm <= 0.0:
        raise ValueError("Embedding vector has invalid norm (must be > 0)")
    return (arr / norm).astype(np.float32, copy=False)