""" 调试工具模块 提供debug日志和明文输出功能 """ import os import json import logging from datetime import datetime def setup_debug_logger(script_name, debug=False): """ 设置debug日志记录器 Args: script_name: 脚本名称 debug: 是否开启debug模式 Returns: logger对象 """ logger = logging.getLogger(script_name) # 清除已有的handlers logger.handlers.clear() # 设置日志级别 if debug: logger.setLevel(logging.DEBUG) else: logger.setLevel(logging.INFO) # 控制台输出 console_handler = logging.StreamHandler() console_handler.setLevel(logging.DEBUG if debug else logging.INFO) console_format = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) console_handler.setFormatter(console_format) logger.addHandler(console_handler) # 文件输出(如果开启debug) if debug: log_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs', 'debug') os.makedirs(log_dir, exist_ok=True) log_file = os.path.join( log_dir, f"{script_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" ) file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(console_format) logger.addHandler(file_handler) logger.debug(f"Debug log file: {log_file}") return logger def log_dataframe_info(logger, df, name="DataFrame", sample_size=5): """ 记录DataFrame的详细信息 Args: logger: logger对象 df: pandas DataFrame name: 数据名称 sample_size: 采样大小 """ logger.debug(f"\n{'='*60}") logger.debug(f"{name} 信息:") logger.debug(f"{'='*60}") logger.debug(f"总行数: {len(df)}") logger.debug(f"总列数: {len(df.columns)}") logger.debug(f"列名: {list(df.columns)}") # 数据类型 logger.debug(f"\n数据类型:") for col, dtype in df.dtypes.items(): logger.debug(f" {col}: {dtype}") # 缺失值统计 null_counts = df.isnull().sum() if null_counts.sum() > 0: logger.debug(f"\n缺失值统计:") for col, count in null_counts[null_counts > 0].items(): logger.debug(f" {col}: {count} ({count/len(df)*100:.2f}%)") # 基本统计 if len(df) > 0: logger.debug(f"\n前{sample_size}行示例:") logger.debug(f"\n{df.head(sample_size).to_string()}") # 数值列的统计 numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns if len(numeric_cols) > 0: logger.debug(f"\n数值列统计:") logger.debug(f"\n{df[numeric_cols].describe().to_string()}") logger.debug(f"{'='*60}\n") def log_dict_stats(logger, data_dict, name="Dictionary", top_n=10): """ 记录字典的统计信息 Args: logger: logger对象 data_dict: 字典数据 name: 数据名称 top_n: 显示前N个元素 """ logger.debug(f"\n{'='*60}") logger.debug(f"{name} 统计:") logger.debug(f"{'='*60}") logger.debug(f"总元素数: {len(data_dict)}") if len(data_dict) > 0: # 如果值是列表或可计数的 try: item_counts = {k: len(v) if hasattr(v, '__len__') else 1 for k, v in list(data_dict.items())[:1000]} # 采样 if item_counts: total_items = sum(item_counts.values()) avg_items = total_items / len(item_counts) logger.debug(f"平均每个key的元素数: {avg_items:.2f}") except: pass # 显示前N个示例 logger.debug(f"\n前{top_n}个示例:") for i, (k, v) in enumerate(list(data_dict.items())[:top_n]): if isinstance(v, list): logger.debug(f" {k}: {v[:3]}... (total: {len(v)})") elif isinstance(v, dict): logger.debug(f" {k}: {dict(list(v.items())[:3])}... (total: {len(v)})") else: logger.debug(f" {k}: {v}") logger.debug(f"{'='*60}\n") def save_readable_index(output_file, index_data, name_mappings, description=""): """ 保存可读的明文索引文件 Args: output_file: 输出文件路径 index_data: 索引数据 {item_id: [(similar_id, score), ...]} name_mappings: 名称映射 { 'item': {id: name}, 'category': {id: name}, 'platform': {id: name}, ... } description: 描述信息 """ debug_dir = os.path.join(os.path.dirname(output_file), 'debug') os.makedirs(debug_dir, exist_ok=True) # 生成明文文件名 base_name = os.path.basename(output_file) name_without_ext = os.path.splitext(base_name)[0] readable_file = os.path.join(debug_dir, f"{name_without_ext}_readable.txt") with open(readable_file, 'w', encoding='utf-8') as f: # 写入描述信息 f.write("="*80 + "\n") f.write(f"明文索引文件\n") f.write(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") if description: f.write(f"描述: {description}\n") f.write(f"总索引数: {len(index_data)}\n") f.write("="*80 + "\n\n") # 遍历索引数据 for idx, (key, items) in enumerate(index_data.items(), 1): # 解析key并添加名称 readable_key = format_key_with_name(key, name_mappings) f.write(f"\n[{idx}] {readable_key}\n") f.write("-" * 80 + "\n") # 解析items if isinstance(items, list): for i, item in enumerate(items, 1): if isinstance(item, tuple) and len(item) >= 2: item_id, score = item[0], item[1] item_name = name_mappings.get('item', {}).get(str(item_id), 'Unknown') f.write(f" {i}. ID:{item_id}({item_name}) - Score:{score:.4f}\n") else: item_name = name_mappings.get('item', {}).get(str(item), 'Unknown') f.write(f" {i}. ID:{item}({item_name})\n") elif isinstance(items, dict): for i, (item_id, score) in enumerate(items.items(), 1): item_name = name_mappings.get('item', {}).get(str(item_id), 'Unknown') f.write(f" {i}. ID:{item_id}({item_name}) - Score:{score:.4f}\n") else: f.write(f" {items}\n") # 每50个索引添加分隔 if idx % 50 == 0: f.write("\n" + "="*80 + "\n") f.write(f"已输出 {idx}/{len(index_data)} 个索引\n") f.write("="*80 + "\n") return readable_file def format_key_with_name(key, name_mappings): """ 格式化key,添加名称信息 Args: key: 原始key (如 "interest:hot:platform:1" 或 "i2i:swing:12345") name_mappings: 名称映射字典 Returns: 格式化后的key字符串 """ if ':' not in str(key): # 简单的item_id item_name = name_mappings.get('item', {}).get(str(key), '') return f"{key}({item_name})" if item_name else str(key) parts = str(key).split(':') formatted_parts = [] for i, part in enumerate(parts): # 尝试识别是否为ID if part.isdigit(): # 根据前一个部分判断类型 if i > 0: prev_part = parts[i-1] if 'category' in prev_part or 'level' in prev_part: name = name_mappings.get('category', {}).get(part, '') formatted_parts.append(f"{part}({name})" if name else part) elif 'platform' in prev_part: name = name_mappings.get('platform', {}).get(part, '') formatted_parts.append(f"{part}({name})" if name else part) elif 'supplier' in prev_part: name = name_mappings.get('supplier', {}).get(part, '') formatted_parts.append(f"{part}({name})" if name else part) else: # 可能是item_id name = name_mappings.get('item', {}).get(part, '') formatted_parts.append(f"{part}({name})" if name else part) else: formatted_parts.append(part) else: formatted_parts.append(part) return ':'.join(formatted_parts) def fetch_name_mappings(engine, debug=False): """ 从数据库获取ID到名称的映射 Args: engine: 数据库连接 debug: 是否输出debug信息 Returns: name_mappings字典 """ import pandas as pd mappings = { 'item': {}, 'category': {}, 'platform': {}, 'supplier': {}, 'client_platform': {} } try: # 获取商品名称 query = "SELECT id, name FROM prd_goods_sku WHERE status IN (2,4,5) LIMIT 5000000" df = pd.read_sql(query, engine) mappings['item'] = dict(zip(df['id'].astype(str), df['name'])) if debug: print(f"✓ 获取到 {len(mappings['item'])} 个商品名称") except Exception as e: if debug: print(f"✗ 获取商品名称失败: {e}") try: # 获取分类名称 query = "SELECT id, name FROM prd_category LIMIT 100000" df = pd.read_sql(query, engine) mappings['category'] = dict(zip(df['id'].astype(str), df['name'])) if debug: print(f"✓ 获取到 {len(mappings['category'])} 个分类名称") except Exception as e: if debug: print(f"✗ 获取分类名称失败: {e}") try: # 获取供应商名称 query = "SELECT id, name FROM sup_supplier LIMIT 100000" df = pd.read_sql(query, engine) mappings['supplier'] = dict(zip(df['id'].astype(str), df['name'])) if debug: print(f"✓ 获取到 {len(mappings['supplier'])} 个供应商名称") except Exception as e: if debug: print(f"✗ 获取供应商名称失败: {e}") # 平台名称(硬编码常见值) mappings['platform'] = { 'pc': 'PC端', 'h5': 'H5移动端', 'app': 'APP', 'miniprogram': '小程序', 'wechat': '微信' } mappings['client_platform'] = { 'iOS': 'iOS', 'Android': 'Android', 'Web': 'Web', 'H5': 'H5' } return mappings def log_algorithm_params(logger, params_dict): """ 记录算法参数 Args: logger: logger对象 params_dict: 参数字典 """ logger.debug(f"\n{'='*60}") logger.debug("算法参数:") logger.debug(f"{'='*60}") for key, value in params_dict.items(): logger.debug(f" {key}: {value}") logger.debug(f"{'='*60}\n") def log_processing_step(logger, step_name, start_time=None): """ 记录处理步骤 Args: logger: logger对象 step_name: 步骤名称 start_time: 开始时间(如果提供,会计算耗时) """ from datetime import datetime current_time = datetime.now() logger.debug(f"\n{'='*60}") logger.debug(f"处理步骤: {step_name}") logger.debug(f"时间: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") if start_time: elapsed = (current_time - start_time).total_seconds() logger.debug(f"耗时: {elapsed:.2f}秒") logger.debug(f"{'='*60}\n")