950a640e
tangwang
embeddings
|
1
|
"""Text embedding client for the local embedding HTTP service."""
|
be52af70
tangwang
first commit
|
2
|
|
950a640e
tangwang
embeddings
|
3
4
5
6
7
|
import logging
import os
import pickle
from datetime import timedelta
from typing import Any, List, Optional, Union
|
be52af70
tangwang
first commit
|
8
|
|
be52af70
tangwang
first commit
|
9
|
import numpy as np
|
453992a8
tangwang
需求:
|
10
|
import redis
|
950a640e
tangwang
embeddings
|
11
|
import requests
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
12
13
|
logger = logging.getLogger(__name__)
|
be52af70
tangwang
first commit
|
14
|
|
42e3aea6
tangwang
tidy
|
15
16
|
from config.services_config import get_embedding_base_url
|
453992a8
tangwang
需求:
|
17
18
19
20
21
22
|
# Try to import REDIS_CONFIG, but allow import to fail
try:
from config.env_config import REDIS_CONFIG
except ImportError:
REDIS_CONFIG = {}
|
be52af70
tangwang
first commit
|
23
|
|
950a640e
tangwang
embeddings
|
24
|
class TextEmbeddingEncoder:
|
be52af70
tangwang
first commit
|
25
|
"""
|
950a640e
tangwang
embeddings
|
26
|
Text embedding encoder using network service.
|
be52af70
tangwang
first commit
|
27
|
"""
|
be52af70
tangwang
first commit
|
28
|
|
950a640e
tangwang
embeddings
|
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
def __init__(self, service_url: Optional[str] = None):
resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url()
self.service_url = str(resolved_url).rstrip("/")
self.endpoint = f"{self.service_url}/embed/text"
self.expire_time = timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180))
logger.info("Creating TextEmbeddingEncoder instance with service URL: %s", self.service_url)
try:
self.redis_client = redis.Redis(
host=REDIS_CONFIG.get("host", "localhost"),
port=REDIS_CONFIG.get("port", 6479),
password=REDIS_CONFIG.get("password"),
decode_responses=False,
socket_timeout=REDIS_CONFIG.get("socket_timeout", 1),
socket_connect_timeout=REDIS_CONFIG.get("socket_connect_timeout", 1),
retry_on_timeout=REDIS_CONFIG.get("retry_on_timeout", False),
health_check_interval=10,
)
self.redis_client.ping()
logger.info("Redis cache initialized for embeddings")
except Exception as e:
logger.warning("Failed to initialize Redis cache for embeddings: %s, continuing without cache", e)
self.redis_client = None
|
be52af70
tangwang
first commit
|
52
|
|
200fdddf
tangwang
embed norm
|
53
|
def _call_service(self, request_data: List[str], normalize_embeddings: bool = True) -> List[Any]:
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
54
55
56
57
|
"""
Call the embedding service API.
Args:
|
7bfb9946
tangwang
向量化模块
|
58
|
request_data: List of texts
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
59
60
|
Returns:
|
7bfb9946
tangwang
向量化模块
|
61
|
List of embeddings (list[float]) or nulls (None), aligned to input order
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
62
63
64
65
|
"""
try:
response = requests.post(
self.endpoint,
|
200fdddf
tangwang
embed norm
|
66
|
params={"normalize": "true" if normalize_embeddings else "false"},
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
67
68
69
70
71
72
|
json=request_data,
timeout=60
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
|
950a640e
tangwang
embeddings
|
73
|
logger.error(f"TextEmbeddingEncoder service request failed: {e}", exc_info=True)
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
74
75
|
raise
|
be52af70
tangwang
first commit
|
76
77
78
79
|
def encode(
self,
sentences: Union[str, List[str]],
normalize_embeddings: bool = True,
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
80
|
device: str = 'cpu',
|
be52af70
tangwang
first commit
|
81
82
83
|
batch_size: int = 32
) -> np.ndarray:
"""
|
453992a8
tangwang
需求:
|
84
|
Encode text into embeddings via network service with Redis caching.
|
be52af70
tangwang
first commit
|
85
86
87
|
Args:
sentences: Single string or list of strings to encode
|
200fdddf
tangwang
embed norm
|
88
|
normalize_embeddings: Whether to request normalized embeddings from service
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
89
90
|
device: Device parameter ignored for service compatibility
batch_size: Batch size for processing (used for service requests)
|
be52af70
tangwang
first commit
|
91
92
|
Returns:
|
ed948666
tangwang
tidy
|
93
94
|
numpy array of dtype=object,元素均为有效 np.ndarray 向量。
若任一输入无法生成向量,将直接抛出异常。
|
be52af70
tangwang
first commit
|
95
|
"""
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
96
97
98
|
# Convert single string to list
if isinstance(sentences, str):
sentences = [sentences]
|
be52af70
tangwang
first commit
|
99
|
|
453992a8
tangwang
需求:
|
100
|
# Check cache first
|
b2e50710
tangwang
BgeEncoder.encode...
|
101
102
|
uncached_indices: List[int] = []
uncached_texts: List[str] = []
|
453992a8
tangwang
需求:
|
103
|
|
70a318c6
tangwang
fix bug
|
104
105
106
|
embeddings: List[Optional[np.ndarray]] = [None] * len(sentences)
for i, text in enumerate(sentences):
|
200fdddf
tangwang
embed norm
|
107
|
cached = self._get_cached_embedding(text, "generic", normalize_embeddings)
|
70a318c6
tangwang
fix bug
|
108
109
110
111
112
113
114
|
if cached is not None:
embeddings[i] = cached
else:
uncached_indices.append(i)
uncached_texts.append(text)
# Prepare request data for uncached texts (after cache check)
|
7bfb9946
tangwang
向量化模块
|
115
|
request_data = list(uncached_texts)
|
453992a8
tangwang
需求:
|
116
117
118
|
# If there are uncached texts, call service
if uncached_texts:
|
200fdddf
tangwang
embed norm
|
119
|
response_data = self._call_service(request_data, normalize_embeddings=normalize_embeddings)
|
453992a8
tangwang
需求:
|
120
|
|
ed948666
tangwang
tidy
|
121
122
123
124
125
126
127
|
# Process response
for i, text in enumerate(uncached_texts):
original_idx = uncached_indices[i]
if response_data and i < len(response_data):
embedding = response_data[i]
else:
embedding = None
|
7bfb9946
tangwang
向量化模块
|
128
|
|
ed948666
tangwang
tidy
|
129
130
131
132
|
if embedding is not None:
embedding_array = np.array(embedding, dtype=np.float32)
if self._is_valid_embedding(embedding_array):
embeddings[original_idx] = embedding_array
|
200fdddf
tangwang
embed norm
|
133
|
self._set_cached_embedding(text, "generic", embedding_array, normalize_embeddings)
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
134
|
else:
|
ed948666
tangwang
tidy
|
135
136
137
138
139
|
raise ValueError(
f"Invalid embedding returned from service for text index {original_idx}"
)
else:
raise ValueError(f"No embedding found for text index {original_idx}: {text[:50]}...")
|
453992a8
tangwang
需求:
|
140
|
|
b2e50710
tangwang
BgeEncoder.encode...
|
141
142
|
# 返回 numpy 数组(dtype=object),元素为 np.ndarray 或 None
return np.array(embeddings, dtype=object)
|
be52af70
tangwang
first commit
|
143
144
145
146
147
|
def encode_batch(
self,
texts: List[str],
batch_size: int = 32,
|
200fdddf
tangwang
embed norm
|
148
149
|
device: str = 'cpu',
normalize_embeddings: bool = True,
|
be52af70
tangwang
first commit
|
150
151
|
) -> np.ndarray:
"""
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
152
|
Encode a batch of texts efficiently via network service.
|
be52af70
tangwang
first commit
|
153
154
155
156
|
Args:
texts: List of texts to encode
batch_size: Batch size for processing
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
157
|
device: Device parameter ignored for service compatibility
|
be52af70
tangwang
first commit
|
158
159
160
161
|
Returns:
numpy array of embeddings
"""
|
200fdddf
tangwang
embed norm
|
162
163
164
165
166
167
|
return self.encode(
texts,
batch_size=batch_size,
device=device,
normalize_embeddings=normalize_embeddings,
)
|
453992a8
tangwang
需求:
|
168
|
|
200fdddf
tangwang
embed norm
|
169
|
def _get_cache_key(self, query: str, language: str, normalize_embeddings: bool = True) -> str:
|
453992a8
tangwang
需求:
|
170
|
"""Generate a cache key for the query"""
|
200fdddf
tangwang
embed norm
|
171
172
|
norm_flag = "norm1" if normalize_embeddings else "norm0"
return f"embedding:{language}:{norm_flag}:{query}"
|
453992a8
tangwang
需求:
|
173
|
|
b2e50710
tangwang
BgeEncoder.encode...
|
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
|
def _is_valid_embedding(self, embedding: np.ndarray) -> bool:
"""
Check if embedding is valid (not None, correct shape, no NaN/Inf).
Args:
embedding: Embedding array to validate
Returns:
True if valid, False otherwise
"""
if embedding is None:
return False
if not isinstance(embedding, np.ndarray):
return False
if embedding.size == 0:
return False
# Check for NaN or Inf values
if not np.isfinite(embedding).all():
return False
return True
|
200fdddf
tangwang
embed norm
|
195
196
197
198
199
200
|
def _get_cached_embedding(
self,
query: str,
language: str,
normalize_embeddings: bool = True,
) -> Optional[np.ndarray]:
|
453992a8
tangwang
需求:
|
201
202
203
204
205
|
"""Get embedding from cache if exists (with sliding expiration)"""
if not self.redis_client:
return None
try:
|
200fdddf
tangwang
embed norm
|
206
|
cache_key = self._get_cache_key(query, language, normalize_embeddings)
|
453992a8
tangwang
需求:
|
207
208
|
cached_data = self.redis_client.get(cache_key)
if cached_data:
|
b2e50710
tangwang
BgeEncoder.encode...
|
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
|
embedding = pickle.loads(cached_data)
# Validate cached embedding - if invalid, ignore cache and return None
if self._is_valid_embedding(embedding):
logger.debug(f"Cache hit for embedding: {query}")
# Update expiration time on access (sliding expiration)
self.redis_client.expire(cache_key, self.expire_time)
return embedding
else:
logger.warning(
f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), "
f"ignoring cache for query: {query[:50]}..."
)
# Delete invalid cache entry
try:
self.redis_client.delete(cache_key)
except Exception as e:
logger.debug(f"Failed to delete invalid cache entry: {e}")
return None
|
453992a8
tangwang
需求:
|
227
228
229
230
231
|
return None
except Exception as e:
logger.error(f"Error retrieving embedding from cache: {e}")
return None
|
200fdddf
tangwang
embed norm
|
232
233
234
235
236
237
238
|
def _set_cached_embedding(
self,
query: str,
language: str,
embedding: np.ndarray,
normalize_embeddings: bool = True,
) -> bool:
|
453992a8
tangwang
需求:
|
239
240
241
242
243
|
"""Store embedding in cache"""
if not self.redis_client:
return False
try:
|
200fdddf
tangwang
embed norm
|
244
|
cache_key = self._get_cache_key(query, language, normalize_embeddings)
|
453992a8
tangwang
需求:
|
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
|
serialized_data = pickle.dumps(embedding)
self.redis_client.setex(
cache_key,
self.expire_time,
serialized_data
)
logger.debug(f"Successfully cached embedding for query: {query}")
return True
except (redis.exceptions.BusyLoadingError, redis.exceptions.ConnectionError,
redis.exceptions.TimeoutError, redis.exceptions.RedisError) as e:
logger.warning(f"Redis error storing embedding in cache: {e}")
return False
except Exception as e:
logger.error(f"Error storing embedding in cache: {e}")
return False
|