test_cnclip_service.py
4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python3
"""
CN-CLIP 服务测试脚本
用途:
测试 CN-CLIP 服务的文本和图像编码功能(使用 gRPC 协议)
使用方法:
python tests/test_cnclip_service.py [PORT]
参数:
PORT: 服务端口(默认:51000)
"""
import sys
import os
import numpy as np
# Skip clip_client version check (it imports pkg_resources in legacy path).
os.environ.setdefault("NO_VERSION_CHECK", "1")
# Ensure vendored client is importable in direct `python tests/test_cnclip_service.py` mode.
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
VENDORED_CLIENT = os.path.join(ROOT, "third-party", "clip-as-service", "client")
if os.path.isdir(VENDORED_CLIENT) and VENDORED_CLIENT not in sys.path:
sys.path.insert(0, VENDORED_CLIENT)
try:
from clip_client import Client
except ImportError as e:
print("✗ 无法导入 clip_client。请先安装/暴露客户端依赖:")
print(" 1) pip install -e third-party/clip-as-service/client")
print(" 或")
print(" 2) export PYTHONPATH=third-party/clip-as-service/client:$PYTHONPATH")
print(f" 详细错误: {e}")
sys.exit(1)
def _test_encoding(client, test_name, inputs):
"""测试编码功能"""
print(f"\n{test_name}...")
try:
result = client.encode(inputs)
if isinstance(result, np.ndarray):
print(f"✓ 成功! 形状: {result.shape}")
print(f" 输入数量: {len(inputs)}")
print(f" 输出维度: {result.shape[1]}")
# 显示每个 embedding 的维度和前20个数字
for i in range(min(len(inputs), result.shape[0])):
emb = result[i]
first_20 = emb[:20].tolist()
# 计算 L2 归一化
norm = np.linalg.norm(emb)
normalized_emb = emb / norm if norm > 0 else emb
normalized_first_20 = normalized_emb[:20].tolist()
print(f" input: {inputs[i]}")
print(f" Embedding[{i}] 维度: {len(emb)}")
print(f" 前20个数字: {first_20}")
print(f" normalize后的前20个数字: {normalized_first_20}")
return True
else:
print(f"✗ 失败: 返回类型错误: {type(result)}")
return False
except Exception as e:
print(f"✗ 失败: {e}")
import traceback
traceback.print_exc()
return False
def main():
# 获取端口参数
port = sys.argv[1] if len(sys.argv) > 1 else "51000"
grpc_url = f"grpc://localhost:{port}"
print("=" * 50)
print("CN-CLIP 服务测试")
print("=" * 50)
print(f"服务地址: {grpc_url} (gRPC 协议)")
print()
# 创建客户端
try:
client = Client(grpc_url)
except ModuleNotFoundError as e:
if str(e) == "No module named 'pkg_resources'":
print("✗ 当前环境缺少 pkg_resources,clip_client/jina 无法初始化。")
print(" 请使用专用环境运行(不要在主 .venv 安装旧依赖):")
print(" .venv-embedding/bin/python tests/test_cnclip_service.py 51000")
print(" 或 .venv-cnclip/bin/python tests/test_cnclip_service.py 51000")
sys.exit(1)
print(f"✗ 客户端创建失败: {e}")
sys.exit(1)
except Exception as e:
print(f"✗ 客户端创建失败: {e}")
sys.exit(1)
# 运行测试
results = []
# 测试1: 文本编码
results.append(_test_encoding(
client,
"测试1: 编码文本",
['这是一个测试文本', '另一个测试文本']
))
# 测试2: 图像编码
results.append(_test_encoding(
client,
"测试2: 编码图像(远程 URL)",
['https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg']
))
# 测试3: 混合编码
results.append(_test_encoding(
client,
"测试3: 混合编码(文本和图像)",
['这是一段文本', 'https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg']
))
# 汇总
print("\n" + "=" * 50)
print("测试结果汇总")
print("=" * 50)
print(f"总测试数: {len(results)}")
print(f"通过: {sum(results)}")
print(f"失败: {len(results) - sum(results)}")
if all(results):
print("\n✓ 所有测试通过!")
sys.exit(0)
else:
print("\n✗ 部分测试失败")
sys.exit(1)
if __name__ == '__main__':
main()