#!/usr/bin/env python3 """ CN-CLIP 服务测试脚本 用途: 测试 CN-CLIP 服务的文本和图像编码功能(使用 gRPC 协议) 使用方法: python scripts/test_cnclip_service.py [PORT] 参数: PORT: 服务端口(默认:51000) """ import sys import numpy as np from clip_client import Client def test_encoding(client, test_name, inputs): """测试编码功能""" print(f"\n{test_name}...") try: result = client.encode(inputs) if isinstance(result, np.ndarray): print(f"✓ 成功! 形状: {result.shape}") print(f" 输入数量: {len(inputs)}") print(f" 输出维度: {result.shape[1]}") # 显示每个 embedding 的维度和前20个数字 for i in range(min(len(inputs), result.shape[0])): emb = result[i] first_20 = emb[:20].tolist() # 计算 L2 归一化 norm = np.linalg.norm(emb) normalized_emb = emb / norm if norm > 0 else emb normalized_first_20 = normalized_emb[:20].tolist() print(f" input: {inputs[i]}") print(f" Embedding[{i}] 维度: {len(emb)}") print(f" 前20个数字: {first_20}") print(f" normalize后的前20个数字: {normalized_first_20}") return True else: print(f"✗ 失败: 返回类型错误: {type(result)}") return False except Exception as e: print(f"✗ 失败: {e}") import traceback traceback.print_exc() return False def main(): # 获取端口参数 port = sys.argv[1] if len(sys.argv) > 1 else "51000" grpc_url = f"grpc://localhost:{port}" print("=" * 50) print("CN-CLIP 服务测试") print("=" * 50) print(f"服务地址: {grpc_url} (gRPC 协议)") print() # 创建客户端 try: client = Client(grpc_url) except Exception as e: print(f"✗ 客户端创建失败: {e}") sys.exit(1) # 运行测试 results = [] # 测试1: 文本编码 results.append(test_encoding( client, "测试1: 编码文本", ['这是一个测试文本', '另一个测试文本'] )) # 测试2: 图像编码 results.append(test_encoding( client, "测试2: 编码图像(远程 URL)", ['https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg'] )) # 测试3: 混合编码 results.append(test_encoding( client, "测试3: 混合编码(文本和图像)", ['这是一段文本', 'https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg'] )) # 汇总 print("\n" + "=" * 50) print("测试结果汇总") print("=" * 50) print(f"总测试数: {len(results)}") print(f"通过: {sum(results)}") print(f"失败: {len(results) - sum(results)}") if all(results): print("\n✓ 所有测试通过!") sys.exit(0) else: print("\n✗ 部分测试失败") sys.exit(1) if __name__ == '__main__': main()