Blame view

embeddings/bf16.py 2.53 KB
4a37d233   tangwang   1. embedding cach...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  """
  BF16 (bfloat16) codec helpers for Redis embedding cache.
  
  We store embeddings in Redis as:
    FP32 vector -> (optional L2 normalize) -> BF16 (uint16 per element, big-endian) -> bytes
  
  No backward compatibility is provided by design.
  """
  
  from __future__ import annotations
  
  import struct
  from typing import Iterable, List, Sequence
  
  import numpy as np
  
  
  def float32_to_bf16(value: float) -> int:
      """
      float32 -> bfloat16 (returns 0..65535 uint16)
      Round-to-nearest-even.
      """
      bits = struct.unpack(">I", struct.pack(">f", float(value)))[0]
      rounding_bias = ((bits >> 16) & 1) + 0x7FFF
      bits += rounding_bias
      return (bits >> 16) & 0xFFFF
  
  
  def bf16_to_float32(bf16: int) -> float:
      """
      bfloat16 -> float32.
      bf16 is an int in 0..65535.
      """
      bits = (int(bf16) & 0xFFFF) << 16
      return struct.unpack(">f", struct.pack(">I", bits))[0]
  
  
  def float_array_to_bf16(vector: Sequence[float]) -> List[int]:
      return [float32_to_bf16(v) for v in vector]
  
  
  def bf16_array_to_float(vector_bf16: Sequence[int]) -> List[float]:
      return [bf16_to_float32(v) for v in vector_bf16]
  
  
  def bf16_list_to_bytes(bf16_list: Sequence[int]) -> bytes:
      """Each bf16 uses 2 bytes big-endian."""
      return b"".join(struct.pack(">H", int(x) & 0xFFFF) for x in bf16_list)
  
  
  def bytes_to_bf16_list(data: bytes) -> List[int]:
      if len(data) % 2 != 0:
          raise ValueError("BF16 byte length must be even")
      return [struct.unpack(">H", data[i : i + 2])[0] for i in range(0, len(data), 2)]
  
  
  def encode_embedding_for_redis(embedding: np.ndarray) -> bytes:
      """
      FP32 embedding -> BF16 -> bytes.
      """
      arr = np.asarray(embedding, dtype=np.float32)
      if arr.ndim != 1:
          arr = arr.reshape(-1)
      # Ensure we operate on plain Python floats for the reference codec.
      bf16_list = float_array_to_bf16(arr.tolist())
      return bf16_list_to_bytes(bf16_list)
  
  
  def decode_embedding_from_redis(data: bytes) -> np.ndarray:
      """
      Redis bytes -> BF16 -> FP32 numpy array.
      """
      bf16_list = bytes_to_bf16_list(data)
      floats = bf16_array_to_float(bf16_list)
      return np.asarray(floats, dtype=np.float32)
  
  
  def l2_normalize_fp32(vec: np.ndarray) -> np.ndarray:
      """L2-normalize a 1D FP32 vector. Raises on invalid norms."""
      arr = np.asarray(vec, dtype=np.float32)
      if arr.ndim != 1:
          arr = arr.reshape(-1)
      norm = float(np.linalg.norm(arr))
      if not np.isfinite(norm) or norm <= 0.0:
          raise ValueError("Embedding vector has invalid norm (must be > 0)")
      return (arr / norm).astype(np.float32, copy=False)