text_encoder.py
4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Text embedding encoder using network service.
Generates embeddings via HTTP API service running on localhost:5001.
"""
import sys
import requests
import time
import threading
import numpy as np
import logging
from typing import List, Union, Dict, Any
logger = logging.getLogger(__name__)
class BgeEncoder:
"""
Singleton text encoder using network service.
Thread-safe singleton pattern ensures only one instance exists.
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, service_url='http://localhost:5001'):
with cls._lock:
if cls._instance is None:
cls._instance = super(BgeEncoder, cls).__new__(cls)
logger.info(f"Creating BgeEncoder instance with service URL: {service_url}")
cls._instance.service_url = service_url
cls._instance.endpoint = f"{service_url}/embedding/generate_embeddings"
return cls._instance
def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Call the embedding service API.
Args:
request_data: List of dictionaries with id and text fields
Returns:
List of dictionaries with id and embedding fields
"""
try:
response = requests.post(
self.endpoint,
json=request_data,
timeout=60
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
logger.error(f"BgeEncoder service request failed: {e}", exc_info=True)
raise
def encode(
self,
sentences: Union[str, List[str]],
normalize_embeddings: bool = True,
device: str = 'cpu',
batch_size: int = 32
) -> np.ndarray:
"""
Encode text into embeddings via network service.
Args:
sentences: Single string or list of strings to encode
normalize_embeddings: Whether to normalize embeddings (ignored for service)
device: Device parameter ignored for service compatibility
batch_size: Batch size for processing (used for service requests)
Returns:
numpy array of shape (n, 1024) containing embeddings
"""
# Convert single string to list
if isinstance(sentences, str):
sentences = [sentences]
# Prepare request data
request_data = []
for i, text in enumerate(sentences):
request_item = {
"id": str(i),
"name_zh": text
}
# Add English and Russian fields as empty for now
# Could be enhanced with language detection in the future
request_item["name_en"] = None
request_item["name_ru"] = None
request_data.append(request_item)
try:
# Call service
response_data = self._call_service(request_data)
# Process response
embeddings = []
for i, text in enumerate(sentences):
# Find corresponding response by ID
response_item = None
for item in response_data:
if str(item.get("id")) == str(i):
response_item = item
break
if response_item:
# Try Chinese embedding first, then English, then Russian
embedding = None
for lang in ["embedding_zh", "embedding_en", "embedding_ru"]:
if lang in response_item and response_item[lang] is not None:
embedding = response_item[lang]
break
if embedding is not None:
embeddings.append(embedding)
else:
logger.warning(f"No embedding found for text {i}: {text[:50]}...")
embeddings.append([0.0] * 1024)
else:
logger.warning(f"No response found for text {i}")
embeddings.append([0.0] * 1024)
return np.array(embeddings, dtype=np.float32)
except Exception as e:
logger.error(f"Failed to encode texts: {e}", exc_info=True)
# Return zero embeddings as fallback
return np.zeros((len(sentences), 1024), dtype=np.float32)
def encode_batch(
self,
texts: List[str],
batch_size: int = 32,
device: str = 'cpu'
) -> np.ndarray:
"""
Encode a batch of texts efficiently via network service.
Args:
texts: List of texts to encode
batch_size: Batch size for processing
device: Device parameter ignored for service compatibility
Returns:
numpy array of embeddings
"""
return self.encode(texts, batch_size=batch_size, device=device)