recreate_and_import.py 6.42 KB
#!/usr/bin/env python3
"""
重建索引并导入数据的脚本。

清除旧索引,使用新的mapping重建索引,然后导入数据。
"""

import sys
import os
import argparse
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from utils.db_connector import create_db_connection
from utils.es_client import ESClient
from indexer.mapping_generator import load_mapping, delete_index_if_exists, DEFAULT_INDEX_NAME
from indexer.spu_transformer import SPUTransformer
from indexer.bulk_indexer import BulkIndexer


def main():
    parser = argparse.ArgumentParser(description='重建ES索引并导入数据')
    
    # Database connection
    parser.add_argument('--db-host', help='MySQL host (或使用环境变量 DB_HOST)')
    parser.add_argument('--db-port', type=int, help='MySQL port (或使用环境变量 DB_PORT, 默认: 3306)')
    parser.add_argument('--db-database', help='MySQL database (或使用环境变量 DB_DATABASE)')
    parser.add_argument('--db-username', help='MySQL username (或使用环境变量 DB_USERNAME)')
    parser.add_argument('--db-password', help='MySQL password (或使用环境变量 DB_PASSWORD)')
    
    # Tenant and ES
    parser.add_argument('--tenant-id', required=True, help='Tenant ID (必需)')
    parser.add_argument('--es-host', help='Elasticsearch host (或使用环境变量 ES_HOST, 默认: http://localhost:9200)')
    
    # Options
    parser.add_argument('--batch-size', type=int, default=500, help='批量导入大小 (默认: 500)')
    parser.add_argument('--skip-delete', action='store_true', help='跳过删除旧索引步骤')
    
    args = parser.parse_args()

    print("=" * 60)
    print("重建ES索引并导入数据")
    print("=" * 60)

    # 加载mapping
    print("\n[1/4] 加载mapping配置...")
    try:
        mapping = load_mapping()
        print(f"✓ 成功加载mapping配置")
    except Exception as e:
        print(f"✗ 加载mapping失败: {e}")
        return 1

    index_name = DEFAULT_INDEX_NAME
    print(f"索引名称: {index_name}")

    # 连接Elasticsearch
    print("\n[2/4] 连接Elasticsearch...")
    es_host = args.es_host or os.environ.get('ES_HOST', 'http://localhost:9200')
    es_username = os.environ.get('ES_USERNAME')
    es_password = os.environ.get('ES_PASSWORD')
    
    print(f"ES地址: {es_host}")
    if es_username:
        print(f"ES用户名: {es_username}")
    
    try:
        if es_username and es_password:
            es_client = ESClient(hosts=[es_host], username=es_username, password=es_password)
        else:
            es_client = ESClient(hosts=[es_host])
        
        if not es_client.ping():
            print(f"✗ 无法连接到Elasticsearch: {es_host}")
            return 1
        print("✓ Elasticsearch连接成功")
    except Exception as e:
        print(f"✗ 连接Elasticsearch失败: {e}")
        return 1

    # 删除旧索引
    if not args.skip_delete:
        print("\n[3/4] 删除旧索引...")
        if es_client.index_exists(index_name):
            print(f"发现已存在的索引: {index_name}")
            if delete_index_if_exists(es_client, index_name):
                print(f"✓ 成功删除索引: {index_name}")
            else:
                print(f"✗ 删除索引失败: {index_name}")
                return 1
        else:
            print(f"索引不存在,跳过删除: {index_name}")
    else:
        print("\n[3/4] 跳过删除旧索引步骤")

    # 创建新索引
    print("\n[4/4] 创建新索引...")
    try:
        if es_client.index_exists(index_name):
            print(f"✓ 索引已存在: {index_name},跳过创建")
        else:
            print(f"创建索引: {index_name}")
            if es_client.create_index(index_name, mapping):
                print(f"✓ 成功创建索引: {index_name}")
            else:
                print(f"✗ 创建索引失败: {index_name}")
                return 1
    except Exception as e:
        print(f"✗ 创建索引失败: {e}")
        import traceback
        traceback.print_exc()
        return 1

    # 连接MySQL
    print("\n[5/5] 连接MySQL...")
    db_host = args.db_host or os.environ.get('DB_HOST')
    db_port = args.db_port or int(os.environ.get('DB_PORT', 3306))
    db_database = args.db_database or os.environ.get('DB_DATABASE')
    db_username = args.db_username or os.environ.get('DB_USERNAME')
    db_password = args.db_password or os.environ.get('DB_PASSWORD')

    if not all([db_host, db_database, db_username, db_password]):
        print("✗ MySQL连接参数不完整")
        print("请提供 --db-host, --db-database, --db-username, --db-password")
        print("或设置环境变量: DB_HOST, DB_DATABASE, DB_USERNAME, DB_PASSWORD")
        return 1

    print(f"MySQL: {db_host}:{db_port}/{db_database}")
    try:
        db_engine = create_db_connection(
            host=db_host,
            port=db_port,
            database=db_database,
            username=db_username,
            password=db_password
        )
        print("✓ MySQL连接成功")
    except Exception as e:
        print(f"✗ 连接MySQL失败: {e}")
        return 1

    # 导入数据
    print("\n[6/6] 导入数据...")
    print(f"Tenant ID: {args.tenant_id}")
    print(f"批量大小: {args.batch_size}")
    
    try:
        transformer = SPUTransformer(db_engine, args.tenant_id)
        print("正在转换数据...")
        documents = transformer.transform_batch()
        print(f"✓ 转换完成: {len(documents)} 个文档")
        
        if not documents:
            print("⚠ 没有数据需要导入")
            return 0

        print(f"正在导入数据到ES (批量大小: {args.batch_size})...")
        indexer = BulkIndexer(es_client, index_name, batch_size=args.batch_size)
        results = indexer.index_documents(documents, id_field="spu_id", show_progress=True)
        
        print(f"\n{'='*60}")
        print("导入完成!")
        print(f"{'='*60}")
        print(f"成功: {results['success']}")
        print(f"失败: {results['failed']}")
        print(f"耗时: {results.get('elapsed_time', 0):.2f}秒")
        
        if results['failed'] > 0:
            print(f"\n⚠ 警告: {results['failed']} 个文档导入失败")
            return 1
        
        return 0
    except Exception as e:
        print(f"✗ 导入数据失败: {e}")
        import traceback
        traceback.print_exc()
        return 1


if __name__ == '__main__':
    sys.exit(main())