""" Image embedding encoder using CN-CLIP model. Generates 1024-dimensional vectors for images using the CN-CLIP ViT-H-14 model. """ import sys import os import io import requests import torch import numpy as np from PIL import Image import logging import threading from typing import List, Optional, Union import cn_clip.clip as clip from cn_clip.clip import load_from_name # DEFAULT_MODEL_NAME = "ViT-L-14-336" # ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"] DEFAULT_MODEL_NAME = "ViT-H-14" MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" class CLIPImageEncoder: """ CLIP Image Encoder for generating image embeddings using cn_clip. Thread-safe singleton pattern. """ _instance = None _lock = threading.Lock() def __new__(cls, model_name=DEFAULT_MODEL_NAME, device=None): with cls._lock: if cls._instance is None: cls._instance = super(CLIPImageEncoder, cls).__new__(cls) print(f"[CLIPImageEncoder] Creating new instance with model: {model_name}") cls._instance._initialize_model(model_name, device) return cls._instance def _initialize_model(self, model_name, device): """Initialize the CLIP model using cn_clip""" try: self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") self.model, self.preprocess = load_from_name( model_name, device=self.device, download_root=MODEL_DOWNLOAD_DIR ) self.model.eval() self.model_name = model_name print(f"[CLIPImageEncoder] Model {model_name} initialized successfully on device {self.device}") except Exception as e: print(f"[CLIPImageEncoder] Failed to initialize model: {str(e)}") raise def validate_image(self, image_data: bytes) -> Image.Image: """Validate image data and return PIL Image if valid""" try: image_stream = io.BytesIO(image_data) image = Image.open(image_stream) image.verify() image_stream.seek(0) image = Image.open(image_stream) if image.mode != 'RGB': image = image.convert('RGB') return image except Exception as e: raise ValueError(f"Invalid image data: {str(e)}") def download_image(self, url: str, timeout: int = 10) -> bytes: """Download image from URL""" try: if url.startswith(('http://', 'https://')): response = requests.get(url, timeout=timeout) if response.status_code != 200: raise ValueError(f"HTTP {response.status_code}") return response.content else: # Local file path with open(url, 'rb') as f: return f.read() except Exception as e: raise ValueError(f"Failed to download image from {url}: {str(e)}") def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: """Preprocess image for CLIP model""" # Resize if too large if max(image.size) > max_size: ratio = max_size / max(image.size) new_size = tuple(int(dim * ratio) for dim in image.size) image = image.resize(new_size, Image.Resampling.LANCZOS) return image def encode_text(self, text): """Encode text to embedding vector using cn_clip""" text_data = clip.tokenize([text] if type(text) == str else text).to(self.device) with torch.no_grad(): text_features = self.model.encode_text(text_data) text_features /= text_features.norm(dim=-1, keepdim=True) return text_features def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: """Encode image to embedding vector using cn_clip""" if not isinstance(image, Image.Image): raise ValueError("CLIPImageEncoder.encode_image Input must be a PIL.Image") try: infer_data = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): image_features = self.model.encode_image(infer_data) image_features /= image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy().astype('float32')[0] except Exception as e: print(f"Failed to process image. Reason: {str(e)}") return None def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: """Complete pipeline: download, validate, preprocess and encode image from URL""" try: # Download image image_data = self.download_image(url) # Validate image image = self.validate_image(image_data) # Preprocess image image = self.preprocess_image(image) # Encode image embedding = self.encode_image(image) return embedding except Exception as e: print(f"Error processing image from URL {url}: {str(e)}") return None def encode_batch( self, images: List[Union[str, Image.Image]], batch_size: int = 8 ) -> List[Optional[np.ndarray]]: """ Encode a batch of images efficiently. Args: images: List of image URLs or PIL Images batch_size: Batch size for processing Returns: List of embeddings (or None for failed images) """ results = [] for i in range(0, len(images), batch_size): batch = images[i:i + batch_size] batch_embeddings = [] for img in batch: if isinstance(img, str): # URL or file path emb = self.encode_image_from_url(img) elif isinstance(img, Image.Image): # PIL Image emb = self.encode_image(img) else: emb = None batch_embeddings.append(emb) results.extend(batch_embeddings) return results