test_cnclip_service.py 8.76 KB
#!/usr/bin/env python3
"""
CN-CLIP 服务测试脚本

用法:
    python scripts/test_cnclip_service.py

选项:
    --url TEXT       服务地址(默认:grpc://localhost:51000)
    --text           只测试文本编码
    --image          只测试图像编码
    --batch-size INT 批处理大小(默认:10)
    --help           显示帮助信息
"""

import sys
import time
import argparse
from pathlib import Path

# 添加项目路径到 sys.path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

# 颜色输出
class Colors:
    GREEN = '\033[0;32m'
    RED = '\033[0;31m'
    YELLOW = '\033[1;33m'
    BLUE = '\033[0;34m'
    NC = '\033[0m'


def print_success(msg):
    print(f"{Colors.GREEN}✓ {msg}{Colors.NC}")


def print_error(msg):
    print(f"{Colors.RED}✗ {msg}{Colors.NC}")


def print_warning(msg):
    print(f"{Colors.YELLOW}⚠ {msg}{Colors.NC}")


def print_info(msg):
    print(f"{Colors.BLUE}ℹ {msg}{Colors.NC}")


def test_imports():
    """测试必要的依赖是否安装"""
    print("\n" + "="*50)
    print("测试 1: 检查依赖")
    print("="*50)

    try:
        import clip_client
        print_success("clip_client 已安装")
    except ImportError as e:
        print_error(f"clip_client 未安装: {e}")
        print_info("请运行: pip install clip-client")
        return False

    try:
        import numpy as np
        print_success("numpy 已安装")
    except ImportError as e:
        print_error(f"numpy 未安装: {e}")
        return False

    return True


def test_connection(url):
    """测试服务连接"""
    print("\n" + "="*50)
    print("测试 2: 连接服务")
    print("="*50)
    print(f"服务地址: {url}")

    try:
        from clip_client import Client

        client = Client(url)
        print_success("客户端创建成功")
        return client
    except Exception as e:
        print_error(f"连接失败: {e}")
        print_info("请确保服务已启动: ./scripts/start_cnclip_service.sh")
        return None


def test_text_encoding(client, batch_size=10):
    """测试文本编码"""
    print("\n" + "="*50)
    print("测试 3: 文本编码")
    print("="*50)

    try:
        # 准备测试数据
        test_texts = [
            '你好,世界',
            'CN-CLIP 图像编码服务',
            '这是一个测试',
            '人工智能',
            '机器学习',
            '深度学习',
            '计算机视觉',
            '自然语言处理',
            '搜索引擎',
            '多模态检索',
        ][:batch_size]

        print(f"测试文本数量: {len(test_texts)}")
        print(f"示例文本: {test_texts[0]}")

        # 执行编码
        start_time = time.time()
        embeddings = client.encode(test_texts)
        elapsed_time = time.time() - start_time

        # 验证结果
        assert embeddings.shape[0] == len(test_texts), "向量数量不匹配"
        assert embeddings.shape[1] == 1024, "向量维度应该是 1024"

        print_success(f"编码成功")
        print(f"  向量形状: {embeddings.shape}")
        print(f"  耗时: {elapsed_time:.2f}秒")
        print(f"  速度: {len(test_texts)/elapsed_time:.2f} 条/秒")
        print(f"  数据类型: {embeddings.dtype}")

        return True

    except Exception as e:
        print_error(f"文本编码失败: {e}")
        return False


def test_image_encoding(client, batch_size=5):
    """测试图像编码"""
    print("\n" + "="*50)
    print("测试 4: 图像编码")
    print("="*50)

    try:
        # 准备测试数据(使用在线图片)
        test_images = [
            'https://picsum.photos/224',
            'https://picsum.photos/224?random=1',
            'https://picsum.photos/224?random=2',
            'https://picsum.photos/224?random=3',
            'https://picsum.photos/224?random=4',
        ][:batch_size]

        print(f"测试图像数量: {len(test_images)}")
        print(f"示例 URL: {test_images[0]}")

        # 执行编码
        start_time = time.time()
        embeddings = client.encode(test_images)
        elapsed_time = time.time() - start_time

        # 验证结果
        assert embeddings.shape[0] == len(test_images), "向量数量不匹配"
        assert embeddings.shape[1] == 1024, "向量维度应该是 1024"

        print_success(f"编码成功")
        print(f"  向量形状: {embeddings.shape}")
        print(f"  耗时: {elapsed_time:.2f}秒")
        print(f"  速度: {len(test_images)/elapsed_time:.2f} 条/秒")
        print(f"  数据类型: {embeddings.dtype}")

        return True

    except Exception as e:
        print_error(f"图像编码失败: {e}")
        print_warning("可能需要网络连接来下载测试图片")
        return False


def test_mixed_encoding(client):
    """测试混合编码(文本+图像)"""
    print("\n" + "="*50)
    print("测试 5: 混合编码")
    print("="*50)

    try:
        # 准备混合数据
        mixed_data = [
            '这是一段测试文本',
            'https://picsum.photos/224?random=10',
            'CN-CLIP 图像编码',
            'https://picsum.photos/224?random=11',
        ]

        print(f"混合数据数量: {len(mixed_data)}")
        print(f"  文本: 2 条")
        print(f"  图像: 2 条")

        # 执行编码
        start_time = time.time()
        embeddings = client.encode(mixed_data)
        elapsed_time = time.time() - start_time

        # 验证结果
        assert embeddings.shape[0] == len(mixed_data), "向量数量不匹配"
        assert embeddings.shape[1] == 1024, "向量维度应该是 1024"

        print_success(f"混合编码成功")
        print(f"  向量形状: {embeddings.shape}")
        print(f"  耗时: {elapsed_time:.2f}秒")

        return True

    except Exception as e:
        print_error(f"混合编码失败: {e}")
        return False


def test_single_encoding(client):
    """测试单个数据编码"""
    print("\n" + "="*50)
    print("测试 6: 单个数据编码")
    print("="*50)

    try:
        # 测试单个文本
        single_text = '测试文本'
        print(f"输入: {single_text}")

        start_time = time.time()
        embedding = client.encode(single_text)
        elapsed_time = time.time() - start_time

        # 注意:单个数据会返回 (1, 1024) 的形状
        if embedding.ndim == 1:
            embedding = embedding.reshape(1, -1)

        assert embedding.shape == (1, 1024), f"向量形状应该是 (1, 1024), 实际是 {embedding.shape}"

        print_success(f"单个文本编码成功")
        print(f"  向量形状: {embedding.shape}")
        print(f"  耗时: {elapsed_time:.2f}秒")

        return True

    except Exception as e:
        print_error(f"单个数据编码失败: {e}")
        return False


def main():
    parser = argparse.ArgumentParser(description='CN-CLIP 服务测试脚本')
    parser.add_argument('--url',
                       default='grpc://localhost:51000',
                       help='服务地址(默认:grpc://localhost:51000)')
    parser.add_argument('--text',
                       action='store_true',
                       help='只测试文本编码')
    parser.add_argument('--image',
                       action='store_true',
                       help='只测试图像编码')
    parser.add_argument('--batch-size',
                       type=int,
                       default=10,
                       help='批处理大小(默认:10)')

    args = parser.parse_args()

    print("\n" + "="*50)
    print("CN-CLIP 服务测试")
    print("="*50)

    # 测试 1: 检查依赖
    if not test_imports():
        sys.exit(1)

    # 测试 2: 连接服务
    client = test_connection(args.url)
    if not client:
        sys.exit(1)

    # 运行测试
    results = []

    if args.text:
        # 只测试文本编码
        results.append(test_text_encoding(client, args.batch_size))
    elif args.image:
        # 只测试图像编码
        results.append(test_image_encoding(client, args.batch_size))
    else:
        # 运行所有测试
        results.append(test_text_encoding(client, args.batch_size))
        results.append(test_image_encoding(client, min(args.batch_size, 5)))
        results.append(test_mixed_encoding(client))
        results.append(test_single_encoding(client))

    # 汇总结果
    print("\n" + "="*50)
    print("测试结果汇总")
    print("="*50)

    total_tests = len(results)
    passed_tests = sum(results)

    print(f"总测试数: {total_tests}")
    print(f"通过: {passed_tests}")
    print(f"失败: {total_tests - passed_tests}")

    if passed_tests == total_tests:
        print_success("\n所有测试通过!")
        sys.exit(0)
    else:
        print_error("\n部分测试失败")
        sys.exit(1)


if __name__ == '__main__':
    main()