debug_utils.py 11.7 KB
"""
调试工具模块
提供debug日志和明文输出功能
"""
import os
import json
import logging
from datetime import datetime


def setup_debug_logger(script_name, debug=False):
    """
    设置debug日志记录器
    
    Args:
        script_name: 脚本名称
        debug: 是否开启debug模式
    
    Returns:
        logger对象
    """
    logger = logging.getLogger(script_name)
    
    # 清除已有的handlers
    logger.handlers.clear()
    
    # 设置日志级别
    if debug:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)
    
    # 控制台输出
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG if debug else logging.INFO)
    console_format = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    console_handler.setFormatter(console_format)
    logger.addHandler(console_handler)
    
    # 文件输出(如果开启debug)
    if debug:
        log_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs', 'debug')
        os.makedirs(log_dir, exist_ok=True)
        
        log_file = os.path.join(
            log_dir, 
            f"{script_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
        )
        file_handler = logging.FileHandler(log_file, encoding='utf-8')
        file_handler.setLevel(logging.DEBUG)
        file_handler.setFormatter(console_format)
        logger.addHandler(file_handler)
        
        logger.debug(f"Debug log file: {log_file}")
    
    return logger


def log_dataframe_info(logger, df, name="DataFrame", sample_size=5):
    """
    记录DataFrame的详细信息
    
    Args:
        logger: logger对象
        df: pandas DataFrame
        name: 数据名称
        sample_size: 采样大小
    """
    logger.debug(f"\n{'='*60}")
    logger.debug(f"{name} 信息:")
    logger.debug(f"{'='*60}")
    logger.debug(f"总行数: {len(df)}")
    logger.debug(f"总列数: {len(df.columns)}")
    logger.debug(f"列名: {list(df.columns)}")
    
    # 数据类型
    logger.debug(f"\n数据类型:")
    for col, dtype in df.dtypes.items():
        logger.debug(f"  {col}: {dtype}")
    
    # 缺失值统计
    null_counts = df.isnull().sum()
    if null_counts.sum() > 0:
        logger.debug(f"\n缺失值统计:")
        for col, count in null_counts[null_counts > 0].items():
            logger.debug(f"  {col}: {count} ({count/len(df)*100:.2f}%)")
    
    # 基本统计
    if len(df) > 0:
        logger.debug(f"\n前{sample_size}行示例:")
        logger.debug(f"\n{df.head(sample_size).to_string()}")
        
        # 数值列的统计
        numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
        if len(numeric_cols) > 0:
            logger.debug(f"\n数值列统计:")
            logger.debug(f"\n{df[numeric_cols].describe().to_string()}")
    
    logger.debug(f"{'='*60}\n")


def log_dict_stats(logger, data_dict, name="Dictionary", top_n=10):
    """
    记录字典的统计信息
    
    Args:
        logger: logger对象
        data_dict: 字典数据
        name: 数据名称
        top_n: 显示前N个元素
    """
    logger.debug(f"\n{'='*60}")
    logger.debug(f"{name} 统计:")
    logger.debug(f"{'='*60}")
    logger.debug(f"总元素数: {len(data_dict)}")
    
    if len(data_dict) > 0:
        # 如果值是列表或可计数的
        try:
            item_counts = {k: len(v) if hasattr(v, '__len__') else 1 
                          for k, v in list(data_dict.items())[:1000]}  # 采样
            if item_counts:
                total_items = sum(item_counts.values())
                avg_items = total_items / len(item_counts)
                logger.debug(f"平均每个key的元素数: {avg_items:.2f}")
        except:
            pass
        
        # 显示前N个示例
        logger.debug(f"\n前{top_n}个示例:")
        for i, (k, v) in enumerate(list(data_dict.items())[:top_n]):
            if isinstance(v, list):
                logger.debug(f"  {k}: {v[:3]}... (total: {len(v)})")
            elif isinstance(v, dict):
                logger.debug(f"  {k}: {dict(list(v.items())[:3])}... (total: {len(v)})")
            else:
                logger.debug(f"  {k}: {v}")
    
    logger.debug(f"{'='*60}\n")


def save_readable_index(output_file, index_data, name_mappings, description=""):
    """
    保存可读的明文索引文件
    
    Args:
        output_file: 输出文件路径
        index_data: 索引数据 {item_id: [(similar_id, score), ...]}
        name_mappings: 名称映射 {
            'item': {id: name},
            'category': {id: name},
            'platform': {id: name},
            ...
        }
        description: 描述信息
    """
    debug_dir = os.path.join(os.path.dirname(output_file), 'debug')
    os.makedirs(debug_dir, exist_ok=True)
    
    # 生成明文文件名
    base_name = os.path.basename(output_file)
    name_without_ext = os.path.splitext(base_name)[0]
    readable_file = os.path.join(debug_dir, f"{name_without_ext}_readable.txt")
    
    with open(readable_file, 'w', encoding='utf-8') as f:
        # 写入描述信息
        f.write("="*80 + "\n")
        f.write(f"明文索引文件\n")
        f.write(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        if description:
            f.write(f"描述: {description}\n")
        f.write(f"总索引数: {len(index_data)}\n")
        f.write("="*80 + "\n\n")
        
        # 遍历索引数据
        for idx, (key, items) in enumerate(index_data.items(), 1):
            # 解析key并添加名称
            readable_key = format_key_with_name(key, name_mappings)
            
            f.write(f"\n[{idx}] {readable_key}\n")
            f.write("-" * 80 + "\n")
            
            # 解析items
            if isinstance(items, list):
                for i, item in enumerate(items, 1):
                    if isinstance(item, tuple) and len(item) >= 2:
                        item_id, score = item[0], item[1]
                        item_name = name_mappings.get('item', {}).get(str(item_id), 'Unknown')
                        f.write(f"  {i}. ID:{item_id}({item_name}) - Score:{score:.4f}\n")
                    else:
                        item_name = name_mappings.get('item', {}).get(str(item), 'Unknown')
                        f.write(f"  {i}. ID:{item}({item_name})\n")
            elif isinstance(items, dict):
                for i, (item_id, score) in enumerate(items.items(), 1):
                    item_name = name_mappings.get('item', {}).get(str(item_id), 'Unknown')
                    f.write(f"  {i}. ID:{item_id}({item_name}) - Score:{score:.4f}\n")
            else:
                f.write(f"  {items}\n")
            
            # 每50个索引添加分隔
            if idx % 50 == 0:
                f.write("\n" + "="*80 + "\n")
                f.write(f"已输出 {idx}/{len(index_data)} 个索引\n")
                f.write("="*80 + "\n")
    
    return readable_file


def format_key_with_name(key, name_mappings):
    """
    格式化key,添加名称信息
    
    Args:
        key: 原始key (如 "interest:hot:platform:1" 或 "i2i:swing:12345")
        name_mappings: 名称映射字典
    
    Returns:
        格式化后的key字符串
    """
    if ':' not in str(key):
        # 简单的item_id
        item_name = name_mappings.get('item', {}).get(str(key), '')
        return f"{key}({item_name})" if item_name else str(key)
    
    parts = str(key).split(':')
    formatted_parts = []
    
    for i, part in enumerate(parts):
        # 尝试识别是否为ID
        if part.isdigit():
            # 根据前一个部分判断类型
            if i > 0:
                prev_part = parts[i-1]
                if 'category' in prev_part or 'level' in prev_part:
                    name = name_mappings.get('category', {}).get(part, '')
                    formatted_parts.append(f"{part}({name})" if name else part)
                elif 'platform' in prev_part:
                    name = name_mappings.get('platform', {}).get(part, '')
                    formatted_parts.append(f"{part}({name})" if name else part)
                elif 'supplier' in prev_part:
                    name = name_mappings.get('supplier', {}).get(part, '')
                    formatted_parts.append(f"{part}({name})" if name else part)
                else:
                    # 可能是item_id
                    name = name_mappings.get('item', {}).get(part, '')
                    formatted_parts.append(f"{part}({name})" if name else part)
            else:
                formatted_parts.append(part)
        else:
            formatted_parts.append(part)
    
    return ':'.join(formatted_parts)


def fetch_name_mappings(engine, debug=False):
    """
    从数据库获取ID到名称的映射
    
    Args:
        engine: 数据库连接
        debug: 是否输出debug信息
    
    Returns:
        name_mappings字典
    """
    import pandas as pd
    
    mappings = {
        'item': {},
        'category': {},
        'platform': {},
        'supplier': {},
        'client_platform': {}
    }
    
    try:
        # 获取商品名称
        query = "SELECT id, name FROM prd_goods_sku WHERE status IN (2,4,5) LIMIT 100000"
        df = pd.read_sql(query, engine)
        mappings['item'] = dict(zip(df['id'].astype(str), df['name']))
        if debug:
            print(f"✓ 获取到 {len(mappings['item'])} 个商品名称")
    except Exception as e:
        if debug:
            print(f"✗ 获取商品名称失败: {e}")
    
    try:
        # 获取分类名称
        query = "SELECT id, name FROM prd_category LIMIT 10000"
        df = pd.read_sql(query, engine)
        mappings['category'] = dict(zip(df['id'].astype(str), df['name']))
        if debug:
            print(f"✓ 获取到 {len(mappings['category'])} 个分类名称")
    except Exception as e:
        if debug:
            print(f"✗ 获取分类名称失败: {e}")
    
    try:
        # 获取供应商名称
        query = "SELECT id, name FROM sup_supplier LIMIT 10000"
        df = pd.read_sql(query, engine)
        mappings['supplier'] = dict(zip(df['id'].astype(str), df['name']))
        if debug:
            print(f"✓ 获取到 {len(mappings['supplier'])} 个供应商名称")
    except Exception as e:
        if debug:
            print(f"✗ 获取供应商名称失败: {e}")
    
    # 平台名称(硬编码常见值)
    mappings['platform'] = {
        'pc': 'PC端',
        'h5': 'H5移动端',
        'app': 'APP',
        'miniprogram': '小程序',
        'wechat': '微信'
    }
    
    mappings['client_platform'] = {
        'iOS': 'iOS',
        'Android': 'Android',
        'Web': 'Web',
        'H5': 'H5'
    }
    
    return mappings


def log_algorithm_params(logger, params_dict):
    """
    记录算法参数
    
    Args:
        logger: logger对象
        params_dict: 参数字典
    """
    logger.debug(f"\n{'='*60}")
    logger.debug("算法参数:")
    logger.debug(f"{'='*60}")
    for key, value in params_dict.items():
        logger.debug(f"  {key}: {value}")
    logger.debug(f"{'='*60}\n")


def log_processing_step(logger, step_name, start_time=None):
    """
    记录处理步骤
    
    Args:
        logger: logger对象
        step_name: 步骤名称
        start_time: 开始时间(如果提供,会计算耗时)
    """
    from datetime import datetime
    current_time = datetime.now()
    
    logger.debug(f"\n{'='*60}")
    logger.debug(f"处理步骤: {step_name}")
    logger.debug(f"时间: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
    
    if start_time:
        elapsed = (current_time - start_time).total_seconds()
        logger.debug(f"耗时: {elapsed:.2f}秒")
    
    logger.debug(f"{'='*60}\n")