image_encoder__local.py 6.09 KB
"""
Image embedding encoder using CN-CLIP model.

Generates 1024-dimensional vectors for images using the CN-CLIP ViT-H-14 model.
"""

import sys
import os
import io
import requests
import torch
import numpy as np
from PIL import Image
import logging
import threading
from typing import List, Optional, Union
import cn_clip.clip as clip
from cn_clip.clip import load_from_name


# DEFAULT_MODEL_NAME = "ViT-L-14-336" # ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"]
DEFAULT_MODEL_NAME = "ViT-H-14"
MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher"


class CLIPImageEncoder:
    """
    CLIP Image Encoder for generating image embeddings using cn_clip.

    Thread-safe singleton pattern.
    """

    _instance = None
    _lock = threading.Lock()

    def __new__(cls, model_name=DEFAULT_MODEL_NAME, device=None):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super(CLIPImageEncoder, cls).__new__(cls)
                print(f"[CLIPImageEncoder] Creating new instance with model: {model_name}")
                cls._instance._initialize_model(model_name, device)
        return cls._instance

    def _initialize_model(self, model_name, device):
        """Initialize the CLIP model using cn_clip"""
        try:
            self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
            self.model, self.preprocess = load_from_name(
                model_name,
                device=self.device,
                download_root=MODEL_DOWNLOAD_DIR
            )
            self.model.eval()
            self.model_name = model_name
            print(f"[CLIPImageEncoder] Model {model_name} initialized successfully on device {self.device}")

        except Exception as e:
            print(f"[CLIPImageEncoder] Failed to initialize model: {str(e)}")
            raise

    def validate_image(self, image_data: bytes) -> Image.Image:
        """Validate image data and return PIL Image if valid"""
        try:
            image_stream = io.BytesIO(image_data)
            image = Image.open(image_stream)
            image.verify()
            image_stream.seek(0)
            image = Image.open(image_stream)
            if image.mode != 'RGB':
                image = image.convert('RGB')
            return image
        except Exception as e:
            raise ValueError(f"Invalid image data: {str(e)}")

    def download_image(self, url: str, timeout: int = 10) -> bytes:
        """Download image from URL"""
        try:
            if url.startswith(('http://', 'https://')):
                response = requests.get(url, timeout=timeout)
                if response.status_code != 200:
                    raise ValueError(f"HTTP {response.status_code}")
                return response.content
            else:
                # Local file path
                with open(url, 'rb') as f:
                    return f.read()
        except Exception as e:
            raise ValueError(f"Failed to download image from {url}: {str(e)}")

    def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image:
        """Preprocess image for CLIP model"""
        # Resize if too large
        if max(image.size) > max_size:
            ratio = max_size / max(image.size)
            new_size = tuple(int(dim * ratio) for dim in image.size)
            image = image.resize(new_size, Image.Resampling.LANCZOS)
        return image

    def encode_text(self, text):
        """Encode text to embedding vector using cn_clip"""
        text_data = clip.tokenize([text] if type(text) == str else text).to(self.device)
        with torch.no_grad():
            text_features = self.model.encode_text(text_data)
            text_features /= text_features.norm(dim=-1, keepdim=True)
        return text_features

    def encode_image(self, image: Image.Image) -> Optional[np.ndarray]:
        """Encode image to embedding vector using cn_clip"""
        if not isinstance(image, Image.Image):
            raise ValueError("CLIPImageEncoder.encode_image Input must be a PIL.Image")

        try:
            infer_data = self.preprocess(image).unsqueeze(0).to(self.device)
            with torch.no_grad():
                image_features = self.model.encode_image(infer_data)
                image_features /= image_features.norm(dim=-1, keepdim=True)
            return image_features.cpu().numpy().astype('float32')[0]
        except Exception as e:
            print(f"Failed to process image. Reason: {str(e)}")
            return None

    def encode_image_from_url(self, url: str) -> Optional[np.ndarray]:
        """Complete pipeline: download, validate, preprocess and encode image from URL"""
        try:
            # Download image
            image_data = self.download_image(url)

            # Validate image
            image = self.validate_image(image_data)

            # Preprocess image
            image = self.preprocess_image(image)

            # Encode image
            embedding = self.encode_image(image)

            return embedding

        except Exception as e:
            print(f"Error processing image from URL {url}: {str(e)}")
            return None

    def encode_batch(
        self,
        images: List[Union[str, Image.Image]],
        batch_size: int = 8
    ) -> List[Optional[np.ndarray]]:
        """
        Encode a batch of images efficiently.

        Args:
            images: List of image URLs or PIL Images
            batch_size: Batch size for processing

        Returns:
            List of embeddings (or None for failed images)
        """
        results = []

        for i in range(0, len(images), batch_size):
            batch = images[i:i + batch_size]
            batch_embeddings = []

            for img in batch:
                if isinstance(img, str):
                    # URL or file path
                    emb = self.encode_image_from_url(img)
                elif isinstance(img, Image.Image):
                    # PIL Image
                    emb = self.encode_image(img)
                else:
                    emb = None

                batch_embeddings.append(emb)

            results.extend(batch_embeddings)

        return results