Blame view

tests/test_cnclip_service.py 4.5 KB
768ad710   tangwang   MySQL到ES字段映射说明-业务...
1
2
3
  #!/usr/bin/env python3
  """
  CN-CLIP 服务测试脚本
74cca190   tangwang   cnclip
4
  
768ad710   tangwang   MySQL到ES字段映射说明-业务...
5
6
  用途:
      测试 CN-CLIP 服务的文本和图像编码功能(使用 gRPC 协议)
74cca190   tangwang   cnclip
7
  
768ad710   tangwang   MySQL到ES字段映射说明-业务...
8
  使用方法:
ed948666   tangwang   tidy
9
      python tests/test_cnclip_service.py [PORT]
74cca190   tangwang   cnclip
10
  
768ad710   tangwang   MySQL到ES字段映射说明-业务...
11
12
13
  参数:
      PORT: 服务端口(默认:51000
  """
74cca190   tangwang   cnclip
14
  
40f1e391   tangwang   cnclip
15
  import sys
ed948666   tangwang   tidy
16
  import os
40f1e391   tangwang   cnclip
17
  
ed948666   tangwang   tidy
18
19
20
21
22
23
24
25
26
27
  import numpy as np
  
  # Skip clip_client version check (it imports pkg_resources in legacy path).
  os.environ.setdefault("NO_VERSION_CHECK", "1")
  
  # Ensure vendored client is importable in direct `python tests/test_cnclip_service.py` mode.
  ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  VENDORED_CLIENT = os.path.join(ROOT, "third-party", "clip-as-service", "client")
  if os.path.isdir(VENDORED_CLIENT) and VENDORED_CLIENT not in sys.path:
      sys.path.insert(0, VENDORED_CLIENT)
768ad710   tangwang   MySQL到ES字段映射说明-业务...
28
  
7299bae6   tangwang   tests
29
  try:
7299bae6   tangwang   tests
30
      from clip_client import Client
ed948666   tangwang   tidy
31
32
33
34
35
36
37
  except ImportError as e:
      print("✗ 无法导入 clip_client。请先安装/暴露客户端依赖:")
      print("  1) pip install -e third-party/clip-as-service/client")
      print("  或")
      print("  2) export PYTHONPATH=third-party/clip-as-service/client:$PYTHONPATH")
      print(f"  详细错误: {e}")
      sys.exit(1)
7299bae6   tangwang   tests
38
39
40
  
  
  def _test_encoding(client, test_name, inputs):
768ad710   tangwang   MySQL到ES字段映射说明-业务...
41
      """测试编码功能"""
74cca190   tangwang   cnclip
42
      print(f"\n{test_name}...")
40f1e391   tangwang   cnclip
43
      try:
74cca190   tangwang   cnclip
44
45
46
47
48
49
50
51
52
53
          result = client.encode(inputs)
          if isinstance(result, np.ndarray):
              print(f"✓ 成功! 形状: {result.shape}")
              print(f"  输入数量: {len(inputs)}")
              print(f"  输出维度: {result.shape[1]}")
              
              # 显示每个 embedding 的维度和前20个数字
              for i in range(min(len(inputs), result.shape[0])):
                  emb = result[i]
                  first_20 = emb[:20].tolist()
768ad710   tangwang   MySQL到ES字段映射说明-业务...
54
55
56
57
58
59
60
                  
                  # 计算 L2 归一化
                  norm = np.linalg.norm(emb)
                  normalized_emb = emb / norm if norm > 0 else emb
                  normalized_first_20 = normalized_emb[:20].tolist()
                  
                  print(f"  input: {inputs[i]}")
74cca190   tangwang   cnclip
61
62
                  print(f"  Embedding[{i}] 维度: {len(emb)}")
                  print(f"  前20个数字: {first_20}")
768ad710   tangwang   MySQL到ES字段映射说明-业务...
63
                  print(f"  normalize后的前20个数字: {normalized_first_20}")
74cca190   tangwang   cnclip
64
65
66
67
              return True
          else:
              print(f"✗ 失败: 返回类型错误: {type(result)}")
              return False
40f1e391   tangwang   cnclip
68
      except Exception as e:
74cca190   tangwang   cnclip
69
70
71
          print(f"✗ 失败: {e}")
          import traceback
          traceback.print_exc()
40f1e391   tangwang   cnclip
72
73
          return False
  
74cca190   tangwang   cnclip
74
  
768ad710   tangwang   MySQL到ES字段映射说明-业务...
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  def main():
      # 获取端口参数
      port = sys.argv[1] if len(sys.argv) > 1 else "51000"
      grpc_url = f"grpc://localhost:{port}"
      
      print("=" * 50)
      print("CN-CLIP 服务测试")
      print("=" * 50)
      print(f"服务地址: {grpc_url} (gRPC 协议)")
      print()
      
      # 创建客户端
      try:
          client = Client(grpc_url)
ed948666   tangwang   tidy
89
90
91
      except ModuleNotFoundError as e:
          if str(e) == "No module named 'pkg_resources'":
              print("✗ 当前环境缺少 pkg_resources,clip_client/jina 无法初始化。")
07cf5a93   tangwang   START_EMBEDDING=...
92
93
94
              print("  请使用专用环境运行(不要在主 .venv 安装旧依赖):")
              print("  .venv-embedding/bin/python tests/test_cnclip_service.py 51000")
              print("  或 .venv-cnclip/bin/python tests/test_cnclip_service.py 51000")
ed948666   tangwang   tidy
95
96
97
              sys.exit(1)
          print(f"✗ 客户端创建失败: {e}")
          sys.exit(1)
768ad710   tangwang   MySQL到ES字段映射说明-业务...
98
99
100
101
102
103
104
105
      except Exception as e:
          print(f"✗ 客户端创建失败: {e}")
          sys.exit(1)
      
      # 运行测试
      results = []
      
      # 测试1: 文本编码
7299bae6   tangwang   tests
106
      results.append(_test_encoding(
768ad710   tangwang   MySQL到ES字段映射说明-业务...
107
108
109
110
111
112
          client,
          "测试1: 编码文本",
          ['这是一个测试文本', '另一个测试文本']
      ))
      
      # 测试2: 图像编码
7299bae6   tangwang   tests
113
      results.append(_test_encoding(
768ad710   tangwang   MySQL到ES字段映射说明-业务...
114
115
116
117
118
119
          client,
          "测试2: 编码图像(远程 URL)",
          ['https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg']
      ))
      
      # 测试3: 混合编码
7299bae6   tangwang   tests
120
      results.append(_test_encoding(
768ad710   tangwang   MySQL到ES字段映射说明-业务...
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
          client,
          "测试3: 混合编码(文本和图像)",
          ['这是一段文本', 'https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg']
      ))
      
      # 汇总
      print("\n" + "=" * 50)
      print("测试结果汇总")
      print("=" * 50)
      print(f"总测试数: {len(results)}")
      print(f"通过: {sum(results)}")
      print(f"失败: {len(results) - sum(results)}")
      
      if all(results):
          print("\n✓ 所有测试通过!")
          sys.exit(0)
      else:
          print("\n✗ 部分测试失败")
          sys.exit(1)
  
  
  if __name__ == '__main__':
      main()