import_test_data.py 8.92 KB
#!/usr/bin/env python3
"""
Import test data into MySQL Shoplazza tables.

Reads SQL file generated by generate_test_data.py and imports into MySQL.
"""

import sys
import os
import argparse
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from utils.db_connector import create_db_connection, test_connection


def import_sql_file(db_engine, sql_file: str):
    """
    Import SQL file into database using MySQL client (more reliable for large files).

    Args:
        db_engine: SQLAlchemy database engine (used to get connection info)
        sql_file: Path to SQL file
    """
    import subprocess
    import os
    from pathlib import Path
    
    # Get connection info from engine URL
    engine_url = str(db_engine.url)
    # Parse: mysql+pymysql://user:pass@host:port/database
    import re
    match = re.match(r'mysql\+pymysql://([^:]+):([^@]+)@([^:]+):(\d+)/(.+)', engine_url)
    if not match:
        raise ValueError(f"Cannot parse database URL: {engine_url}")
    
    username, password, host, port, database = match.groups()
    
    # Use MySQL client to execute SQL file (more reliable)
    sql_file_path = Path(sql_file).absolute()
    
    # Build mysql command
    mysql_cmd = [
        'mysql',
        f'-h{host}',
        f'-P{port}',
        f'-u{username}',
        f'-p{password}',
        database
    ]
    
    print(f"Executing SQL file using MySQL client...")
    print(f"  File: {sql_file_path}")
    print(f"  Database: {host}:{port}/{database}")
    
    try:
        with open(sql_file_path, 'r', encoding='utf-8') as f:
            result = subprocess.run(
                mysql_cmd,
                stdin=f,
                capture_output=True,
                text=True,
                timeout=300  # 5 minute timeout
            )
        
        if result.returncode != 0:
            error_msg = result.stderr or result.stdout
            print(f"ERROR: MySQL execution failed")
            print(f"Error output: {error_msg[:500]}")
            raise Exception(f"MySQL execution failed: {error_msg[:200]}")
        
        print("SQL file executed successfully")
        return True
        
    except FileNotFoundError:
        # Fallback to SQLAlchemy if mysql client not available
        print("MySQL client not found, falling back to SQLAlchemy...")
        return import_sql_file_sqlalchemy(db_engine, sql_file)
    except subprocess.TimeoutExpired:
        raise Exception("SQL execution timed out after 5 minutes")
    except Exception as e:
        print(f"Error using MySQL client: {e}")
        print("Falling back to SQLAlchemy...")
        return import_sql_file_sqlalchemy(db_engine, sql_file)


def import_sql_file_sqlalchemy(db_engine, sql_file: str):
    """
    Fallback method: Import SQL file using SQLAlchemy (for when mysql client unavailable).
    """
    from sqlalchemy import text
    
    with open(sql_file, 'r', encoding='utf-8') as f:
        sql_content = f.read()
    
    # Remove comment lines
    lines = sql_content.split('\n')
    cleaned_lines = []
    for line in lines:
        stripped = line.lstrip()
        if stripped.startswith('--'):
            continue
        cleaned_lines.append(line)
    
    sql_content = '\n'.join(cleaned_lines)
    
    # Split by semicolon - but we need to handle strings properly
    # Use a state machine to track string boundaries
    statements = []
    current = []
    in_string = False
    i = 0
    
    while i < len(sql_content):
        char = sql_content[i]
        
        if char == "'":
            # Check for escaped quote (two single quotes)
            if i + 1 < len(sql_content) and sql_content[i+1] == "'":
                current.append("''")
                i += 1  # Skip next quote
            elif not in_string:
                in_string = True
                current.append(char)
            else:
                in_string = False
                current.append(char)
        else:
            current.append(char)
        
        # Split on semicolon only if not in string
        if char == ';' and not in_string:
            stmt = ''.join(current).strip()
            if stmt and stmt.upper().startswith('INSERT INTO'):
                statements.append(stmt)
            current = []
        
        i += 1
    
    # Handle last statement
    if current:
        stmt = ''.join(current).strip()
        if stmt and stmt.upper().startswith('INSERT INTO'):
            statements.append(stmt)
    
    print(f"Parsed {len(statements)} SQL statements")
    print(f"Executing {len(statements)} SQL statements...")
    
    # Use raw connection to avoid SQLAlchemy parameter parsing
    raw_conn = db_engine.raw_connection()
    try:
        cursor = raw_conn.cursor()
        try:
        for i, statement in enumerate(statements, 1):
                try:
                    # Execute raw SQL directly using pymysql cursor
                    cursor.execute(statement)
                    raw_conn.commit()
                    if i % 1000 == 0 or i == len(statements):
                    print(f"  [{i}/{len(statements)}] Executed successfully")
                except Exception as e:
                    print(f"  [{i}/{len(statements)}] ERROR: {e}")
                    error_start = max(0, statement.find('VALUES') - 100)
                    error_end = min(len(statement), error_start + 500)
                    print(f"  Statement context: ...{statement[error_start:error_end]}...")
                    raise
        finally:
            cursor.close()
    finally:
        raw_conn.close()
    
    return True


def verify_import(db_engine, tenant_id: str):
    """
    Verify imported data.

    Args:
        db_engine: SQLAlchemy database engine
        tenant_id: Tenant ID to verify
    """
    from sqlalchemy import text
    
    with db_engine.connect() as conn:
        # Count SPUs
        result = conn.execute(text("SELECT COUNT(*) FROM shoplazza_product_spu WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id})
        spu_count = result.scalar()
        
        # Count SKUs
        result = conn.execute(text("SELECT COUNT(*) FROM shoplazza_product_sku WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id})
        sku_count = result.scalar()
        
        print(f"\nVerification:")
        print(f"  SPUs: {spu_count}")
        print(f"  SKUs: {sku_count}")
        
        return spu_count, sku_count


def main():
    parser = argparse.ArgumentParser(description='Import test data into MySQL')
    
    # Database connection
    parser.add_argument('--db-host', required=True, help='MySQL host')
    parser.add_argument('--db-port', type=int, default=3306, help='MySQL port (default: 3306)')
    parser.add_argument('--db-database', required=True, help='MySQL database name')
    parser.add_argument('--db-username', required=True, help='MySQL username')
    parser.add_argument('--db-password', required=True, help='MySQL password')
    
    # Import options
    parser.add_argument('--sql-file', required=True, help='SQL file to import')
    parser.add_argument('--tenant-id', help='Tenant ID to verify (optional)')
    
    args = parser.parse_args()

    print(f"Connecting to MySQL: {args.db_host}:{args.db_port}/{args.db_database}")
    
    # Connect to database
    try:
        db_engine = create_db_connection(
            host=args.db_host,
            port=args.db_port,
            database=args.db_database,
            username=args.db_username,
            password=args.db_password
        )
    except Exception as e:
        print(f"ERROR: Failed to connect to MySQL: {e}")
        return 1

    # Test connection
    if not test_connection(db_engine):
        print("ERROR: Database connection test failed")
        return 1

    print("Database connection successful")

    # Clean existing data if tenant_id provided
    if args.tenant_id:
        print(f"\nCleaning existing data for tenant_id: {args.tenant_id}")
        from sqlalchemy import text
        try:
            with db_engine.connect() as conn:
                # Delete SKUs first (foreign key constraint)
                conn.execute(text(f"DELETE FROM shoplazza_product_sku WHERE tenant_id = '{args.tenant_id}'"))
                # Delete SPUs
                conn.execute(text(f"DELETE FROM shoplazza_product_spu WHERE tenant_id = '{args.tenant_id}'"))
                conn.commit()
                print("✓ Existing data cleaned")
        except Exception as e:
            print(f"⚠ Warning: Failed to clean existing data: {e}")
            # Continue anyway

    # Import SQL file
    print(f"\nImporting SQL file: {args.sql_file}")
    try:
        import_sql_file(db_engine, args.sql_file)
        print("Import completed successfully")
    except Exception as e:
        print(f"ERROR: Failed to import SQL file: {e}")
        import traceback
        traceback.print_exc()
        return 1

    # Verify import if tenant_id provided
    if args.tenant_id:
        verify_import(db_engine, args.tenant_id)

    return 0


if __name__ == '__main__':
    sys.exit(main())