fetch_item_attributes.py 5.43 KB
"""
获取商品基础属性(前置任务)
从数据库获取ID->名称的映射,保存到本地文件供其他任务使用
避免每个任务重复查询数据库
"""
import pandas as pd
import os
import json
import argparse
from datetime import datetime
from db_service import create_db_connection
from config.offline_config import DB_CONFIG, OUTPUT_DIR
from scripts.debug_utils import setup_debug_logger


def fetch_and_save_mappings(engine, output_dir, logger=None, debug=False):
    """
    从数据库获取各种ID->名称映射并保存
    
    Args:
        engine: 数据库连接
        output_dir: 输出目录
        logger: 日志记录器
        debug: 是否开启debug模式
    
    Returns:
        mappings字典和输出文件路径
    """
    if logger:
        logger.info("开始获取ID到名称的映射...")
    
    mappings = {
        'item': {},
        'category': {},
        'platform': {},
        'supplier': {},
        'client_platform': {}
    }
    
    stats = {}
    
    # 1. 获取商品名称
    try:
        if logger:
            logger.info("获取商品名称...")
        query = "SELECT id, name FROM prd_goods_sku WHERE status IN (2,4,5) LIMIT 5000000"
        df = pd.read_sql(query, engine)
        mappings['item'] = dict(zip(df['id'].astype(str), df['name']))
        stats['item'] = len(mappings['item'])
        if logger:
            logger.info(f"✓ 获取到 {stats['item']} 个商品名称")
    except Exception as e:
        if logger:
            logger.error(f"✗ 获取商品名称失败: {e}")
        stats['item'] = 0
    
    # 2. 获取分类名称
    try:
        if logger:
            logger.info("获取分类名称...")
        query = "SELECT id, name FROM prd_category LIMIT 100000"
        df = pd.read_sql(query, engine)
        mappings['category'] = dict(zip(df['id'].astype(str), df['name']))
        stats['category'] = len(mappings['category'])
        if logger:
            logger.info(f"✓ 获取到 {stats['category']} 个分类名称")
    except Exception as e:
        if logger:
            logger.error(f"✗ 获取分类名称失败: {e}")
        stats['category'] = 0
    
    # 3. 获取供应商名称
    try:
        if logger:
            logger.info("获取供应商名称...")
        query = "SELECT id, name FROM sup_supplier LIMIT 100000"
        df = pd.read_sql(query, engine)
        mappings['supplier'] = dict(zip(df['id'].astype(str), df['name']))
        stats['supplier'] = len(mappings['supplier'])
        if logger:
            logger.info(f"✓ 获取到 {stats['supplier']} 个供应商名称")
    except Exception as e:
        if logger:
            logger.error(f"✗ 获取供应商名称失败: {e}")
        stats['supplier'] = 0
    
    # 4. 平台名称(硬编码)
    mappings['platform'] = {
        'pc': 'PC端',
        'h5': 'H5移动端',
        'app': 'APP',
        'miniprogram': '小程序',
        'wechat': '微信'
    }
    stats['platform'] = len(mappings['platform'])
    
    # 5. 客户端平台(硬编码)
    mappings['client_platform'] = {
        'iOS': 'iOS',
        'Android': 'Android',
        'Web': 'Web',
        'H5': 'H5'
    }
    stats['client_platform'] = len(mappings['client_platform'])
    
    # 保存到文件
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, 'item_attributes_mappings.json')
    
    if logger:
        logger.info(f"保存映射到: {output_file}")
    
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(mappings, f, ensure_ascii=False, indent=2)
    
    # 保存统计信息
    stats_file = os.path.join(output_dir, 'item_attributes_stats.txt')
    with open(stats_file, 'w', encoding='utf-8') as f:
        f.write(f"商品属性映射统计信息\n")
        f.write(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"=" * 60 + "\n")
        for key, count in stats.items():
            f.write(f"{key}: {count}\n")
        f.write(f"=" * 60 + "\n")
        f.write(f"总计: {sum(stats.values())}\n")
    
    if logger:
        logger.info(f"统计信息已保存到: {stats_file}")
        logger.info(f"总计获取 {sum(stats.values())} 个映射")
    
    return mappings, output_file


def main():
    parser = argparse.ArgumentParser(description='Fetch item attributes and save mappings')
    parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR,
                       help='Output directory for mappings file')
    parser.add_argument('--debug', action='store_true',
                       help='Enable debug mode with detailed logging')
    
    args = parser.parse_args()
    
    # 设置日志
    logger = setup_debug_logger('fetch_item_attributes', debug=args.debug)
    
    logger.info("="*60)
    logger.info("商品属性获取任务(前置任务)")
    logger.info("="*60)
    
    # 创建数据库连接
    logger.info("连接数据库...")
    engine = create_db_connection(
        DB_CONFIG['host'],
        DB_CONFIG['port'],
        DB_CONFIG['database'],
        DB_CONFIG['username'],
        DB_CONFIG['password']
    )
    
    # 获取并保存映射
    mappings, output_file = fetch_and_save_mappings(
        engine,
        args.output_dir,
        logger=logger,
        debug=args.debug
    )
    
    logger.info("="*60)
    logger.info("✓ 商品属性获取完成!")
    logger.info(f"映射文件: {output_file}")
    logger.info("="*60)


if __name__ == '__main__':
    main()