#!/usr/bin/env python3 """ CN-CLIP 服务测试脚本 用法: python scripts/test_cnclip_service.py 选项: --url TEXT 服务地址(默认:grpc://localhost:51000) --text 只测试文本编码 --image 只测试图像编码 --batch-size INT 批处理大小(默认:10) --help 显示帮助信息 """ import sys import time import argparse from pathlib import Path # 添加项目路径到 sys.path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) # 颜色输出 class Colors: GREEN = '\033[0;32m' RED = '\033[0;31m' YELLOW = '\033[1;33m' BLUE = '\033[0;34m' NC = '\033[0m' def print_success(msg): print(f"{Colors.GREEN}✓ {msg}{Colors.NC}") def print_error(msg): print(f"{Colors.RED}✗ {msg}{Colors.NC}") def print_warning(msg): print(f"{Colors.YELLOW}⚠ {msg}{Colors.NC}") def print_info(msg): print(f"{Colors.BLUE}ℹ {msg}{Colors.NC}") def test_imports(): """测试必要的依赖是否安装""" print("\n" + "="*50) print("测试 1: 检查依赖") print("="*50) try: import clip_client print_success("clip_client 已安装") except ImportError as e: print_error(f"clip_client 未安装: {e}") print_info("请运行: pip install clip-client") return False try: import numpy as np print_success("numpy 已安装") except ImportError as e: print_error(f"numpy 未安装: {e}") return False return True def test_connection(url): """测试服务连接""" print("\n" + "="*50) print("测试 2: 连接服务") print("="*50) print(f"服务地址: {url}") try: from clip_client import Client client = Client(url) print_success("客户端创建成功") return client except Exception as e: print_error(f"连接失败: {e}") print_info("请确保服务已启动: ./scripts/start_cnclip_service.sh") return None def test_text_encoding(client, batch_size=10): """测试文本编码""" print("\n" + "="*50) print("测试 3: 文本编码") print("="*50) try: # 准备测试数据 test_texts = [ '你好,世界', 'CN-CLIP 图像编码服务', '这是一个测试', '人工智能', '机器学习', '深度学习', '计算机视觉', '自然语言处理', '搜索引擎', '多模态检索', ][:batch_size] print(f"测试文本数量: {len(test_texts)}") print(f"示例文本: {test_texts[0]}") # 执行编码 start_time = time.time() embeddings = client.encode(test_texts) elapsed_time = time.time() - start_time # 验证结果 assert embeddings.shape[0] == len(test_texts), "向量数量不匹配" assert embeddings.shape[1] == 1024, "向量维度应该是 1024" print_success(f"编码成功") print(f" 向量形状: {embeddings.shape}") print(f" 耗时: {elapsed_time:.2f}秒") print(f" 速度: {len(test_texts)/elapsed_time:.2f} 条/秒") print(f" 数据类型: {embeddings.dtype}") return True except Exception as e: print_error(f"文本编码失败: {e}") return False def test_image_encoding(client, batch_size=5): """测试图像编码""" print("\n" + "="*50) print("测试 4: 图像编码") print("="*50) try: # 准备测试数据(使用在线图片) test_images = [ 'https://picsum.photos/224', 'https://picsum.photos/224?random=1', 'https://picsum.photos/224?random=2', 'https://picsum.photos/224?random=3', 'https://picsum.photos/224?random=4', ][:batch_size] print(f"测试图像数量: {len(test_images)}") print(f"示例 URL: {test_images[0]}") # 执行编码 start_time = time.time() embeddings = client.encode(test_images) elapsed_time = time.time() - start_time # 验证结果 assert embeddings.shape[0] == len(test_images), "向量数量不匹配" assert embeddings.shape[1] == 1024, "向量维度应该是 1024" print_success(f"编码成功") print(f" 向量形状: {embeddings.shape}") print(f" 耗时: {elapsed_time:.2f}秒") print(f" 速度: {len(test_images)/elapsed_time:.2f} 条/秒") print(f" 数据类型: {embeddings.dtype}") return True except Exception as e: print_error(f"图像编码失败: {e}") print_warning("可能需要网络连接来下载测试图片") return False def test_mixed_encoding(client): """测试混合编码(文本+图像)""" print("\n" + "="*50) print("测试 5: 混合编码") print("="*50) try: # 准备混合数据 mixed_data = [ '这是一段测试文本', 'https://picsum.photos/224?random=10', 'CN-CLIP 图像编码', 'https://picsum.photos/224?random=11', ] print(f"混合数据数量: {len(mixed_data)}") print(f" 文本: 2 条") print(f" 图像: 2 条") # 执行编码 start_time = time.time() embeddings = client.encode(mixed_data) elapsed_time = time.time() - start_time # 验证结果 assert embeddings.shape[0] == len(mixed_data), "向量数量不匹配" assert embeddings.shape[1] == 1024, "向量维度应该是 1024" print_success(f"混合编码成功") print(f" 向量形状: {embeddings.shape}") print(f" 耗时: {elapsed_time:.2f}秒") return True except Exception as e: print_error(f"混合编码失败: {e}") return False def test_single_encoding(client): """测试单个数据编码""" print("\n" + "="*50) print("测试 6: 单个数据编码") print("="*50) try: # 测试单个文本 single_text = '测试文本' print(f"输入: {single_text}") start_time = time.time() embedding = client.encode(single_text) elapsed_time = time.time() - start_time # 注意:单个数据会返回 (1, 1024) 的形状 if embedding.ndim == 1: embedding = embedding.reshape(1, -1) assert embedding.shape == (1, 1024), f"向量形状应该是 (1, 1024), 实际是 {embedding.shape}" print_success(f"单个文本编码成功") print(f" 向量形状: {embedding.shape}") print(f" 耗时: {elapsed_time:.2f}秒") return True except Exception as e: print_error(f"单个数据编码失败: {e}") return False def main(): parser = argparse.ArgumentParser(description='CN-CLIP 服务测试脚本') parser.add_argument('--url', default='grpc://localhost:51000', help='服务地址(默认:grpc://localhost:51000)') parser.add_argument('--text', action='store_true', help='只测试文本编码') parser.add_argument('--image', action='store_true', help='只测试图像编码') parser.add_argument('--batch-size', type=int, default=10, help='批处理大小(默认:10)') args = parser.parse_args() print("\n" + "="*50) print("CN-CLIP 服务测试") print("="*50) # 测试 1: 检查依赖 if not test_imports(): sys.exit(1) # 测试 2: 连接服务 client = test_connection(args.url) if not client: sys.exit(1) # 运行测试 results = [] if args.text: # 只测试文本编码 results.append(test_text_encoding(client, args.batch_size)) elif args.image: # 只测试图像编码 results.append(test_image_encoding(client, args.batch_size)) else: # 运行所有测试 results.append(test_text_encoding(client, args.batch_size)) results.append(test_image_encoding(client, min(args.batch_size, 5))) results.append(test_mixed_encoding(client)) results.append(test_single_encoding(client)) # 汇总结果 print("\n" + "="*50) print("测试结果汇总") print("="*50) total_tests = len(results) passed_tests = sum(results) print(f"总测试数: {total_tests}") print(f"通过: {passed_tests}") print(f"失败: {total_tests - passed_tests}") if passed_tests == total_tests: print_success("\n所有测试通过!") sys.exit(0) else: print_error("\n部分测试失败") sys.exit(1) if __name__ == '__main__': main()