deploy_models.py 6.42 KB
#!/usr/bin/env python3
"""
Qwen3 模型部署脚本
自动部署 Qwen3-Embedding 和 Qwen3-Reranker 模型
"""

import time
import sys
from xinference_client import RESTfulClient as Client


def print_section(title):
    """打印分节标题"""
    print("\n" + "="*60)
    print(f"  {title}")
    print("="*60 + "\n")


def deploy_qwen3_models(host="http://localhost:9997", gpu_idx=[0]):
    """
    部署 Qwen3 模型

    Args:
        host: Xinference 服务地址
        gpu_idx: GPU 索引
    """
    print_section("Qwen3 模型自动部署")

    # 连接到 Xinference 服务
    print(f"🔗 连接到 Xinference 服务: {host}")
    try:
        client = Client(host)
        print("✅ 连接成功!\n")
    except Exception as e:
        print(f"❌ 连接失败: {e}")
        print("\n💡 请确保 Xinference 服务已启动:")
        print("   ./start.sh")
        sys.exit(1)

    # 部署 Qwen3-Embedding 模型
    print_section("部署 Qwen3-Embedding 模型 (4B)")
    print("⏳ 正在部署,首次运行需要下载模型,请耐心等待...")
    print("   模型大小: ~8GB")
    print("   上下文长度: 8192 tokens")
    print("   向量维度: 1024\n")

    try:
        embedding_uid = client.launch_model(
            model_name="qwen3-embedding",
            model_size_in_billions=4,
            model_type="embedding",
            engine="vllm",
            gpu_idx=gpu_idx,
        )
        print(f"✅ Qwen3-Embedding 部署成功!")
        print(f"   模型 UID: {embedding_uid}\n")

        # 等待模型加载完成
        print("⏳ 等待模型完全加载...")
        time.sleep(5)

        # 测试模型
        embedding_model = client.get_model(embedding_uid)
        test_result = embedding_model.create_embedding("测试文本")
        if test_result and "data" in test_result:
            vector_dim = len(test_result["data"][0]["embedding"])
            print(f"✅ 模型测试成功!向量维度: {vector_dim}\n")
        else:
            print("⚠️  模型部署成功但测试失败\n")

    except Exception as e:
        print(f"❌ Qwen3-Embedding 部署失败: {e}\n")
        return None

    # 部署 Qwen3-Reranker 模型
    print_section("部署 Qwen3-Reranker 模型 (4B)")
    print("⏳ 正在部署,首次运行需要下载模型,请耐心等待...")
    print("   模型大小: ~8GB")
    print("   架构: Cross-Encoder\n")

    try:
        reranker_uid = client.launch_model(
            model_name="qwen3-reranker",
            model_size_in_billions=4,
            model_type="rerank",
            engine="vllm",
            gpu_idx=gpu_idx,
        )
        print(f"✅ Qwen3-Reranker 部署成功!")
        print(f"   模型 UID: {reranker_uid}\n")

        # 等待模型加载完成
        print("⏳ 等待模型完全加载...")
        time.sleep(5)

        # 测试模型
        reranker_model = client.get_model(reranker_uid)
        test_result = reranker_model.rerank(
            [("测试查询", "测试文档")]
        )
        if test_result and len(test_result) > 0:
            print(f"✅ 模型测试成功!\n")
        else:
            print("⚠️  模型部署成功但测试失败\n")

    except Exception as e:
        print(f"❌ Qwen3-Reranker 部署失败: {e}")
        print("💡 可能的原因: GPU 显存不足,请尝试:")
        print("   1. 使用不同的 GPU 索引: python deploy_models.py --gpu 1")
        print("   2. 只部署 embedding 模型: python deploy_models.py --embedding-only")
        return None

    # 显示部署摘要
    print_section("🎉 模型部署完成!")
    print(f"✅ Qwen3-Embedding UID: {embedding_uid}")
    print(f"✅ Qwen3-Reranker UID: {reranker_uid}")
    print("\n📝 下一步:")
    print("   1. 运行电商搜索示例: python ecommerce_demo.py")
    print("   2. 查看 API 调用示例: cat api_examples.sh")
    print("   3. 查看 Dashboard: http://localhost:9998")
    print("   4. 查看所有模型: curl http://localhost:9997/v1/models")
    print("")

    return {
        "embedding_uid": embedding_uid,
        "reranker_uid": reranker_uid
    }


def list_models(host="http://localhost:9997"):
    """列出所有已部署的模型"""
    print_section("已部署模型列表")
    try:
        client = Client(host)
        models = client.list_models()

        if not models:
            print("📭 当前没有已部署的模型")
        else:
            for model in models:
                model_type = model.get("model_type", "unknown")
                model_uid = model.get("model_uid", "unknown")
                print(f"📦 {model_type.upper()}: {model_uid}")
        print()
    except Exception as e:
        print(f"❌ 获取模型列表失败: {e}\n")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="部署 Qwen3 模型到 Xinference")
    parser.add_argument("--host", default="http://localhost:9997", help="Xinference 服务地址")
    parser.add_argument("--gpu", default="0", help="GPU 索引(逗号分隔,如: 0 或 0,1)")
    parser.add_argument("--embedding-only", action="store_true", help="仅部署 embedding 模型")
    parser.add_argument("--reranker-only", action="store_true", help="仅部署 reranker 模型")
    parser.add_argument("--list", action="store_true", help="列出已部署的模型")

    args = parser.parse_args()

    # 将 GPU 字符串转换为列表
    gpu_idx = [int(x.strip()) for x in args.gpu.split(",")]

    if args.list:
        list_models(args.host)
    elif args.embedding_only:
        # 仅部署 embedding
        print_section("部署 Qwen3-Embedding 模型 (4B)")
        client = Client(args.host)
        embedding_uid = client.launch_model(
            model_name="qwen3-embedding",
            model_size_in_billions=4,
            model_type="embedding",
            engine="vllm",
            gpu_idx=gpu_idx,
        )
        print(f"✅ Embedding 模型部署成功: {embedding_uid}")
    elif args.reranker_only:
        # 仅部署 reranker
        print_section("部署 Qwen3-Reranker 模型 (4B)")
        client = Client(args.host)
        reranker_uid = client.launch_model(
            model_name="qwen3-reranker",
            model_size_in_billions=4,
            model_type="rerank",
            engine="vllm",
            gpu_idx=gpu_idx,
        )
        print(f"✅ Reranker 模型部署成功: {reranker_uid}")
    else:
        # 部署所有模型
        deploy_qwen3_models(args.host, gpu_idx)