clip_model.py 5.88 KB
"""
CN-CLIP local image embedding implementation.

Internal model implementation used by the embedding service.
"""

import io
import threading
from typing import List, Optional, Union

import numpy as np
import requests
import torch
from PIL import Image
from cn_clip.clip import load_from_name
import cn_clip.clip as clip


DEFAULT_MODEL_NAME = "ViT-H-14"  # ViT-H-14: 1024-dim; ViT-L-14: 768-dim — 须与 config 与 ES image_embedding.dims 一致
MODEL_DOWNLOAD_DIR = "/data/"


class ClipImageModel(object):
    """
    Thread-safe singleton image encoder using cn_clip (local inference).
    """

    _instance = None
    _lock = threading.Lock()

    def __new__(cls, model_name: str = DEFAULT_MODEL_NAME, device: Optional[str] = None):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super(ClipImageModel, cls).__new__(cls)
                cls._instance._initialize_model(model_name, device)
        return cls._instance

    def _initialize_model(self, model_name: str, device: Optional[str]):
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.model, self.preprocess = load_from_name(
            model_name, device=self.device, download_root=MODEL_DOWNLOAD_DIR
        )
        self.model.eval()
        self.model_name = model_name

    def validate_image(self, image_data: bytes) -> Image.Image:
        image_stream = io.BytesIO(image_data)
        image = Image.open(image_stream)
        image.verify()
        image_stream.seek(0)
        image = Image.open(image_stream)
        if image.mode != "RGB":
            image = image.convert("RGB")
        return image

    def download_image(self, url: str, timeout: int = 10) -> bytes:
        if url.startswith(("http://", "https://")):
            response = requests.get(url, timeout=timeout)
            if response.status_code != 200:
                raise ValueError("HTTP %s" % response.status_code)
            return response.content
        with open(url, "rb") as f:
            return f.read()

    def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image:
        if max(image.size) > max_size:
            ratio = float(max_size) / float(max(image.size))
            new_size = tuple(int(dim * ratio) for dim in image.size)
            image = image.resize(new_size, Image.Resampling.LANCZOS)
        return image

    def encode_text(self, text):
        text_data = clip.tokenize([text] if isinstance(text, str) else text).to(self.device)
        with torch.no_grad():
            text_features = self.model.encode_text(text_data)
            text_features /= text_features.norm(dim=-1, keepdim=True)
        return text_features

    def encode_image(self, image: Image.Image, normalize_embeddings: bool = True) -> Optional[np.ndarray]:
        if not isinstance(image, Image.Image):
            raise ValueError("ClipImageModel.encode_image input must be a PIL.Image")
        infer_data = self.preprocess(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            image_features = self.model.encode_image(infer_data)
            if normalize_embeddings:
                image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features.cpu().numpy().astype("float32")[0]

    def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> np.ndarray:
        image_data = self.download_image(url)
        image = self.validate_image(image_data)
        image = self.preprocess_image(image)
        return self.encode_image(image, normalize_embeddings=normalize_embeddings)

    def encode_image_urls(
        self,
        urls: List[str],
        batch_size: Optional[int] = None,
        normalize_embeddings: bool = True,
    ) -> List[np.ndarray]:
        """
        Encode a list of image URLs to vectors. Same interface as ClipAsServiceImageEncoder.

        Args:
            urls: list of image URLs or local paths.
            batch_size: batch size for internal batching (default 8).

        Returns:
            List of vectors, same length as urls.
        """
        return self.encode_batch(
            urls,
            batch_size=batch_size or 8,
            normalize_embeddings=normalize_embeddings,
        )

    def encode_batch(
        self,
        images: List[Union[str, Image.Image]],
        batch_size: int = 8,
        normalize_embeddings: bool = True,
    ) -> List[np.ndarray]:
        results: List[np.ndarray] = []
        for i in range(0, len(images), batch_size):
            batch = images[i : i + batch_size]
            for img in batch:
                if isinstance(img, str):
                    results.append(self.encode_image_from_url(img, normalize_embeddings=normalize_embeddings))
                elif isinstance(img, Image.Image):
                    results.append(self.encode_image(img, normalize_embeddings=normalize_embeddings))
                else:
                    raise ValueError(f"Unsupported image input type: {type(img)!r}")
        return results

    def encode_clip_texts(
        self,
        texts: List[str],
        batch_size: Optional[int] = None,
        normalize_embeddings: bool = True,
    ) -> List[np.ndarray]:
        """
        CN-CLIP 文本塔向量,与 encode_image 同空间;供 ``POST /embed/clip_text`` 使用。
        """
        if not texts:
            return []
        bs = batch_size or 8
        out: List[np.ndarray] = []
        for i in range(0, len(texts), bs):
            batch = texts[i : i + bs]
            text_data = clip.tokenize(batch).to(self.device)
            with torch.no_grad():
                feats = self.model.encode_text(text_data)
                if normalize_embeddings:
                    feats = feats / feats.norm(dim=-1, keepdim=True)
            arr = feats.cpu().numpy().astype("float32")
            for row in arr:
                out.append(np.asarray(row, dtype=np.float32))
        return out