""" 获取商品基础属性(前置任务) 从数据库获取ID->名称的映射,保存到本地文件供其他任务使用 避免每个任务重复查询数据库 """ import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) import pandas as pd import json import argparse from datetime import datetime from db_service import create_db_connection from offline_tasks.config.offline_config import DB_CONFIG, OUTPUT_DIR from offline_tasks.scripts.debug_utils import setup_debug_logger def fetch_and_save_mappings(engine, output_dir, logger=None, debug=False): """ 从数据库获取各种ID->名称映射并保存 Args: engine: 数据库连接 output_dir: 输出目录 logger: 日志记录器 debug: 是否开启debug模式 Returns: mappings字典和输出文件路径 """ if logger: logger.info("开始获取ID到名称的映射...") mappings = { 'item': {}, 'category': {}, 'platform': {}, 'supplier': {}, 'client_platform': {} } stats = {} # 1. 获取商品名称 try: if logger: logger.info("获取商品名称...") query = "SELECT id, name FROM prd_goods_sku WHERE status IN (2,4,5) LIMIT 5000000" df = pd.read_sql(query, engine) mappings['item'] = dict(zip(df['id'].astype(str), df['name'])) stats['item'] = len(mappings['item']) if logger: logger.info(f"✓ 获取到 {stats['item']} 个商品名称") except Exception as e: if logger: logger.error(f"✗ 获取商品名称失败: {e}") stats['item'] = 0 # 2. 获取分类名称 try: if logger: logger.info("获取分类名称...") query = "SELECT id, name FROM prd_category LIMIT 100000" df = pd.read_sql(query, engine) mappings['category'] = dict(zip(df['id'].astype(str), df['name'])) stats['category'] = len(mappings['category']) if logger: logger.info(f"✓ 获取到 {stats['category']} 个分类名称") except Exception as e: if logger: logger.error(f"✗ 获取分类名称失败: {e}") stats['category'] = 0 # 3. 获取供应商名称 try: if logger: logger.info("获取供应商名称...") query = "SELECT id, name FROM sup_supplier LIMIT 100000" df = pd.read_sql(query, engine) mappings['supplier'] = dict(zip(df['id'].astype(str), df['name'])) stats['supplier'] = len(mappings['supplier']) if logger: logger.info(f"✓ 获取到 {stats['supplier']} 个供应商名称") except Exception as e: if logger: logger.error(f"✗ 获取供应商名称失败: {e}") stats['supplier'] = 0 # 4. 平台名称(硬编码) mappings['platform'] = { 'pc': 'PC端', 'h5': 'H5移动端', 'app': 'APP', 'miniprogram': '小程序', 'wechat': '微信' } stats['platform'] = len(mappings['platform']) # 5. 客户端平台(硬编码) mappings['client_platform'] = { 'iOS': 'iOS', 'Android': 'Android', 'Web': 'Web', 'H5': 'H5' } stats['client_platform'] = len(mappings['client_platform']) # 保存到文件 os.makedirs(output_dir, exist_ok=True) output_file = os.path.join(output_dir, 'item_attributes_mappings.json') if logger: logger.info(f"保存映射到: {output_file}") with open(output_file, 'w', encoding='utf-8') as f: json.dump(mappings, f, ensure_ascii=False, indent=2) # 保存统计信息 stats_file = os.path.join(output_dir, 'item_attributes_stats.txt') with open(stats_file, 'w', encoding='utf-8') as f: f.write(f"商品属性映射统计信息\n") f.write(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"=" * 60 + "\n") for key, count in stats.items(): f.write(f"{key}: {count}\n") f.write(f"=" * 60 + "\n") f.write(f"总计: {sum(stats.values())}\n") if logger: logger.info(f"统计信息已保存到: {stats_file}") logger.info(f"总计获取 {sum(stats.values())} 个映射") return mappings, output_file def main(): parser = argparse.ArgumentParser(description='Fetch item attributes and save mappings') parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='Output directory for mappings file') parser.add_argument('--debug', action='store_true', help='Enable debug mode with detailed logging') args = parser.parse_args() # 设置日志 logger = setup_debug_logger('fetch_item_attributes', debug=args.debug) logger.info("="*60) logger.info("商品属性获取任务(前置任务)") logger.info("="*60) # 创建数据库连接 logger.info("连接数据库...") engine = create_db_connection( DB_CONFIG['host'], DB_CONFIG['port'], DB_CONFIG['database'], DB_CONFIG['username'], DB_CONFIG['password'] ) # 获取并保存映射 mappings, output_file = fetch_and_save_mappings( engine, args.output_dir, logger=logger, debug=args.debug ) logger.info("="*60) logger.info("✓ 商品属性获取完成!") logger.info(f"映射文件: {output_file}") logger.info("="*60) if __name__ == '__main__': main()