Blame view

embeddings/clip_model.py 5.81 KB
325eec03   tangwang   1. 日志、配置基础设施,使用优化
1
  """
7bfb9946   tangwang   向量化模块
2
  CN-CLIP local image embedding implementation.
325eec03   tangwang   1. 日志、配置基础设施,使用优化
3
  
7bfb9946   tangwang   向量化模块
4
  Internal model implementation used by the embedding service.
325eec03   tangwang   1. 日志、配置基础设施,使用优化
5
6
  """
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
7
  import io
7bfb9946   tangwang   向量化模块
8
9
10
11
  import threading
  from typing import List, Optional, Union
  
  import numpy as np
325eec03   tangwang   1. 日志、配置基础设施,使用优化
12
13
  import requests
  import torch
325eec03   tangwang   1. 日志、配置基础设施,使用优化
14
  from PIL import Image
325eec03   tangwang   1. 日志、配置基础设施,使用优化
15
  from cn_clip.clip import load_from_name
7bfb9946   tangwang   向量化模块
16
  import cn_clip.clip as clip
325eec03   tangwang   1. 日志、配置基础设施,使用优化
17
18
  
  
4747e2f4   tangwang   embedding perform...
19
  DEFAULT_MODEL_NAME = "ViT-L-14" # "ViT-H-14", "ViT-L-14-336"
c10f90fe   tangwang   cnclip
20
  MODEL_DOWNLOAD_DIR = "/data/"
325eec03   tangwang   1. 日志、配置基础设施,使用优化
21
22
  
  
7bfb9946   tangwang   向量化模块
23
  class ClipImageModel(object):
325eec03   tangwang   1. 日志、配置基础设施,使用优化
24
      """
7bfb9946   tangwang   向量化模块
25
      Thread-safe singleton image encoder using cn_clip (local inference).
325eec03   tangwang   1. 日志、配置基础设施,使用优化
26
27
28
29
30
      """
  
      _instance = None
      _lock = threading.Lock()
  
7bfb9946   tangwang   向量化模块
31
      def __new__(cls, model_name: str = DEFAULT_MODEL_NAME, device: Optional[str] = None):
325eec03   tangwang   1. 日志、配置基础设施,使用优化
32
33
          with cls._lock:
              if cls._instance is None:
7bfb9946   tangwang   向量化模块
34
                  cls._instance = super(ClipImageModel, cls).__new__(cls)
325eec03   tangwang   1. 日志、配置基础设施,使用优化
35
36
37
                  cls._instance._initialize_model(model_name, device)
          return cls._instance
  
7bfb9946   tangwang   向量化模块
38
39
40
41
42
43
44
      def _initialize_model(self, model_name: str, device: Optional[str]):
          self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
          self.model, self.preprocess = load_from_name(
              model_name, device=self.device, download_root=MODEL_DOWNLOAD_DIR
          )
          self.model.eval()
          self.model_name = model_name
325eec03   tangwang   1. 日志、配置基础设施,使用优化
45
46
  
      def validate_image(self, image_data: bytes) -> Image.Image:
7bfb9946   tangwang   向量化模块
47
48
49
50
51
52
53
54
          image_stream = io.BytesIO(image_data)
          image = Image.open(image_stream)
          image.verify()
          image_stream.seek(0)
          image = Image.open(image_stream)
          if image.mode != "RGB":
              image = image.convert("RGB")
          return image
325eec03   tangwang   1. 日志、配置基础设施,使用优化
55
56
  
      def download_image(self, url: str, timeout: int = 10) -> bytes:
7bfb9946   tangwang   向量化模块
57
58
59
60
61
62
63
          if url.startswith(("http://", "https://")):
              response = requests.get(url, timeout=timeout)
              if response.status_code != 200:
                  raise ValueError("HTTP %s" % response.status_code)
              return response.content
          with open(url, "rb") as f:
              return f.read()
325eec03   tangwang   1. 日志、配置基础设施,使用优化
64
65
  
      def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
66
          if max(image.size) > max_size:
7bfb9946   tangwang   向量化模块
67
              ratio = float(max_size) / float(max(image.size))
325eec03   tangwang   1. 日志、配置基础设施,使用优化
68
69
70
71
72
              new_size = tuple(int(dim * ratio) for dim in image.size)
              image = image.resize(new_size, Image.Resampling.LANCZOS)
          return image
  
      def encode_text(self, text):
7bfb9946   tangwang   向量化模块
73
          text_data = clip.tokenize([text] if isinstance(text, str) else text).to(self.device)
325eec03   tangwang   1. 日志、配置基础设施,使用优化
74
75
76
77
78
          with torch.no_grad():
              text_features = self.model.encode_text(text_data)
              text_features /= text_features.norm(dim=-1, keepdim=True)
          return text_features
  
200fdddf   tangwang   embed norm
79
      def encode_image(self, image: Image.Image, normalize_embeddings: bool = True) -> Optional[np.ndarray]:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
80
          if not isinstance(image, Image.Image):
7bfb9946   tangwang   向量化模块
81
82
83
84
              raise ValueError("ClipImageModel.encode_image input must be a PIL.Image")
          infer_data = self.preprocess(image).unsqueeze(0).to(self.device)
          with torch.no_grad():
              image_features = self.model.encode_image(infer_data)
200fdddf   tangwang   embed norm
85
86
              if normalize_embeddings:
                  image_features /= image_features.norm(dim=-1, keepdim=True)
7bfb9946   tangwang   向量化模块
87
          return image_features.cpu().numpy().astype("float32")[0]
325eec03   tangwang   1. 日志、配置基础设施,使用优化
88
  
af03fdef   tangwang   embedding模块代码整理
89
      def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> np.ndarray:
7bfb9946   tangwang   向量化模块
90
91
92
          image_data = self.download_image(url)
          image = self.validate_image(image_data)
          image = self.preprocess_image(image)
200fdddf   tangwang   embed norm
93
          return self.encode_image(image, normalize_embeddings=normalize_embeddings)
325eec03   tangwang   1. 日志、配置基础设施,使用优化
94
  
c10f90fe   tangwang   cnclip
95
96
97
98
      def encode_image_urls(
          self,
          urls: List[str],
          batch_size: Optional[int] = None,
200fdddf   tangwang   embed norm
99
          normalize_embeddings: bool = True,
af03fdef   tangwang   embedding模块代码整理
100
      ) -> List[np.ndarray]:
c10f90fe   tangwang   cnclip
101
102
103
104
105
106
107
108
          """
          Encode a list of image URLs to vectors. Same interface as ClipAsServiceImageEncoder.
  
          Args:
              urls: list of image URLs or local paths.
              batch_size: batch size for internal batching (default 8).
  
          Returns:
af03fdef   tangwang   embedding模块代码整理
109
              List of vectors, same length as urls.
c10f90fe   tangwang   cnclip
110
          """
200fdddf   tangwang   embed norm
111
112
113
114
115
          return self.encode_batch(
              urls,
              batch_size=batch_size or 8,
              normalize_embeddings=normalize_embeddings,
          )
c10f90fe   tangwang   cnclip
116
  
200fdddf   tangwang   embed norm
117
118
119
120
121
      def encode_batch(
          self,
          images: List[Union[str, Image.Image]],
          batch_size: int = 8,
          normalize_embeddings: bool = True,
af03fdef   tangwang   embedding模块代码整理
122
123
      ) -> List[np.ndarray]:
          results: List[np.ndarray] = []
325eec03   tangwang   1. 日志、配置基础设施,使用优化
124
          for i in range(0, len(images), batch_size):
7bfb9946   tangwang   向量化模块
125
              batch = images[i : i + batch_size]
325eec03   tangwang   1. 日志、配置基础设施,使用优化
126
127
              for img in batch:
                  if isinstance(img, str):
200fdddf   tangwang   embed norm
128
                      results.append(self.encode_image_from_url(img, normalize_embeddings=normalize_embeddings))
325eec03   tangwang   1. 日志、配置基础设施,使用优化
129
                  elif isinstance(img, Image.Image):
200fdddf   tangwang   embed norm
130
                      results.append(self.encode_image(img, normalize_embeddings=normalize_embeddings))
325eec03   tangwang   1. 日志、配置基础设施,使用优化
131
                  else:
af03fdef   tangwang   embedding模块代码整理
132
                      raise ValueError(f"Unsupported image input type: {type(img)!r}")
7bfb9946   tangwang   向量化模块
133
          return results
7a013ca7   tangwang   多模态文本向量服务ok
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
  
      def encode_clip_texts(
          self,
          texts: List[str],
          batch_size: Optional[int] = None,
          normalize_embeddings: bool = True,
      ) -> List[np.ndarray]:
          """
          CN-CLIP 文本塔向量,与 encode_image 同空间;供 ``POST /embed/clip_text`` 使用。
          """
          if not texts:
              return []
          bs = batch_size or 8
          out: List[np.ndarray] = []
          for i in range(0, len(texts), bs):
              batch = texts[i : i + bs]
              text_data = clip.tokenize(batch).to(self.device)
              with torch.no_grad():
                  feats = self.model.encode_text(text_data)
                  if normalize_embeddings:
                      feats = feats / feats.norm(dim=-1, keepdim=True)
              arr = feats.cpu().numpy().astype("float32")
              for row in arr:
                  out.append(np.asarray(row, dtype=np.float32))
          return out