be52af70
tangwang
first commit
|
1
|
"""
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
2
|
Text embedding encoder using network service.
|
be52af70
tangwang
first commit
|
3
|
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
4
|
Generates embeddings via HTTP API service running on localhost:5001.
|
be52af70
tangwang
first commit
|
5
6
7
|
"""
import sys
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
8
|
import requests
|
be52af70
tangwang
first commit
|
9
10
|
import time
import threading
|
be52af70
tangwang
first commit
|
11
|
import numpy as np
|
453992a8
tangwang
需求:
|
12
13
14
15
|
import pickle
import redis
from datetime import timedelta
from typing import List, Union, Dict, Any, Optional
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
16
|
import logging
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
17
18
|
logger = logging.getLogger(__name__)
|
be52af70
tangwang
first commit
|
19
|
|
453992a8
tangwang
需求:
|
20
21
22
23
24
25
|
# Try to import REDIS_CONFIG, but allow import to fail
try:
from config.env_config import REDIS_CONFIG
except ImportError:
REDIS_CONFIG = {}
|
be52af70
tangwang
first commit
|
26
27
28
|
class BgeEncoder:
"""
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
29
|
Singleton text encoder using network service.
|
be52af70
tangwang
first commit
|
30
|
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
31
|
Thread-safe singleton pattern ensures only one instance exists.
|
be52af70
tangwang
first commit
|
32
33
34
35
|
"""
_instance = None
_lock = threading.Lock()
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
36
|
def __new__(cls, service_url='http://localhost:5001'):
|
be52af70
tangwang
first commit
|
37
38
39
|
with cls._lock:
if cls._instance is None:
cls._instance = super(BgeEncoder, cls).__new__(cls)
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
40
41
42
|
logger.info(f"Creating BgeEncoder instance with service URL: {service_url}")
cls._instance.service_url = service_url
cls._instance.endpoint = f"{service_url}/embedding/generate_embeddings"
|
453992a8
tangwang
需求:
|
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
|
# Initialize Redis cache
try:
cls._instance.redis_client = redis.Redis(
host=REDIS_CONFIG.get('host', 'localhost'),
port=REDIS_CONFIG.get('port', 6479),
password=REDIS_CONFIG.get('password'),
decode_responses=False, # Keep binary data as is
socket_timeout=REDIS_CONFIG.get('socket_timeout', 1),
socket_connect_timeout=REDIS_CONFIG.get('socket_connect_timeout', 1),
retry_on_timeout=REDIS_CONFIG.get('retry_on_timeout', False),
health_check_interval=10 # 避免复用坏连接
)
# Test connection
cls._instance.redis_client.ping()
cls._instance.expire_time = timedelta(days=REDIS_CONFIG.get('cache_expire_days', 180))
logger.info("Redis cache initialized for embeddings")
except Exception as e:
logger.warning(f"Failed to initialize Redis cache for embeddings: {e}, continuing without cache")
cls._instance.redis_client = None
|
be52af70
tangwang
first commit
|
63
64
|
return cls._instance
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Call the embedding service API.
Args:
request_data: List of dictionaries with id and text fields
Returns:
List of dictionaries with id and embedding fields
"""
try:
response = requests.post(
self.endpoint,
json=request_data,
timeout=60
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"BgeEncoder service request failed: {e}", exc_info=True)
raise
|
be52af70
tangwang
first commit
|
87
88
89
90
|
def encode(
self,
sentences: Union[str, List[str]],
normalize_embeddings: bool = True,
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
91
|
device: str = 'cpu',
|
be52af70
tangwang
first commit
|
92
93
94
|
batch_size: int = 32
) -> np.ndarray:
"""
|
453992a8
tangwang
需求:
|
95
|
Encode text into embeddings via network service with Redis caching.
|
be52af70
tangwang
first commit
|
96
97
98
|
Args:
sentences: Single string or list of strings to encode
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
99
100
101
|
normalize_embeddings: Whether to normalize embeddings (ignored for service)
device: Device parameter ignored for service compatibility
batch_size: Batch size for processing (used for service requests)
|
be52af70
tangwang
first commit
|
102
103
|
Returns:
|
b2e50710
tangwang
BgeEncoder.encode...
|
104
105
106
|
numpy array of dtype=object, where each element is either:
- np.ndarray (valid embedding vector) or
- None (no embedding available for that text)
|
be52af70
tangwang
first commit
|
107
|
"""
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
108
109
110
|
# Convert single string to list
if isinstance(sentences, str):
sentences = [sentences]
|
be52af70
tangwang
first commit
|
111
|
|
453992a8
tangwang
需求:
|
112
|
# Check cache first
|
b2e50710
tangwang
BgeEncoder.encode...
|
113
114
|
uncached_indices: List[int] = []
uncached_texts: List[str] = []
|
453992a8
tangwang
需求:
|
115
116
117
118
|
# Prepare request data for uncached texts
request_data = []
for i, text in enumerate(uncached_texts):
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
119
|
request_item = {
|
453992a8
tangwang
需求:
|
120
|
"id": str(uncached_indices[i]),
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
121
122
|
"name_zh": text
}
|
be52af70
tangwang
first commit
|
123
|
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
124
125
126
127
|
# Add English and Russian fields as empty for now
# Could be enhanced with language detection in the future
request_item["name_en"] = None
request_item["name_ru"] = None
|
be52af70
tangwang
first commit
|
128
|
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
129
130
|
request_data.append(request_item)
|
453992a8
tangwang
需求:
|
131
|
# Process response
|
b2e50710
tangwang
BgeEncoder.encode...
|
132
133
134
135
136
137
138
139
140
141
|
# Each element can be np.ndarray or None (表示该文本没有可用的向量)
embeddings: List[Optional[np.ndarray]] = [None] * len(sentences)
for i, text in enumerate(sentences):
cached = self._get_cached_embedding(text, 'en') # Use 'en' as default language for title embedding
if cached is not None:
embeddings[i] = cached
else:
uncached_indices.append(i)
uncached_texts.append(text)
|
453992a8
tangwang
需求:
|
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
|
# If there are uncached texts, call service
if uncached_texts:
try:
# Call service
response_data = self._call_service(request_data)
# Process response
for i, text in enumerate(uncached_texts):
original_idx = uncached_indices[i]
# Find corresponding response by ID
response_item = None
for item in response_data:
if str(item.get("id")) == str(original_idx):
response_item = item
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
157
158
|
break
|
453992a8
tangwang
需求:
|
159
160
161
162
163
164
165
166
167
168
|
if response_item:
# Try Chinese embedding first, then English, then Russian
embedding = None
for lang in ["embedding_zh", "embedding_en", "embedding_ru"]:
if lang in response_item and response_item[lang] is not None:
embedding = response_item[lang]
break
if embedding is not None:
embedding_array = np.array(embedding, dtype=np.float32)
|
b2e50710
tangwang
BgeEncoder.encode...
|
169
170
171
172
173
174
175
176
177
178
179
180
181
|
# Validate embedding from service - if invalid, treat as no result
if self._is_valid_embedding(embedding_array):
embeddings[original_idx] = embedding_array
# Cache the embedding
self._set_cached_embedding(text, 'en', embedding_array)
else:
logger.warning(
f"Invalid embedding returned from service for text {original_idx} "
f"(contains NaN/Inf or invalid shape), treating as no result. "
f"Text preview: {text[:50]}..."
)
# 不生成兜底向量,保持为 None
embeddings[original_idx] = None
|
453992a8
tangwang
需求:
|
182
183
|
else:
logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...")
|
b2e50710
tangwang
BgeEncoder.encode...
|
184
185
|
# 不生成兜底向量,保持为 None
embeddings[original_idx] = None
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
186
|
else:
|
453992a8
tangwang
需求:
|
187
|
logger.warning(f"No response found for text {original_idx}")
|
b2e50710
tangwang
BgeEncoder.encode...
|
188
189
|
# 不生成兜底向量,保持为 None
embeddings[original_idx] = None
|
453992a8
tangwang
需求:
|
190
191
192
|
except Exception as e:
logger.error(f"Failed to encode texts: {e}", exc_info=True)
|
b2e50710
tangwang
BgeEncoder.encode...
|
193
194
|
# 出错时不要生成兜底全零向量,保持为 None
pass
|
453992a8
tangwang
需求:
|
195
|
|
b2e50710
tangwang
BgeEncoder.encode...
|
196
197
|
# 返回 numpy 数组(dtype=object),元素为 np.ndarray 或 None
return np.array(embeddings, dtype=object)
|
be52af70
tangwang
first commit
|
198
199
200
201
202
|
def encode_batch(
self,
texts: List[str],
batch_size: int = 32,
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
203
|
device: str = 'cpu'
|
be52af70
tangwang
first commit
|
204
205
|
) -> np.ndarray:
"""
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
206
|
Encode a batch of texts efficiently via network service.
|
be52af70
tangwang
first commit
|
207
208
209
210
|
Args:
texts: List of texts to encode
batch_size: Batch size for processing
|
325eec03
tangwang
1. 日志、配置基础设施,使用优化
|
211
|
device: Device parameter ignored for service compatibility
|
be52af70
tangwang
first commit
|
212
213
214
215
216
|
Returns:
numpy array of embeddings
"""
return self.encode(texts, batch_size=batch_size, device=device)
|
453992a8
tangwang
需求:
|
217
218
219
220
221
|
def _get_cache_key(self, query: str, language: str) -> str:
"""Generate a cache key for the query"""
return f"embedding:{language}:{query}"
|
b2e50710
tangwang
BgeEncoder.encode...
|
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
|
def _is_valid_embedding(self, embedding: np.ndarray) -> bool:
"""
Check if embedding is valid (not None, correct shape, no NaN/Inf).
Args:
embedding: Embedding array to validate
Returns:
True if valid, False otherwise
"""
if embedding is None:
return False
if not isinstance(embedding, np.ndarray):
return False
if embedding.size == 0:
return False
# Check for NaN or Inf values
if not np.isfinite(embedding).all():
return False
return True
|
453992a8
tangwang
需求:
|
243
244
245
246
247
248
249
250
251
|
def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]:
"""Get embedding from cache if exists (with sliding expiration)"""
if not self.redis_client:
return None
try:
cache_key = self._get_cache_key(query, language)
cached_data = self.redis_client.get(cache_key)
if cached_data:
|
b2e50710
tangwang
BgeEncoder.encode...
|
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
|
embedding = pickle.loads(cached_data)
# Validate cached embedding - if invalid, ignore cache and return None
if self._is_valid_embedding(embedding):
logger.debug(f"Cache hit for embedding: {query}")
# Update expiration time on access (sliding expiration)
self.redis_client.expire(cache_key, self.expire_time)
return embedding
else:
logger.warning(
f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), "
f"ignoring cache for query: {query[:50]}..."
)
# Delete invalid cache entry
try:
self.redis_client.delete(cache_key)
except Exception as e:
logger.debug(f"Failed to delete invalid cache entry: {e}")
return None
|
453992a8
tangwang
需求:
|
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
|
return None
except Exception as e:
logger.error(f"Error retrieving embedding from cache: {e}")
return None
def _set_cached_embedding(self, query: str, language: str, embedding: np.ndarray) -> bool:
"""Store embedding in cache"""
if not self.redis_client:
return False
try:
cache_key = self._get_cache_key(query, language)
serialized_data = pickle.dumps(embedding)
self.redis_client.setex(
cache_key,
self.expire_time,
serialized_data
)
logger.debug(f"Successfully cached embedding for query: {query}")
return True
except (redis.exceptions.BusyLoadingError, redis.exceptions.ConnectionError,
redis.exceptions.TimeoutError, redis.exceptions.RedisError) as e:
logger.warning(f"Redis error storing embedding in cache: {e}")
return False
except Exception as e:
logger.error(f"Error storing embedding in cache: {e}")
return False
|