test_cnclip_service.py 4.5 KB
#!/usr/bin/env python3
"""
CN-CLIP 服务测试脚本

用途:
    测试 CN-CLIP 服务的文本和图像编码功能(使用 gRPC 协议)

使用方法:
    python tests/test_cnclip_service.py [PORT]

参数:
    PORT: 服务端口(默认:51000)
"""

import sys
import os

import numpy as np

# Skip clip_client version check (it imports pkg_resources in legacy path).
os.environ.setdefault("NO_VERSION_CHECK", "1")

# Ensure vendored client is importable in direct `python tests/test_cnclip_service.py` mode.
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
VENDORED_CLIENT = os.path.join(ROOT, "third-party", "clip-as-service", "client")
if os.path.isdir(VENDORED_CLIENT) and VENDORED_CLIENT not in sys.path:
    sys.path.insert(0, VENDORED_CLIENT)

try:
    from clip_client import Client
except ImportError as e:
    print("✗ 无法导入 clip_client。请先安装/暴露客户端依赖:")
    print("  1) pip install -e third-party/clip-as-service/client")
    print("  或")
    print("  2) export PYTHONPATH=third-party/clip-as-service/client:$PYTHONPATH")
    print(f"  详细错误: {e}")
    sys.exit(1)


def _test_encoding(client, test_name, inputs):
    """测试编码功能"""
    print(f"\n{test_name}...")
    try:
        result = client.encode(inputs)
        if isinstance(result, np.ndarray):
            print(f"✓ 成功! 形状: {result.shape}")
            print(f"  输入数量: {len(inputs)}")
            print(f"  输出维度: {result.shape[1]}")
            
            # 显示每个 embedding 的维度和前20个数字
            for i in range(min(len(inputs), result.shape[0])):
                emb = result[i]
                first_20 = emb[:20].tolist()
                
                # 计算 L2 归一化
                norm = np.linalg.norm(emb)
                normalized_emb = emb / norm if norm > 0 else emb
                normalized_first_20 = normalized_emb[:20].tolist()
                
                print(f"  input: {inputs[i]}")
                print(f"  Embedding[{i}] 维度: {len(emb)}")
                print(f"  前20个数字: {first_20}")
                print(f"  normalize后的前20个数字: {normalized_first_20}")
            return True
        else:
            print(f"✗ 失败: 返回类型错误: {type(result)}")
            return False
    except Exception as e:
        print(f"✗ 失败: {e}")
        import traceback
        traceback.print_exc()
        return False


def main():
    # 获取端口参数
    port = sys.argv[1] if len(sys.argv) > 1 else "51000"
    grpc_url = f"grpc://localhost:{port}"
    
    print("=" * 50)
    print("CN-CLIP 服务测试")
    print("=" * 50)
    print(f"服务地址: {grpc_url} (gRPC 协议)")
    print()
    
    # 创建客户端
    try:
        client = Client(grpc_url)
    except ModuleNotFoundError as e:
        if str(e) == "No module named 'pkg_resources'":
            print("✗ 当前环境缺少 pkg_resources,clip_client/jina 无法初始化。")
            print("  请使用专用环境运行(不要在主 .venv 安装旧依赖):")
            print("  .venv-embedding/bin/python tests/test_cnclip_service.py 51000")
            print("  或 .venv-cnclip/bin/python tests/test_cnclip_service.py 51000")
            sys.exit(1)
        print(f"✗ 客户端创建失败: {e}")
        sys.exit(1)
    except Exception as e:
        print(f"✗ 客户端创建失败: {e}")
        sys.exit(1)
    
    # 运行测试
    results = []
    
    # 测试1: 文本编码
    results.append(_test_encoding(
        client,
        "测试1: 编码文本",
        ['这是一个测试文本', '另一个测试文本']
    ))
    
    # 测试2: 图像编码
    results.append(_test_encoding(
        client,
        "测试2: 编码图像(远程 URL)",
        ['https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg']
    ))
    
    # 测试3: 混合编码
    results.append(_test_encoding(
        client,
        "测试3: 混合编码(文本和图像)",
        ['这是一段文本', 'https://oss.essa.cn/98532128-cf8e-456c-9e30-6f2a5ea0c19f.jpg']
    ))
    
    # 汇总
    print("\n" + "=" * 50)
    print("测试结果汇总")
    print("=" * 50)
    print(f"总测试数: {len(results)}")
    print(f"通过: {sum(results)}")
    print(f"失败: {len(results) - sum(results)}")
    
    if all(results):
        print("\n✓ 所有测试通过!")
        sys.exit(0)
    else:
        print("\n✗ 部分测试失败")
        sys.exit(1)


if __name__ == '__main__':
    main()