clip_model.py
5.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
CN-CLIP local image embedding implementation.
Internal model implementation used by the embedding service.
"""
import io
import threading
from typing import List, Optional, Union
import numpy as np
import requests
import torch
from PIL import Image
from cn_clip.clip import load_from_name
import cn_clip.clip as clip
DEFAULT_MODEL_NAME = "ViT-H-14" # ViT-H-14: 1024-dim; ViT-L-14: 768-dim — 须与 config 与 ES image_embedding.dims 一致
MODEL_DOWNLOAD_DIR = "/data/"
class ClipImageModel(object):
"""
Thread-safe singleton image encoder using cn_clip (local inference).
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, model_name: str = DEFAULT_MODEL_NAME, device: Optional[str] = None):
with cls._lock:
if cls._instance is None:
cls._instance = super(ClipImageModel, cls).__new__(cls)
cls._instance._initialize_model(model_name, device)
return cls._instance
def _initialize_model(self, model_name: str, device: Optional[str]):
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.model, self.preprocess = load_from_name(
model_name, device=self.device, download_root=MODEL_DOWNLOAD_DIR
)
self.model.eval()
self.model_name = model_name
def validate_image(self, image_data: bytes) -> Image.Image:
image_stream = io.BytesIO(image_data)
image = Image.open(image_stream)
image.verify()
image_stream.seek(0)
image = Image.open(image_stream)
if image.mode != "RGB":
image = image.convert("RGB")
return image
def download_image(self, url: str, timeout: int = 10) -> bytes:
if url.startswith(("http://", "https://")):
response = requests.get(url, timeout=timeout)
if response.status_code != 200:
raise ValueError("HTTP %s" % response.status_code)
return response.content
with open(url, "rb") as f:
return f.read()
def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image:
if max(image.size) > max_size:
ratio = float(max_size) / float(max(image.size))
new_size = tuple(int(dim * ratio) for dim in image.size)
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
def encode_text(self, text):
text_data = clip.tokenize([text] if isinstance(text, str) else text).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_data)
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features
def encode_image(self, image: Image.Image, normalize_embeddings: bool = True) -> Optional[np.ndarray]:
if not isinstance(image, Image.Image):
raise ValueError("ClipImageModel.encode_image input must be a PIL.Image")
infer_data = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(infer_data)
if normalize_embeddings:
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().astype("float32")[0]
def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> np.ndarray:
image_data = self.download_image(url)
image = self.validate_image(image_data)
image = self.preprocess_image(image)
return self.encode_image(image, normalize_embeddings=normalize_embeddings)
def encode_image_urls(
self,
urls: List[str],
batch_size: Optional[int] = None,
normalize_embeddings: bool = True,
) -> List[np.ndarray]:
"""
Encode a list of image URLs to vectors. Same interface as ClipAsServiceImageEncoder.
Args:
urls: list of image URLs or local paths.
batch_size: batch size for internal batching (default 8).
Returns:
List of vectors, same length as urls.
"""
return self.encode_batch(
urls,
batch_size=batch_size or 8,
normalize_embeddings=normalize_embeddings,
)
def encode_batch(
self,
images: List[Union[str, Image.Image]],
batch_size: int = 8,
normalize_embeddings: bool = True,
) -> List[np.ndarray]:
results: List[np.ndarray] = []
for i in range(0, len(images), batch_size):
batch = images[i : i + batch_size]
for img in batch:
if isinstance(img, str):
results.append(self.encode_image_from_url(img, normalize_embeddings=normalize_embeddings))
elif isinstance(img, Image.Image):
results.append(self.encode_image(img, normalize_embeddings=normalize_embeddings))
else:
raise ValueError(f"Unsupported image input type: {type(img)!r}")
return results
def encode_clip_texts(
self,
texts: List[str],
batch_size: Optional[int] = None,
normalize_embeddings: bool = True,
) -> List[np.ndarray]:
"""
CN-CLIP 文本塔向量,与 encode_image 同空间;供 ``POST /embed/clip_text`` 使用。
"""
if not texts:
return []
bs = batch_size or 8
out: List[np.ndarray] = []
for i in range(0, len(texts), bs):
batch = texts[i : i + bs]
text_data = clip.tokenize(batch).to(self.device)
with torch.no_grad():
feats = self.model.encode_text(text_data)
if normalize_embeddings:
feats = feats / feats.norm(dim=-1, keepdim=True)
arr = feats.cpu().numpy().astype("float32")
for row in arr:
out.append(np.asarray(row, dtype=np.float32))
return out