Blame view

embeddings/image_encoder.py 5.12 KB
be52af70   tangwang   first commit
1
  """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
2
  Image embedding encoder using network service.
be52af70   tangwang   first commit
3
  
7bfb9946   tangwang   向量化模块
4
  Generates embeddings via HTTP API service (default localhost:6005).
be52af70   tangwang   first commit
5
6
7
8
  """
  
  import sys
  import os
be52af70   tangwang   first commit
9
  import requests
be52af70   tangwang   first commit
10
11
12
13
  import numpy as np
  from PIL import Image
  import logging
  import threading
325eec03   tangwang   1. 日志、配置基础设施,使用优化
14
  from typing import List, Optional, Union, Dict, Any
be52af70   tangwang   first commit
15
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
16
  logger = logging.getLogger(__name__)
be52af70   tangwang   first commit
17
  
42e3aea6   tangwang   tidy
18
19
  from config.services_config import get_embedding_base_url
  
be52af70   tangwang   first commit
20
21
22
  
  class CLIPImageEncoder:
      """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
23
      Image Encoder for generating image embeddings using network service.
be52af70   tangwang   first commit
24
25
26
27
28
29
30
  
      Thread-safe singleton pattern.
      """
  
      _instance = None
      _lock = threading.Lock()
  
7bfb9946   tangwang   向量化模块
31
      def __new__(cls, service_url: Optional[str] = None):
be52af70   tangwang   first commit
32
33
34
          with cls._lock:
              if cls._instance is None:
                  cls._instance = super(CLIPImageEncoder, cls).__new__(cls)
42e3aea6   tangwang   tidy
35
                  resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url()
7bfb9946   tangwang   向量化模块
36
37
38
                  logger.info(f"Creating CLIPImageEncoder instance with service URL: {resolved_url}")
                  cls._instance.service_url = resolved_url
                  cls._instance.endpoint = f"{resolved_url}/embed/image"
be52af70   tangwang   first commit
39
40
          return cls._instance
  
7bfb9946   tangwang   向量化模块
41
      def _call_service(self, request_data: List[str]) -> List[Any]:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
42
43
          """
          Call the embedding service API.
be52af70   tangwang   first commit
44
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
45
          Args:
7bfb9946   tangwang   向量化模块
46
              request_data: List of image URLs / local file paths
be52af70   tangwang   first commit
47
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
48
          Returns:
7bfb9946   tangwang   向量化模块
49
              List of embeddings (list[float]) or nulls (None), aligned to input order
325eec03   tangwang   1. 日志、配置基础设施,使用优化
50
          """
be52af70   tangwang   first commit
51
          try:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
52
53
54
55
56
57
58
59
60
61
              response = requests.post(
                  self.endpoint,
                  json=request_data,
                  timeout=60
              )
              response.raise_for_status()
              return response.json()
          except requests.exceptions.RequestException as e:
              logger.error(f"CLIPImageEncoder service request failed: {e}", exc_info=True)
              raise
be52af70   tangwang   first commit
62
63
  
      def encode_image(self, image: Image.Image) -> Optional[np.ndarray]:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
64
65
          """
          Encode image to embedding vector using network service.
be52af70   tangwang   first commit
66
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
67
68
69
70
          Note: This method is kept for compatibility but the service only works with URLs.
          """
          logger.warning("encode_image with PIL Image not supported by service, returning None")
          return None
be52af70   tangwang   first commit
71
72
  
      def encode_image_from_url(self, url: str) -> Optional[np.ndarray]:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
73
74
          """
          Generate image embedding via network service using URL.
be52af70   tangwang   first commit
75
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
76
77
          Args:
              url: Image URL to process
be52af70   tangwang   first commit
78
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
79
80
81
82
          Returns:
              Embedding vector or None if failed
          """
          try:
7bfb9946   tangwang   向量化模块
83
84
85
86
87
              response_data = self._call_service([url])
              if response_data and len(response_data) > 0 and response_data[0] is not None:
                  return np.array(response_data[0], dtype=np.float32)
              logger.warning(f"No embedding for URL {url}")
              return None
be52af70   tangwang   first commit
88
89
  
          except Exception as e:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
90
              logger.error(f"Failed to process image from URL {url}: {str(e)}", exc_info=True)
be52af70   tangwang   first commit
91
92
93
94
95
96
97
98
              return None
  
      def encode_batch(
          self,
          images: List[Union[str, Image.Image]],
          batch_size: int = 8
      ) -> List[Optional[np.ndarray]]:
          """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
99
          Encode a batch of images efficiently via network service.
be52af70   tangwang   first commit
100
101
102
  
          Args:
              images: List of image URLs or PIL Images
325eec03   tangwang   1. 日志、配置基础设施,使用优化
103
              batch_size: Batch size for processing (used for service requests)
be52af70   tangwang   first commit
104
105
106
107
  
          Returns:
              List of embeddings (or None for failed images)
          """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
          # Initialize results with None for all images
          results = [None] * len(images)
  
          # Filter out PIL Images since service only supports URLs
          url_images = []
          url_indices = []
  
          for i, img in enumerate(images):
              if isinstance(img, str):
                  url_images.append(img)
                  url_indices.append(i)
              elif isinstance(img, Image.Image):
                  logger.warning(f"PIL Image at index {i} not supported by service, returning None")
                  # results[i] is already None
  
          # Process URLs in batches
          for i in range(0, len(url_images), batch_size):
              batch_urls = url_images[i:i + batch_size]
              batch_indices = url_indices[i:i + batch_size]
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
128
129
              try:
                  # Call service
7bfb9946   tangwang   向量化模块
130
                  response_data = self._call_service(batch_urls)
325eec03   tangwang   1. 日志、配置基础设施,使用优化
131
  
7bfb9946   tangwang   向量化模块
132
                  # Process response (aligned list)
325eec03   tangwang   1. 日志、配置基础设施,使用优化
133
134
                  batch_results = []
                  for j, url in enumerate(batch_urls):
7bfb9946   tangwang   向量化模块
135
136
                      if response_data and j < len(response_data) and response_data[j] is not None:
                          batch_results.append(np.array(response_data[j], dtype=np.float32))
325eec03   tangwang   1. 日志、配置基础设施,使用优化
137
                      else:
7bfb9946   tangwang   向量化模块
138
                          logger.warning(f"Failed to encode URL {url}: no embedding")
325eec03   tangwang   1. 日志、配置基础设施,使用优化
139
140
141
142
143
144
145
146
147
148
149
                          batch_results.append(None)
  
                  # Insert results at the correct positions
                  for j, result in enumerate(batch_results):
                      results[batch_indices[j]] = result
  
              except Exception as e:
                  logger.error(f"Batch processing failed: {e}", exc_info=True)
                  # Fill with None for this batch
                  for j in range(len(batch_urls)):
                      results[batch_indices[j]] = None
be52af70   tangwang   first commit
150
151
  
          return results