""" CN-CLIP local image embedding implementation. Internal model implementation used by the embedding service. """ import io import threading from typing import List, Optional, Union import numpy as np import requests import torch from PIL import Image from cn_clip.clip import load_from_name import cn_clip.clip as clip DEFAULT_MODEL_NAME = "ViT-H-14" MODEL_DOWNLOAD_DIR = "/data/" class ClipImageModel(object): """ Thread-safe singleton image encoder using cn_clip (local inference). """ _instance = None _lock = threading.Lock() def __new__(cls, model_name: str = DEFAULT_MODEL_NAME, device: Optional[str] = None): with cls._lock: if cls._instance is None: cls._instance = super(ClipImageModel, cls).__new__(cls) cls._instance._initialize_model(model_name, device) return cls._instance def _initialize_model(self, model_name: str, device: Optional[str]): self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") self.model, self.preprocess = load_from_name( model_name, device=self.device, download_root=MODEL_DOWNLOAD_DIR ) self.model.eval() self.model_name = model_name def validate_image(self, image_data: bytes) -> Image.Image: image_stream = io.BytesIO(image_data) image = Image.open(image_stream) image.verify() image_stream.seek(0) image = Image.open(image_stream) if image.mode != "RGB": image = image.convert("RGB") return image def download_image(self, url: str, timeout: int = 10) -> bytes: if url.startswith(("http://", "https://")): response = requests.get(url, timeout=timeout) if response.status_code != 200: raise ValueError("HTTP %s" % response.status_code) return response.content with open(url, "rb") as f: return f.read() def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: if max(image.size) > max_size: ratio = float(max_size) / float(max(image.size)) new_size = tuple(int(dim * ratio) for dim in image.size) image = image.resize(new_size, Image.Resampling.LANCZOS) return image def encode_text(self, text): text_data = clip.tokenize([text] if isinstance(text, str) else text).to(self.device) with torch.no_grad(): text_features = self.model.encode_text(text_data) text_features /= text_features.norm(dim=-1, keepdim=True) return text_features def encode_image(self, image: Image.Image, normalize_embeddings: bool = True) -> Optional[np.ndarray]: if not isinstance(image, Image.Image): raise ValueError("ClipImageModel.encode_image input must be a PIL.Image") infer_data = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): image_features = self.model.encode_image(infer_data) if normalize_embeddings: image_features /= image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy().astype("float32")[0] def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> Optional[np.ndarray]: image_data = self.download_image(url) image = self.validate_image(image_data) image = self.preprocess_image(image) return self.encode_image(image, normalize_embeddings=normalize_embeddings) def encode_image_urls( self, urls: List[str], batch_size: Optional[int] = None, normalize_embeddings: bool = True, ) -> List[Optional[np.ndarray]]: """ Encode a list of image URLs to vectors. Same interface as ClipAsServiceImageEncoder. Args: urls: list of image URLs or local paths. batch_size: batch size for internal batching (default 8). Returns: List of vectors (or None for failed items), same length as urls. """ return self.encode_batch( urls, batch_size=batch_size or 8, normalize_embeddings=normalize_embeddings, ) def encode_batch( self, images: List[Union[str, Image.Image]], batch_size: int = 8, normalize_embeddings: bool = True, ) -> List[Optional[np.ndarray]]: results: List[Optional[np.ndarray]] = [] for i in range(0, len(images), batch_size): batch = images[i : i + batch_size] for img in batch: if isinstance(img, str): results.append(self.encode_image_from_url(img, normalize_embeddings=normalize_embeddings)) elif isinstance(img, Image.Image): results.append(self.encode_image(img, normalize_embeddings=normalize_embeddings)) else: results.append(None) return results