test_cnclip_service.py
8.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
#!/usr/bin/env python3
"""
CN-CLIP 服务测试脚本
用法:
python scripts/test_cnclip_service.py
选项:
--url TEXT 服务地址(默认:grpc://localhost:51000)
--text 只测试文本编码
--image 只测试图像编码
--batch-size INT 批处理大小(默认:10)
--help 显示帮助信息
"""
import sys
import time
import argparse
from pathlib import Path
# 添加项目路径到 sys.path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# 颜色输出
class Colors:
GREEN = '\033[0;32m'
RED = '\033[0;31m'
YELLOW = '\033[1;33m'
BLUE = '\033[0;34m'
NC = '\033[0m'
def print_success(msg):
print(f"{Colors.GREEN}✓ {msg}{Colors.NC}")
def print_error(msg):
print(f"{Colors.RED}✗ {msg}{Colors.NC}")
def print_warning(msg):
print(f"{Colors.YELLOW}⚠ {msg}{Colors.NC}")
def print_info(msg):
print(f"{Colors.BLUE}ℹ {msg}{Colors.NC}")
def test_imports():
"""测试必要的依赖是否安装"""
print("\n" + "="*50)
print("测试 1: 检查依赖")
print("="*50)
try:
import clip_client
print_success("clip_client 已安装")
except ImportError as e:
print_error(f"clip_client 未安装: {e}")
print_info("请运行: pip install clip-client")
return False
try:
import numpy as np
print_success("numpy 已安装")
except ImportError as e:
print_error(f"numpy 未安装: {e}")
return False
return True
def test_connection(url):
"""测试服务连接"""
print("\n" + "="*50)
print("测试 2: 连接服务")
print("="*50)
print(f"服务地址: {url}")
try:
from clip_client import Client
client = Client(url)
print_success("客户端创建成功")
return client
except Exception as e:
print_error(f"连接失败: {e}")
print_info("请确保服务已启动: ./scripts/start_cnclip_service.sh")
return None
def test_text_encoding(client, batch_size=10):
"""测试文本编码"""
print("\n" + "="*50)
print("测试 3: 文本编码")
print("="*50)
try:
# 准备测试数据
test_texts = [
'你好,世界',
'CN-CLIP 图像编码服务',
'这是一个测试',
'人工智能',
'机器学习',
'深度学习',
'计算机视觉',
'自然语言处理',
'搜索引擎',
'多模态检索',
][:batch_size]
print(f"测试文本数量: {len(test_texts)}")
print(f"示例文本: {test_texts[0]}")
# 执行编码
start_time = time.time()
embeddings = client.encode(test_texts)
elapsed_time = time.time() - start_time
# 验证结果
assert embeddings.shape[0] == len(test_texts), "向量数量不匹配"
assert embeddings.shape[1] == 1024, "向量维度应该是 1024"
print_success(f"编码成功")
print(f" 向量形状: {embeddings.shape}")
print(f" 耗时: {elapsed_time:.2f}秒")
print(f" 速度: {len(test_texts)/elapsed_time:.2f} 条/秒")
print(f" 数据类型: {embeddings.dtype}")
return True
except Exception as e:
print_error(f"文本编码失败: {e}")
return False
def test_image_encoding(client, batch_size=5):
"""测试图像编码"""
print("\n" + "="*50)
print("测试 4: 图像编码")
print("="*50)
try:
# 准备测试数据(使用在线图片)
test_images = [
'https://picsum.photos/224',
'https://picsum.photos/224?random=1',
'https://picsum.photos/224?random=2',
'https://picsum.photos/224?random=3',
'https://picsum.photos/224?random=4',
][:batch_size]
print(f"测试图像数量: {len(test_images)}")
print(f"示例 URL: {test_images[0]}")
# 执行编码
start_time = time.time()
embeddings = client.encode(test_images)
elapsed_time = time.time() - start_time
# 验证结果
assert embeddings.shape[0] == len(test_images), "向量数量不匹配"
assert embeddings.shape[1] == 1024, "向量维度应该是 1024"
print_success(f"编码成功")
print(f" 向量形状: {embeddings.shape}")
print(f" 耗时: {elapsed_time:.2f}秒")
print(f" 速度: {len(test_images)/elapsed_time:.2f} 条/秒")
print(f" 数据类型: {embeddings.dtype}")
return True
except Exception as e:
print_error(f"图像编码失败: {e}")
print_warning("可能需要网络连接来下载测试图片")
return False
def test_mixed_encoding(client):
"""测试混合编码(文本+图像)"""
print("\n" + "="*50)
print("测试 5: 混合编码")
print("="*50)
try:
# 准备混合数据
mixed_data = [
'这是一段测试文本',
'https://picsum.photos/224?random=10',
'CN-CLIP 图像编码',
'https://picsum.photos/224?random=11',
]
print(f"混合数据数量: {len(mixed_data)}")
print(f" 文本: 2 条")
print(f" 图像: 2 条")
# 执行编码
start_time = time.time()
embeddings = client.encode(mixed_data)
elapsed_time = time.time() - start_time
# 验证结果
assert embeddings.shape[0] == len(mixed_data), "向量数量不匹配"
assert embeddings.shape[1] == 1024, "向量维度应该是 1024"
print_success(f"混合编码成功")
print(f" 向量形状: {embeddings.shape}")
print(f" 耗时: {elapsed_time:.2f}秒")
return True
except Exception as e:
print_error(f"混合编码失败: {e}")
return False
def test_single_encoding(client):
"""测试单个数据编码"""
print("\n" + "="*50)
print("测试 6: 单个数据编码")
print("="*50)
try:
# 测试单个文本
single_text = '测试文本'
print(f"输入: {single_text}")
start_time = time.time()
embedding = client.encode(single_text)
elapsed_time = time.time() - start_time
# 注意:单个数据会返回 (1, 1024) 的形状
if embedding.ndim == 1:
embedding = embedding.reshape(1, -1)
assert embedding.shape == (1, 1024), f"向量形状应该是 (1, 1024), 实际是 {embedding.shape}"
print_success(f"单个文本编码成功")
print(f" 向量形状: {embedding.shape}")
print(f" 耗时: {elapsed_time:.2f}秒")
return True
except Exception as e:
print_error(f"单个数据编码失败: {e}")
return False
def main():
parser = argparse.ArgumentParser(description='CN-CLIP 服务测试脚本')
parser.add_argument('--url',
default='grpc://localhost:51000',
help='服务地址(默认:grpc://localhost:51000)')
parser.add_argument('--text',
action='store_true',
help='只测试文本编码')
parser.add_argument('--image',
action='store_true',
help='只测试图像编码')
parser.add_argument('--batch-size',
type=int,
default=10,
help='批处理大小(默认:10)')
args = parser.parse_args()
print("\n" + "="*50)
print("CN-CLIP 服务测试")
print("="*50)
# 测试 1: 检查依赖
if not test_imports():
sys.exit(1)
# 测试 2: 连接服务
client = test_connection(args.url)
if not client:
sys.exit(1)
# 运行测试
results = []
if args.text:
# 只测试文本编码
results.append(test_text_encoding(client, args.batch_size))
elif args.image:
# 只测试图像编码
results.append(test_image_encoding(client, args.batch_size))
else:
# 运行所有测试
results.append(test_text_encoding(client, args.batch_size))
results.append(test_image_encoding(client, min(args.batch_size, 5)))
results.append(test_mixed_encoding(client))
results.append(test_single_encoding(client))
# 汇总结果
print("\n" + "="*50)
print("测试结果汇总")
print("="*50)
total_tests = len(results)
passed_tests = sum(results)
print(f"总测试数: {total_tests}")
print(f"通过: {passed_tests}")
print(f"失败: {total_tests - passed_tests}")
if passed_tests == total_tests:
print_success("\n所有测试通过!")
sys.exit(0)
else:
print_error("\n部分测试失败")
sys.exit(1)
if __name__ == '__main__':
main()