test_cnclip_client.py 2.97 KB
#!/usr/bin/env python3
"""
CN-CLIP 服务客户端测试脚本

用法:
    python scripts/test_cnclip_client.py [--url URL]

注意:如果服务配置了 protocol: http,必须使用 http:// 而不是 grpc://
"""

import sys
import argparse
from pathlib import Path

# 添加项目路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

try:
    from clip_client import Client
except ImportError:
    print("错误: 请先安装 clip-client: pip install clip-client")
    sys.exit(1)


def test_text_encoding(client):
    """测试文本编码"""
    print("\n测试文本编码...")
    try:
        texts = ['这是测试文本', '另一个测试文本']
        result = client.encode(texts)
        print(f"✓ 成功! 形状: {result.shape}")
        print(f"  输入: {len(texts)} 个文本")
        print(f"  输出维度: {result.shape[1]}")
        return True
    except Exception as e:
        print(f"✗ 失败: {e}")
        return False


def test_image_encoding(client):
    """测试图像编码"""
    print("\n测试图像编码...")
    try:
        images = ['https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg']
        result = client.encode(images)
        print(f"✓ 成功! 形状: {result.shape}")
        print(f"  输入: {len(images)} 个图像")
        print(f"  输出维度: {result.shape[1]}")
        return True
    except Exception as e:
        print(f"✗ 失败: {e}")
        print("  注意: CN-CLIP 的图像编码可能存在兼容性问题")
        return False


def main():
    parser = argparse.ArgumentParser(description='CN-CLIP 服务客户端测试')
    parser.add_argument(
        '--url',
        default='http://localhost:51000',
        help='服务地址(默认: http://localhost:51000)'
    )
    
    args = parser.parse_args()
    
    print("=" * 50)
    print("CN-CLIP 服务客户端测试")
    print("=" * 50)
    print(f"服务地址: {args.url}")
    
    # 检查协议
    if args.url.startswith('grpc://'):
        print("\n⚠ 警告: 服务配置了 protocol: http,请使用 http:// 而不是 grpc://")
        print("  将自动转换为 http://")
        args.url = args.url.replace('grpc://', 'http://')
        print(f"  新地址: {args.url}")
    
    try:
        client = Client(args.url)
        print("✓ 客户端创建成功")
    except Exception as e:
        print(f"✗ 客户端创建失败: {e}")
        sys.exit(1)
    
    # 运行测试
    results = []
    results.append(test_text_encoding(client))
    results.append(test_image_encoding(client))
    
    # 汇总
    print("\n" + "=" * 50)
    print("测试结果汇总")
    print("=" * 50)
    print(f"总测试数: {len(results)}")
    print(f"通过: {sum(results)}")
    print(f"失败: {len(results) - sum(results)}")
    
    if all(results):
        print("\n✓ 所有测试通过!")
        sys.exit(0)
    else:
        print("\n✗ 部分测试失败")
        sys.exit(1)


if __name__ == '__main__':
    main()