Blame view

embeddings/image_encoder.py 5.34 KB
950a640e   tangwang   embeddings
1
  """Image embedding client for the local embedding HTTP service."""
be52af70   tangwang   first commit
2
  
be52af70   tangwang   first commit
3
  import os
950a640e   tangwang   embeddings
4
5
6
  import logging
  from typing import Any, List, Optional, Union
  
be52af70   tangwang   first commit
7
  import numpy as np
950a640e   tangwang   embeddings
8
  import requests
be52af70   tangwang   first commit
9
  from PIL import Image
be52af70   tangwang   first commit
10
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
11
  logger = logging.getLogger(__name__)
be52af70   tangwang   first commit
12
  
42e3aea6   tangwang   tidy
13
14
  from config.services_config import get_embedding_base_url
  
be52af70   tangwang   first commit
15
16
17
  
  class CLIPImageEncoder:
      """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
18
      Image Encoder for generating image embeddings using network service.
be52af70   tangwang   first commit
19
  
950a640e   tangwang   embeddings
20
      This client is stateless and safe to instantiate per caller.
be52af70   tangwang   first commit
21
22
      """
  
950a640e   tangwang   embeddings
23
24
25
26
27
      def __init__(self, service_url: Optional[str] = None):
          resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url()
          self.service_url = str(resolved_url).rstrip("/")
          self.endpoint = f"{self.service_url}/embed/image"
          logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url)
be52af70   tangwang   first commit
28
  
7bfb9946   tangwang   向量化模块
29
      def _call_service(self, request_data: List[str]) -> List[Any]:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
30
31
          """
          Call the embedding service API.
be52af70   tangwang   first commit
32
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
33
          Args:
7bfb9946   tangwang   向量化模块
34
              request_data: List of image URLs / local file paths
be52af70   tangwang   first commit
35
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
36
          Returns:
7bfb9946   tangwang   向量化模块
37
              List of embeddings (list[float]) or nulls (None), aligned to input order
325eec03   tangwang   1. 日志、配置基础设施,使用优化
38
          """
be52af70   tangwang   first commit
39
          try:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
40
41
42
43
44
45
46
47
48
49
              response = requests.post(
                  self.endpoint,
                  json=request_data,
                  timeout=60
              )
              response.raise_for_status()
              return response.json()
          except requests.exceptions.RequestException as e:
              logger.error(f"CLIPImageEncoder service request failed: {e}", exc_info=True)
              raise
be52af70   tangwang   first commit
50
51
  
      def encode_image(self, image: Image.Image) -> Optional[np.ndarray]:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
52
53
          """
          Encode image to embedding vector using network service.
be52af70   tangwang   first commit
54
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
55
56
57
58
          Note: This method is kept for compatibility but the service only works with URLs.
          """
          logger.warning("encode_image with PIL Image not supported by service, returning None")
          return None
be52af70   tangwang   first commit
59
60
  
      def encode_image_from_url(self, url: str) -> Optional[np.ndarray]:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
61
62
          """
          Generate image embedding via network service using URL.
be52af70   tangwang   first commit
63
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
64
65
          Args:
              url: Image URL to process
be52af70   tangwang   first commit
66
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
67
68
69
70
          Returns:
              Embedding vector or None if failed
          """
          try:
7bfb9946   tangwang   向量化模块
71
72
73
74
75
              response_data = self._call_service([url])
              if response_data and len(response_data) > 0 and response_data[0] is not None:
                  return np.array(response_data[0], dtype=np.float32)
              logger.warning(f"No embedding for URL {url}")
              return None
be52af70   tangwang   first commit
76
77
  
          except Exception as e:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
78
              logger.error(f"Failed to process image from URL {url}: {str(e)}", exc_info=True)
be52af70   tangwang   first commit
79
80
81
82
83
84
85
86
              return None
  
      def encode_batch(
          self,
          images: List[Union[str, Image.Image]],
          batch_size: int = 8
      ) -> List[Optional[np.ndarray]]:
          """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
87
          Encode a batch of images efficiently via network service.
be52af70   tangwang   first commit
88
89
90
  
          Args:
              images: List of image URLs or PIL Images
325eec03   tangwang   1. 日志、配置基础设施,使用优化
91
              batch_size: Batch size for processing (used for service requests)
be52af70   tangwang   first commit
92
93
94
95
  
          Returns:
              List of embeddings (or None for failed images)
          """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
          # Initialize results with None for all images
          results = [None] * len(images)
  
          # Filter out PIL Images since service only supports URLs
          url_images = []
          url_indices = []
  
          for i, img in enumerate(images):
              if isinstance(img, str):
                  url_images.append(img)
                  url_indices.append(i)
              elif isinstance(img, Image.Image):
                  logger.warning(f"PIL Image at index {i} not supported by service, returning None")
                  # results[i] is already None
  
          # Process URLs in batches
          for i in range(0, len(url_images), batch_size):
              batch_urls = url_images[i:i + batch_size]
              batch_indices = url_indices[i:i + batch_size]
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
116
117
              try:
                  # Call service
7bfb9946   tangwang   向量化模块
118
                  response_data = self._call_service(batch_urls)
325eec03   tangwang   1. 日志、配置基础设施,使用优化
119
  
7bfb9946   tangwang   向量化模块
120
                  # Process response (aligned list)
325eec03   tangwang   1. 日志、配置基础设施,使用优化
121
122
                  batch_results = []
                  for j, url in enumerate(batch_urls):
7bfb9946   tangwang   向量化模块
123
124
                      if response_data and j < len(response_data) and response_data[j] is not None:
                          batch_results.append(np.array(response_data[j], dtype=np.float32))
325eec03   tangwang   1. 日志、配置基础设施,使用优化
125
                      else:
7bfb9946   tangwang   向量化模块
126
                          logger.warning(f"Failed to encode URL {url}: no embedding")
325eec03   tangwang   1. 日志、配置基础设施,使用优化
127
128
129
130
131
132
133
134
135
136
137
                          batch_results.append(None)
  
                  # Insert results at the correct positions
                  for j, result in enumerate(batch_results):
                      results[batch_indices[j]] = result
  
              except Exception as e:
                  logger.error(f"Batch processing failed: {e}", exc_info=True)
                  # Fill with None for this batch
                  for j in range(len(batch_urls)):
                      results[batch_indices[j]] = None
be52af70   tangwang   first commit
138
139
  
          return results
e7a2c0b7   tangwang   img encode
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
  
      def encode_image_urls(
          self,
          urls: List[str],
          batch_size: Optional[int] = None,
      ) -> List[Optional[np.ndarray]]:
          """
           ClipImageModel / ClipAsServiceImageEncoder 一致的接口,供索引器 document_transformer 调用。
  
          Args:
              urls: 图片 URL 列表
              batch_size: 批大小(默认 8
  
          Returns:
               urls 等长的向量列表,失败为 None
          """
          return self.encode_batch(urls, batch_size=batch_size or 8)