test_cnclip_example.py 4.89 KB
#!/usr/bin/env python3
"""
CN-CLIP 快速测试脚本

测试文本和图像编码功能
"""

from clip_client import Client
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# 测试图片
TEST_IMAGE = "https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg"

# 测试文本
TEST_TEXTS = [
    "一只可爱的猫咪",
    "美丽的高山风景",
    "汽车在公路上行驶",
    "现代建筑",
]

def test_connection():
    """测试服务连接"""
    print("=" * 60)
    print("测试 1: 连接服务")
    print("=" * 60)

    try:
        client = Client('grpc://localhost:51000')
        print("✓ 服务连接成功")
        return client
    except Exception as e:
        print(f"✗ 连接失败: {e}")
        print("\n请确保服务已启动:")
        print("  ./scripts/start_cnclip_service.sh")
        return None

def test_text_encoding(client):
    """测试文本编码"""
    print("\n" + "=" * 60)
    print("测试 2: 文本编码")
    print("=" * 60)

    print(f"\n测试文本:")
    for i, text in enumerate(TEST_TEXTS, 1):
        print(f"  {i}. {text}")

    try:
        embeddings = client.encode(TEST_TEXTS)
        print(f"\n✓ 文本编码成功")
        print(f"  编码数量: {len(embeddings)}")
        print(f"  向量形状: {embeddings.shape}")
        print(f"  数据类型: {embeddings.dtype}")
        print(f"  值域: [{embeddings.min():.4f}, {embeddings.max():.4f}]")
        return embeddings
    except Exception as e:
        print(f"✗ 文本编码失败: {e}")
        return None

def test_image_encoding(client):
    """测试图像编码"""
    print("\n" + "=" * 60)
    print("测试 3: 图像编码")
    print("=" * 60)

    print(f"\n测试图片: {TEST_IMAGE}")

    try:
        embeddings = client.encode([TEST_IMAGE])
        print(f"\n✓ 图像编码成功")
        print(f"  向量形状: {embeddings.shape}")
        print(f"  数据类型: {embeddings.dtype}")
        print(f"  值域: [{embeddings.min():.4f}, {embeddings.max():.4f}]")
        return embeddings
    except Exception as e:
        print(f"✗ 图像编码失败: {e}")
        return None

def test_image_text_retrieval(client, image_embedding, text_embeddings):
    """测试图文检索"""
    print("\n" + "=" * 60)
    print("测试 4: 图文检索(计算相似度)")
    print("=" * 60)

    print(f"\n使用图片搜索最匹配的文本...")

    try:
        # 计算相似度
        similarities = cosine_similarity(image_embedding, text_embeddings)[0]

        print(f"\n相似度排序:")
        # 按相似度排序
        sorted_indices = np.argsort(similarities)[::-1]

        for rank, idx in enumerate(sorted_indices, 1):
            text = TEST_TEXTS[idx]
            score = similarities[idx]
            bar = "█" * int(score * 50)
            print(f"  {rank}. {score:.4f} {bar} {text}")

        print(f"\n最佳匹配: {TEST_TEXTS[sorted_indices[0]]}")
        print(f"相似度分数: {similarities[sorted_indices[0]]:.4f}")

        return similarities
    except Exception as e:
        print(f"✗ 相似度计算失败: {e}")
        return None

def test_batch_encoding(client):
    """测试批量编码"""
    print("\n" + "=" * 60)
    print("测试 5: 批量编码性能")
    print("=" * 60)

    import time

    # 准备测试数据
    batch_texts = [f"测试文本 {i}" for i in range(50)]

    print(f"\n编码 {len(batch_texts)} 条文本...")

    try:
        start = time.time()
        embeddings = client.encode(batch_texts)
        elapsed = time.time() - start

        print(f"\n✓ 批量编码成功")
        print(f"  耗时: {elapsed:.2f}秒")
        print(f"  速度: {len(batch_texts)/elapsed:.2f} 条/秒")
        print(f"  平均延迟: {elapsed/len(batch_texts)*1000:.2f}ms/条")

    except Exception as e:
        print(f"✗ 批量编码失败: {e}")

def main():
    print("\n" + "=" * 60)
    print("CN-CLIP 服务测试")
    print("=" * 60)
    print(f"\n测试图片: {TEST_IMAGE}")
    print(f"服务地址: grpc://localhost:51000")

    # 测试连接
    client = test_connection()
    if not client:
        return

    # 测试文本编码
    text_embeddings = test_text_encoding(client)
    if text_embeddings is None:
        return

    # 测试图像编码
    image_embeddings = test_image_encoding(client)
    if image_embeddings is None:
        return

    # 测试图文检索
    test_image_text_retrieval(client, image_embeddings, text_embeddings)

    # 测试批量编码性能
    test_batch_encoding(client)

    # 总结
    print("\n" + "=" * 60)
    print("测试总结")
    print("=" * 60)
    print("\n✓ 所有测试通过!")
    print("\n服务运行正常,可以开始使用。")
    print("\n下一步:")
    print("  1. 查看使用文档: cat docs/CNCLIP_USAGE_GUIDE.md")
    print("  2. 集成到你的项目")
    print("  3. 调整服务配置(如需要)")
    print()

if __name__ == '__main__':
    main()