Blame view

tests/test_cnclip_service.py 3.16 KB
768ad710   tangwang   MySQL到ES字段映射说明-业务...
1
2
3
  #!/usr/bin/env python3
  """
  CN-CLIP 服务测试脚本
74cca190   tangwang   cnclip
4
  
768ad710   tangwang   MySQL到ES字段映射说明-业务...
5
6
  用途:
      测试 CN-CLIP 服务的文本和图像编码功能(使用 gRPC 协议)
74cca190   tangwang   cnclip
7
  
768ad710   tangwang   MySQL到ES字段映射说明-业务...
8
9
  使用方法:
      python scripts/test_cnclip_service.py [PORT]
74cca190   tangwang   cnclip
10
  
768ad710   tangwang   MySQL到ES字段映射说明-业务...
11
12
13
  参数:
      PORT: 服务端口(默认:51000
  """
74cca190   tangwang   cnclip
14
  
40f1e391   tangwang   cnclip
15
  import sys
74cca190   tangwang   cnclip
16
17
  import numpy as np
  from clip_client import Client
40f1e391   tangwang   cnclip
18
  
768ad710   tangwang   MySQL到ES字段映射说明-业务...
19
  
74cca190   tangwang   cnclip
20
  def test_encoding(client, test_name, inputs):
768ad710   tangwang   MySQL到ES字段映射说明-业务...
21
      """测试编码功能"""
74cca190   tangwang   cnclip
22
      print(f"\n{test_name}...")
40f1e391   tangwang   cnclip
23
      try:
74cca190   tangwang   cnclip
24
25
26
27
28
29
30
31
32
33
          result = client.encode(inputs)
          if isinstance(result, np.ndarray):
              print(f"✓ 成功! 形状: {result.shape}")
              print(f"  输入数量: {len(inputs)}")
              print(f"  输出维度: {result.shape[1]}")
              
              # 显示每个 embedding 的维度和前20个数字
              for i in range(min(len(inputs), result.shape[0])):
                  emb = result[i]
                  first_20 = emb[:20].tolist()
768ad710   tangwang   MySQL到ES字段映射说明-业务...
34
35
36
37
38
39
40
                  
                  # 计算 L2 归一化
                  norm = np.linalg.norm(emb)
                  normalized_emb = emb / norm if norm > 0 else emb
                  normalized_first_20 = normalized_emb[:20].tolist()
                  
                  print(f"  input: {inputs[i]}")
74cca190   tangwang   cnclip
41
42
                  print(f"  Embedding[{i}] 维度: {len(emb)}")
                  print(f"  前20个数字: {first_20}")
768ad710   tangwang   MySQL到ES字段映射说明-业务...
43
                  print(f"  normalize后的前20个数字: {normalized_first_20}")
74cca190   tangwang   cnclip
44
45
46
47
              return True
          else:
              print(f"✗ 失败: 返回类型错误: {type(result)}")
              return False
40f1e391   tangwang   cnclip
48
      except Exception as e:
74cca190   tangwang   cnclip
49
50
51
          print(f"✗ 失败: {e}")
          import traceback
          traceback.print_exc()
40f1e391   tangwang   cnclip
52
53
          return False
  
74cca190   tangwang   cnclip
54
  
768ad710   tangwang   MySQL到ES字段映射说明-业务...
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
  def main():
      # 获取端口参数
      port = sys.argv[1] if len(sys.argv) > 1 else "51000"
      grpc_url = f"grpc://localhost:{port}"
      
      print("=" * 50)
      print("CN-CLIP 服务测试")
      print("=" * 50)
      print(f"服务地址: {grpc_url} (gRPC 协议)")
      print()
      
      # 创建客户端
      try:
          client = Client(grpc_url)
      except Exception as e:
          print(f"✗ 客户端创建失败: {e}")
          sys.exit(1)
      
      # 运行测试
      results = []
      
      # 测试1: 文本编码
      results.append(test_encoding(
          client,
          "测试1: 编码文本",
          ['这是一个测试文本', '另一个测试文本']
      ))
      
      # 测试2: 图像编码
      results.append(test_encoding(
          client,
          "测试2: 编码图像(远程 URL)",
          ['https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg']
      ))
      
      # 测试3: 混合编码
      results.append(test_encoding(
          client,
          "测试3: 混合编码(文本和图像)",
          ['这是一段文本', 'https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg']
      ))
      
      # 汇总
      print("\n" + "=" * 50)
      print("测试结果汇总")
      print("=" * 50)
      print(f"总测试数: {len(results)}")
      print(f"通过: {sum(results)}")
      print(f"失败: {len(results) - sum(results)}")
      
      if all(results):
          print("\n✓ 所有测试通过!")
          sys.exit(0)
      else:
          print("\n✗ 部分测试失败")
          sys.exit(1)
  
  
  if __name__ == '__main__':
      main()
74cca190   tangwang   cnclip