768ad710
tangwang
MySQL到ES字段映射说明-业务...
|
1
2
3
|
#!/usr/bin/env python3
"""
CN-CLIP 服务测试脚本
|
74cca190
tangwang
cnclip
|
4
|
|
768ad710
tangwang
MySQL到ES字段映射说明-业务...
|
5
6
|
用途:
测试 CN-CLIP 服务的文本和图像编码功能(使用 gRPC 协议)
|
74cca190
tangwang
cnclip
|
7
|
|
768ad710
tangwang
MySQL到ES字段映射说明-业务...
|
8
9
|
使用方法:
python scripts/test_cnclip_service.py [PORT]
|
74cca190
tangwang
cnclip
|
10
|
|
768ad710
tangwang
MySQL到ES字段映射说明-业务...
|
11
12
13
|
参数:
PORT: 服务端口(默认:51000)
"""
|
74cca190
tangwang
cnclip
|
14
|
|
40f1e391
tangwang
cnclip
|
15
|
import sys
|
74cca190
tangwang
cnclip
|
16
17
|
import numpy as np
from clip_client import Client
|
40f1e391
tangwang
cnclip
|
18
|
|
768ad710
tangwang
MySQL到ES字段映射说明-业务...
|
19
|
|
74cca190
tangwang
cnclip
|
20
|
def test_encoding(client, test_name, inputs):
|
768ad710
tangwang
MySQL到ES字段映射说明-业务...
|
21
|
"""测试编码功能"""
|
74cca190
tangwang
cnclip
|
22
|
print(f"\n{test_name}...")
|
40f1e391
tangwang
cnclip
|
23
|
try:
|
74cca190
tangwang
cnclip
|
24
25
26
27
28
29
30
31
32
33
|
result = client.encode(inputs)
if isinstance(result, np.ndarray):
print(f"✓ 成功! 形状: {result.shape}")
print(f" 输入数量: {len(inputs)}")
print(f" 输出维度: {result.shape[1]}")
# 显示每个 embedding 的维度和前20个数字
for i in range(min(len(inputs), result.shape[0])):
emb = result[i]
first_20 = emb[:20].tolist()
|
768ad710
tangwang
MySQL到ES字段映射说明-业务...
|
34
35
36
37
38
39
40
|
# 计算 L2 归一化
norm = np.linalg.norm(emb)
normalized_emb = emb / norm if norm > 0 else emb
normalized_first_20 = normalized_emb[:20].tolist()
print(f" input: {inputs[i]}")
|
74cca190
tangwang
cnclip
|
41
42
|
print(f" Embedding[{i}] 维度: {len(emb)}")
print(f" 前20个数字: {first_20}")
|
768ad710
tangwang
MySQL到ES字段映射说明-业务...
|
43
|
print(f" normalize后的前20个数字: {normalized_first_20}")
|
74cca190
tangwang
cnclip
|
44
45
46
47
|
return True
else:
print(f"✗ 失败: 返回类型错误: {type(result)}")
return False
|
40f1e391
tangwang
cnclip
|
48
|
except Exception as e:
|
74cca190
tangwang
cnclip
|
49
50
51
|
print(f"✗ 失败: {e}")
import traceback
traceback.print_exc()
|
40f1e391
tangwang
cnclip
|
52
53
|
return False
|
74cca190
tangwang
cnclip
|
54
|
|
768ad710
tangwang
MySQL到ES字段映射说明-业务...
|
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
def main():
# 获取端口参数
port = sys.argv[1] if len(sys.argv) > 1 else "51000"
grpc_url = f"grpc://localhost:{port}"
print("=" * 50)
print("CN-CLIP 服务测试")
print("=" * 50)
print(f"服务地址: {grpc_url} (gRPC 协议)")
print()
# 创建客户端
try:
client = Client(grpc_url)
except Exception as e:
print(f"✗ 客户端创建失败: {e}")
sys.exit(1)
# 运行测试
results = []
# 测试1: 文本编码
results.append(test_encoding(
client,
"测试1: 编码文本",
['这是一个测试文本', '另一个测试文本']
))
# 测试2: 图像编码
results.append(test_encoding(
client,
"测试2: 编码图像(远程 URL)",
['https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg']
))
# 测试3: 混合编码
results.append(test_encoding(
client,
"测试3: 混合编码(文本和图像)",
['这是一段文本', 'https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg']
))
# 汇总
print("\n" + "=" * 50)
print("测试结果汇总")
print("=" * 50)
print(f"总测试数: {len(results)}")
print(f"通过: {sum(results)}")
print(f"失败: {len(results) - sum(results)}")
if all(results):
print("\n✓ 所有测试通过!")
sys.exit(0)
else:
print("\n✗ 部分测试失败")
sys.exit(1)
if __name__ == '__main__':
main()
|
74cca190
tangwang
cnclip
|
|
|