Blame view

embeddings/image_encoder.py 6.71 KB
950a640e   tangwang   embeddings
1
  """Image embedding client for the local embedding HTTP service."""
be52af70   tangwang   first commit
2
  
be52af70   tangwang   first commit
3
  import os
950a640e   tangwang   embeddings
4
5
6
  import logging
  from typing import Any, List, Optional, Union
  
be52af70   tangwang   first commit
7
  import numpy as np
950a640e   tangwang   embeddings
8
  import requests
be52af70   tangwang   first commit
9
  from PIL import Image
be52af70   tangwang   first commit
10
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
11
  logger = logging.getLogger(__name__)
be52af70   tangwang   first commit
12
  
7214c2e7   tangwang   mplemented**
13
  from config.services_config import get_embedding_image_base_url
4a37d233   tangwang   1. embedding cach...
14
  from config.env_config import REDIS_CONFIG
7214c2e7   tangwang   mplemented**
15
  from embeddings.cache_keys import build_image_cache_key
4a37d233   tangwang   1. embedding cach...
16
  from embeddings.redis_embedding_cache import RedisEmbeddingCache
42e3aea6   tangwang   tidy
17
  
be52af70   tangwang   first commit
18
19
20
  
  class CLIPImageEncoder:
      """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
21
      Image Encoder for generating image embeddings using network service.
be52af70   tangwang   first commit
22
  
950a640e   tangwang   embeddings
23
      This client is stateless and safe to instantiate per caller.
be52af70   tangwang   first commit
24
25
      """
  
950a640e   tangwang   embeddings
26
      def __init__(self, service_url: Optional[str] = None):
7214c2e7   tangwang   mplemented**
27
28
29
30
31
32
          resolved_url = (
              service_url
              or os.getenv("EMBEDDING_IMAGE_SERVICE_URL")
              or os.getenv("EMBEDDING_SERVICE_URL")
              or get_embedding_image_base_url()
          )
950a640e   tangwang   embeddings
33
34
          self.service_url = str(resolved_url).rstrip("/")
          self.endpoint = f"{self.service_url}/embed/image"
4a37d233   tangwang   1. embedding cach...
35
36
          # Reuse embedding cache prefix, but separate namespace for images to avoid collisions.
          self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding"
950a640e   tangwang   embeddings
37
          logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url)
4a37d233   tangwang   1. embedding cach...
38
39
40
41
          self.cache = RedisEmbeddingCache(
              key_prefix=self.cache_prefix,
              namespace="image",
          )
be52af70   tangwang   first commit
42
  
200fdddf   tangwang   embed norm
43
      def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
44
45
          """
          Call the embedding service API.
be52af70   tangwang   first commit
46
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
47
          Args:
7bfb9946   tangwang   向量化模块
48
              request_data: List of image URLs / local file paths
be52af70   tangwang   first commit
49
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
50
          Returns:
7bfb9946   tangwang   向量化模块
51
              List of embeddings (list[float]) or nulls (None), aligned to input order
325eec03   tangwang   1. 日志、配置基础设施,使用优化
52
          """
be52af70   tangwang   first commit
53
          try:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
54
55
              response = requests.post(
                  self.endpoint,
200fdddf   tangwang   embed norm
56
                  params={"normalize": "true" if normalize_embeddings else "false"},
325eec03   tangwang   1. 日志、配置基础设施,使用优化
57
58
59
60
61
62
63
64
                  json=request_data,
                  timeout=60
              )
              response.raise_for_status()
              return response.json()
          except requests.exceptions.RequestException as e:
              logger.error(f"CLIPImageEncoder service request failed: {e}", exc_info=True)
              raise
be52af70   tangwang   first commit
65
  
ed948666   tangwang   tidy
66
      def encode_image(self, image: Image.Image) -> np.ndarray:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
67
68
          """
          Encode image to embedding vector using network service.
be52af70   tangwang   first commit
69
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
70
71
          Note: This method is kept for compatibility but the service only works with URLs.
          """
ed948666   tangwang   tidy
72
          raise NotImplementedError("encode_image with PIL Image is not supported by embedding service")
be52af70   tangwang   first commit
73
  
200fdddf   tangwang   embed norm
74
      def encode_image_from_url(self, url: str, normalize_embeddings: bool = True) -> np.ndarray:
325eec03   tangwang   1. 日志、配置基础设施,使用优化
75
76
          """
          Generate image embedding via network service using URL.
be52af70   tangwang   first commit
77
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
78
79
          Args:
              url: Image URL to process
be52af70   tangwang   first commit
80
  
325eec03   tangwang   1. 日志、配置基础设施,使用优化
81
          Returns:
ed948666   tangwang   tidy
82
              Embedding vector
325eec03   tangwang   1. 日志、配置基础设施,使用优化
83
          """
7214c2e7   tangwang   mplemented**
84
85
          cache_key = build_image_cache_key(url, normalize=normalize_embeddings)
          cached = self.cache.get(cache_key)
4a37d233   tangwang   1. embedding cach...
86
87
88
          if cached is not None:
              return cached
  
200fdddf   tangwang   embed norm
89
          response_data = self._call_service([url], normalize_embeddings=normalize_embeddings)
ed948666   tangwang   tidy
90
91
92
93
94
          if not response_data or len(response_data) != 1 or response_data[0] is None:
              raise RuntimeError(f"No image embedding returned for URL: {url}")
          vec = np.array(response_data[0], dtype=np.float32)
          if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
              raise RuntimeError(f"Invalid image embedding returned for URL: {url}")
7214c2e7   tangwang   mplemented**
95
          self.cache.set(cache_key, vec)
ed948666   tangwang   tidy
96
          return vec
be52af70   tangwang   first commit
97
98
99
100
  
      def encode_batch(
          self,
          images: List[Union[str, Image.Image]],
200fdddf   tangwang   embed norm
101
102
          batch_size: int = 8,
          normalize_embeddings: bool = True,
ed948666   tangwang   tidy
103
      ) -> List[np.ndarray]:
be52af70   tangwang   first commit
104
          """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
105
          Encode a batch of images efficiently via network service.
be52af70   tangwang   first commit
106
107
108
  
          Args:
              images: List of image URLs or PIL Images
325eec03   tangwang   1. 日志、配置基础设施,使用优化
109
              batch_size: Batch size for processing (used for service requests)
be52af70   tangwang   first commit
110
111
  
          Returns:
ed948666   tangwang   tidy
112
              List of embeddings
be52af70   tangwang   first commit
113
          """
325eec03   tangwang   1. 日志、配置基础设施,使用优化
114
          for i, img in enumerate(images):
ed948666   tangwang   tidy
115
116
117
118
119
120
              if isinstance(img, Image.Image):
                  raise NotImplementedError(f"PIL Image at index {i} is not supported by service")
              if not isinstance(img, str) or not img.strip():
                  raise ValueError(f"Invalid image URL/path at index {i}: {img!r}")
  
          results: List[np.ndarray] = []
4a37d233   tangwang   1. embedding cach...
121
122
123
124
125
          pending_urls: List[str] = []
          pending_positions: List[int] = []
  
          normalized_urls = [str(u).strip() for u in images]  # type: ignore[list-item]
          for pos, url in enumerate(normalized_urls):
7214c2e7   tangwang   mplemented**
126
127
              cache_key = build_image_cache_key(url, normalize=normalize_embeddings)
              cached = self.cache.get(cache_key)
4a37d233   tangwang   1. embedding cach...
128
129
              if cached is not None:
                  results.append(cached)
5bac9649   tangwang   文本 embedding 与图片 ...
130
131
132
133
                  continue
              results.append(np.array([], dtype=np.float32))  # placeholder
              pending_positions.append(pos)
              pending_urls.append(url)
4a37d233   tangwang   1. embedding cach...
134
135
136
  
          for i in range(0, len(pending_urls), batch_size):
              batch_urls = pending_urls[i : i + batch_size]
200fdddf   tangwang   embed norm
137
              response_data = self._call_service(batch_urls, normalize_embeddings=normalize_embeddings)
ed948666   tangwang   tidy
138
139
140
141
142
143
144
145
146
147
148
149
              if not response_data or len(response_data) != len(batch_urls):
                  raise RuntimeError(
                      f"Image embedding response length mismatch: expected {len(batch_urls)}, "
                      f"got {0 if response_data is None else len(response_data)}"
                  )
              for j, url in enumerate(batch_urls):
                  embedding = response_data[j]
                  if embedding is None:
                      raise RuntimeError(f"No image embedding returned for URL: {url}")
                  vec = np.array(embedding, dtype=np.float32)
                  if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
                      raise RuntimeError(f"Invalid image embedding returned for URL: {url}")
7214c2e7   tangwang   mplemented**
150
                  self.cache.set(build_image_cache_key(url, normalize=normalize_embeddings), vec)
4a37d233   tangwang   1. embedding cach...
151
152
                  pos = pending_positions[i + j]
                  results[pos] = vec
be52af70   tangwang   first commit
153
154
  
          return results
e7a2c0b7   tangwang   img encode
155
156
157
158
159
  
      def encode_image_urls(
          self,
          urls: List[str],
          batch_size: Optional[int] = None,
200fdddf   tangwang   embed norm
160
          normalize_embeddings: bool = True,
ed948666   tangwang   tidy
161
      ) -> List[np.ndarray]:
e7a2c0b7   tangwang   img encode
162
163
164
165
166
167
168
169
          """
           ClipImageModel / ClipAsServiceImageEncoder 一致的接口,供索引器 document_transformer 调用。
  
          Args:
              urls: 图片 URL 列表
              batch_size: 批大小(默认 8
  
          Returns:
ed948666   tangwang   tidy
170
               urls 等长的向量列表
e7a2c0b7   tangwang   img encode
171
          """
200fdddf   tangwang   embed norm
172
173
174
175
176
          return self.encode_batch(
              urls,
              batch_size=batch_size or 8,
              normalize_embeddings=normalize_embeddings,
          )