4a37d233
tangwang
1. embedding cach...
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
"""
BF16 (bfloat16) codec helpers for Redis embedding cache.
We store embeddings in Redis as:
FP32 vector -> (optional L2 normalize) -> BF16 (uint16 per element, big-endian) -> bytes
No backward compatibility is provided by design.
"""
from __future__ import annotations
import struct
from typing import Iterable, List, Sequence
import numpy as np
def float32_to_bf16(value: float) -> int:
"""
float32 -> bfloat16 (returns 0..65535 uint16)
Round-to-nearest-even.
"""
bits = struct.unpack(">I", struct.pack(">f", float(value)))[0]
rounding_bias = ((bits >> 16) & 1) + 0x7FFF
bits += rounding_bias
return (bits >> 16) & 0xFFFF
def bf16_to_float32(bf16: int) -> float:
"""
bfloat16 -> float32.
bf16 is an int in 0..65535.
"""
bits = (int(bf16) & 0xFFFF) << 16
return struct.unpack(">f", struct.pack(">I", bits))[0]
def float_array_to_bf16(vector: Sequence[float]) -> List[int]:
return [float32_to_bf16(v) for v in vector]
def bf16_array_to_float(vector_bf16: Sequence[int]) -> List[float]:
return [bf16_to_float32(v) for v in vector_bf16]
def bf16_list_to_bytes(bf16_list: Sequence[int]) -> bytes:
"""Each bf16 uses 2 bytes big-endian."""
return b"".join(struct.pack(">H", int(x) & 0xFFFF) for x in bf16_list)
def bytes_to_bf16_list(data: bytes) -> List[int]:
if len(data) % 2 != 0:
raise ValueError("BF16 byte length must be even")
return [struct.unpack(">H", data[i : i + 2])[0] for i in range(0, len(data), 2)]
def encode_embedding_for_redis(embedding: np.ndarray) -> bytes:
"""
FP32 embedding -> BF16 -> bytes.
"""
arr = np.asarray(embedding, dtype=np.float32)
if arr.ndim != 1:
arr = arr.reshape(-1)
# Ensure we operate on plain Python floats for the reference codec.
bf16_list = float_array_to_bf16(arr.tolist())
return bf16_list_to_bytes(bf16_list)
def decode_embedding_from_redis(data: bytes) -> np.ndarray:
"""
Redis bytes -> BF16 -> FP32 numpy array.
"""
bf16_list = bytes_to_bf16_list(data)
floats = bf16_array_to_float(bf16_list)
return np.asarray(floats, dtype=np.float32)
def l2_normalize_fp32(vec: np.ndarray) -> np.ndarray:
"""L2-normalize a 1D FP32 vector. Raises on invalid norms."""
arr = np.asarray(vec, dtype=np.float32)
if arr.ndim != 1:
arr = arr.reshape(-1)
norm = float(np.linalg.norm(arr))
if not np.isfinite(norm) or norm <= 0.0:
raise ValueError("Embedding vector has invalid norm (must be > 0)")
return (arr / norm).astype(np.float32, copy=False)
|