From 0e45f7029178f20715ff1b7724250e407013c3d9 Mon Sep 17 00:00:00 2001 From: tangwang Date: Tue, 21 Oct 2025 11:12:57 +0800 Subject: [PATCH] deepwalk refactor for memsave and perfermance optimize --- offline_tasks/scripts/i2i_deepwalk.py | 295 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------------------------------------------------------------------------------------------------------------------------ 1 file changed, 157 insertions(+), 138 deletions(-) diff --git a/offline_tasks/scripts/i2i_deepwalk.py b/offline_tasks/scripts/i2i_deepwalk.py index 4bdd2a6..1d3a1ff 100644 --- a/offline_tasks/scripts/i2i_deepwalk.py +++ b/offline_tasks/scripts/i2i_deepwalk.py @@ -1,40 +1,44 @@ """ i2i - DeepWalk算法实现 基于用户-物品图结构训练DeepWalk模型,获取物品向量相似度 +复用 graphembedding/deepwalk/ 的高效实现 """ -import sys -import os -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) - import pandas as pd import argparse +import os +import sys from datetime import datetime from collections import defaultdict from gensim.models import Word2Vec -import numpy as np from db_service import create_db_connection from offline_tasks.config.offline_config import ( DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range, DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N ) from offline_tasks.scripts.debug_utils import ( - setup_debug_logger, log_dataframe_info, log_dict_stats, + setup_debug_logger, log_dataframe_info, save_readable_index, fetch_name_mappings, log_algorithm_params, log_processing_step ) +# 导入 DeepWalk 实现 +sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'deepwalk')) +from deepwalk import DeepWalk + -def build_item_graph(df, behavior_weights): +def build_edge_file_from_db(df, behavior_weights, output_path, logger): """ - 构建物品图(基于用户共同交互) + 从数据库数据构建边文件 + 边文件格式: item_id \t neighbor_id1:weight1,neighbor_id2:weight2,... Args: df: DataFrame with columns: user_id, item_id, event_type behavior_weights: 行为权重字典 - - Returns: - edge_dict: {item_id: {neighbor_id: weight}} + output_path: 边文件输出路径 + logger: 日志对象 """ + logger.info("开始构建物品图...") + # 构建用户-物品列表 user_items = defaultdict(list) @@ -43,13 +47,19 @@ def build_item_graph(df, behavior_weights): item_id = str(row['item_id']) event_type = row['event_type'] weight = behavior_weights.get(event_type, 1.0) - user_items[user_id].append((item_id, weight)) + logger.info(f"共有 {len(user_items)} 个用户") + # 构建物品图边 edge_dict = defaultdict(lambda: defaultdict(float)) for user_id, items in user_items.items(): + # 限制每个用户的物品数量,避免内存爆炸 + if len(items) > 100: + # 按权重排序,只保留前100个 + items = sorted(items, key=lambda x: -x[1])[:100] + # 物品两两组合,构建边 for i in range(len(items)): item_i, weight_i = items[i] @@ -61,106 +71,47 @@ def build_item_graph(df, behavior_weights): edge_dict[item_i][item_j] += edge_weight edge_dict[item_j][item_i] += edge_weight - return edge_dict - - -def save_edge_file(edge_dict, output_path): - """ - 保存边文件 + logger.info(f"构建物品图完成,共 {len(edge_dict)} 个节点") - Args: - edge_dict: 边字典 - output_path: 输出路径 - """ + # 保存边文件 + logger.info(f"保存边文件到 {output_path}") with open(output_path, 'w', encoding='utf-8') as f: for item_id, neighbors in edge_dict.items(): - # 格式: item_id \t neighbor1:weight1,neighbor2:weight2,... neighbor_str = ','.join([f'{nbr}:{weight:.4f}' for nbr, weight in neighbors.items()]) f.write(f'{item_id}\t{neighbor_str}\n') - print(f"Edge file saved to {output_path}") + logger.info(f"边文件保存完成") + return len(edge_dict) -def random_walk(graph, start_node, walk_length): +def train_word2vec_from_walks(walks_file, config, logger): """ - 执行随机游走 + 从游走文件训练Word2Vec模型 Args: - graph: 图结构 {node: {neighbor: weight}} - start_node: 起始节点 - walk_length: 游走长度 - - Returns: - 游走序列 - """ - walk = [start_node] - - while len(walk) < walk_length: - cur = walk[-1] - - if cur not in graph or not graph[cur]: - break - - # 获取邻居和权重 - neighbors = list(graph[cur].keys()) - weights = list(graph[cur].values()) - - # 归一化权重 - total_weight = sum(weights) - if total_weight == 0: - break - - probs = [w / total_weight for w in weights] - - # 按权重随机选择下一个节点 - next_node = np.random.choice(neighbors, p=probs) - walk.append(next_node) - - return walk - - -def generate_walks(graph, num_walks, walk_length): - """ - 生成随机游走序列 - - Args: - graph: 图结构 - num_walks: 每个节点的游走次数 - walk_length: 游走长度 + walks_file: 游走序列文件路径 + config: Word2Vec配置 + logger: 日志对象 Returns: - List of walks + Word2Vec模型 """ - walks = [] - nodes = list(graph.keys()) - - print(f"Generating {num_walks} walks per node, walk length {walk_length}...") + logger.info(f"从 {walks_file} 读取游走序列...") - for _ in range(num_walks): - np.random.shuffle(nodes) - for node in nodes: - walk = random_walk(graph, node, walk_length) + # 读取游走序列 + sentences = [] + with open(walks_file, 'r', encoding='utf-8') as f: + for line in f: + walk = line.strip().split() if len(walk) >= 2: - walks.append(walk) - - return walks - - -def train_word2vec(walks, config): - """ - 训练Word2Vec模型 + sentences.append(walk) - Args: - walks: 游走序列列表 - config: Word2Vec配置 - - Returns: - Word2Vec模型 - """ - print(f"Training Word2Vec with {len(walks)} walks...") + logger.info(f"共读取 {len(sentences)} 条游走序列") + # 训练Word2Vec + logger.info("开始训练Word2Vec模型...") model = Word2Vec( - sentences=walks, + sentences=sentences, vector_size=config['vector_size'], window=config['window_size'], min_count=config['min_count'], @@ -170,21 +121,23 @@ def train_word2vec(walks, config): seed=42 ) - print(f"Training completed. Vocabulary size: {len(model.wv)}") + logger.info(f"训练完成。词汇表大小:{len(model.wv)}") return model -def generate_similarities(model, top_n=50): +def generate_similarities(model, top_n, logger): """ - 生成物品相似度 + 从Word2Vec模型生成物品相似度 Args: model: Word2Vec模型 top_n: Top N similar items + logger: 日志对象 Returns: Dict[item_id, List[Tuple(similar_item_id, score)]] """ + logger.info("生成相似度...") result = {} for item_id in model.wv.index_to_key: @@ -194,9 +147,37 @@ def generate_similarities(model, top_n=50): except KeyError: continue + logger.info(f"为 {len(result)} 个物品生成了相似度") return result +def save_results(result, output_file, name_mappings, logger): + """ + 保存相似度结果到文件 + + Args: + result: 相似度字典 + output_file: 输出文件路径 + name_mappings: ID到名称的映射 + logger: 日志对象 + """ + logger.info(f"保存结果到 {output_file}...") + + with open(output_file, 'w', encoding='utf-8') as f: + for item_id, sims in result.items(): + # 获取物品名称 + item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown' + + if not sims: + continue + + # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,... + sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims]) + f.write(f'{item_id}\t{item_name}\t{sim_str}\n') + + logger.info(f"结果保存完成") + + def main(): parser = argparse.ArgumentParser(description='Run DeepWalk for i2i similarity') parser.add_argument('--num_walks', type=int, default=I2I_CONFIG['deepwalk']['num_walks'], @@ -225,6 +206,10 @@ def main(): help='Save graph edge file') parser.add_argument('--debug', action='store_true', help='Enable debug mode with detailed logging and readable output') + parser.add_argument('--use_softmax', action='store_true', + help='Use softmax-based alias sampling (default: False)') + parser.add_argument('--temperature', type=float, default=1.0, + help='Temperature for softmax (default: 1.0)') args = parser.parse_args() @@ -242,10 +227,25 @@ def main(): 'epochs': args.epochs, 'top_n': args.top_n, 'lookback_days': args.lookback_days, - 'debug': args.debug + 'debug': args.debug, + 'use_softmax': args.use_softmax, + 'temperature': args.temperature } log_algorithm_params(logger, params) + # 创建临时目录 + temp_dir = os.path.join(OUTPUT_DIR, 'temp') + os.makedirs(temp_dir, exist_ok=True) + + date_str = datetime.now().strftime('%Y%m%d') + edge_file = os.path.join(temp_dir, f'item_graph_{date_str}.txt') + walks_file = os.path.join(temp_dir, f'walks_{date_str}.txt') + + # ============================================================ + # 步骤1: 从数据库获取数据并构建边文件 + # ============================================================ + log_processing_step(logger, "从数据库获取数据") + # 创建数据库连接 logger.info("连接数据库...") engine = create_db_connection( @@ -295,51 +295,67 @@ def main(): } logger.debug(f"行为权重: {behavior_weights}") - # 构建物品图 - log_processing_step(logger, "构建物品图") - graph = build_item_graph(df, behavior_weights) - logger.info(f"构建物品图完成,共 {len(graph)} 个节点") - - # 保存边文件(可选) - if args.save_graph: - edge_file = os.path.join(OUTPUT_DIR, f'item_graph_{datetime.now().strftime("%Y%m%d")}.txt') - save_edge_file(graph, edge_file) - logger.info(f"图边文件已保存到 {edge_file}") + # 构建边文件 + log_processing_step(logger, "构建边文件") + num_nodes = build_edge_file_from_db(df, behavior_weights, edge_file, logger) + + # ============================================================ + # 步骤2: 使用DeepWalk进行随机游走 + # ============================================================ + log_processing_step(logger, "执行DeepWalk随机游走") + + logger.info("初始化DeepWalk...") + deepwalk = DeepWalk( + edge_file=edge_file, + node_tag_file=None, # 不使用标签游走 + use_softmax=args.use_softmax, + temperature=args.temperature, + p_tag_walk=0.0 # 不使用标签游走 + ) - # 生成随机游走 - log_processing_step(logger, "生成随机游走") - walks = generate_walks(graph, args.num_walks, args.walk_length) - logger.info(f"生成 {len(walks)} 条游走路径") + logger.info("开始随机游走...") + deepwalk.simulate_walks( + num_walks=args.num_walks, + walk_length=args.walk_length, + workers=args.workers, + output_file=walks_file + ) - # 训练Word2Vec模型 + # ============================================================ + # 步骤3: 训练Word2Vec模型 + # ============================================================ log_processing_step(logger, "训练Word2Vec模型") + w2v_config = { 'vector_size': args.vector_size, 'window_size': args.window_size, 'min_count': args.min_count, 'workers': args.workers, 'epochs': args.epochs, - 'sg': 1 + 'sg': 1 # Skip-gram } logger.debug(f"Word2Vec配置: {w2v_config}") - model = train_word2vec(walks, w2v_config) - logger.info(f"训练完成。词汇表大小:{len(model.wv)}") + model = train_word2vec_from_walks(walks_file, w2v_config, logger) # 保存模型(可选) if args.save_model: - model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{datetime.now().strftime("%Y%m%d")}.model') + model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{date_str}.model') model.save(model_path) logger.info(f"模型已保存到 {model_path}") - # 生成相似度 + # ============================================================ + # 步骤4: 生成相似度 + # ============================================================ log_processing_step(logger, "生成相似度") - result = generate_similarities(model, top_n=args.top_n) - logger.info(f"生成了 {len(result)} 个物品的相似度") + result = generate_similarities(model, args.top_n, logger) - # 输出结果 + # ============================================================ + # 步骤5: 保存结果 + # ============================================================ log_processing_step(logger, "保存结果") - output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{datetime.now().strftime("%Y%m%d")}.txt') + + output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{date_str}.txt') # 获取name mappings name_mappings = {} @@ -347,23 +363,14 @@ def main(): logger.info("获取物品名称映射...") name_mappings = fetch_name_mappings(engine, debug=True) - logger.info(f"写入结果到 {output_file}...") - with open(output_file, 'w', encoding='utf-8') as f: - for item_id, sims in result.items(): - # 使用name_mappings获取名称 - item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown' - if item_name == 'Unknown' and 'item_name' in df.columns: - item_name = df[df['item_id'].astype(str) == item_id]['item_name'].iloc[0] if len(df[df['item_id'].astype(str) == item_id]) > 0 else 'Unknown' - - if not sims: - continue - - # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,... - sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims]) - f.write(f'{item_id}\t{item_name}\t{sim_str}\n') + save_results(result, output_file, name_mappings, logger) - logger.info(f"完成!为 {len(result)} 个物品生成了相似度") - logger.info(f"输出保存到:{output_file}") + logger.info(f"✓ DeepWalk完成!") + logger.info(f" - 输出文件: {output_file}") + logger.info(f" - 商品数: {len(result)}") + if result: + avg_sims = sum(len(sims) for sims in result.values()) / len(result) + logger.info(f" - 平均相似商品数: {avg_sims:.1f}") # 如果启用debug模式,保存可读格式 if args.debug: @@ -374,8 +381,20 @@ def main(): name_mappings, description='i2i:deepwalk' ) + + # 清理临时文件(可选) + if not args.save_graph: + if os.path.exists(edge_file): + os.remove(edge_file) + logger.debug(f"已删除临时文件: {edge_file}") + if os.path.exists(walks_file): + os.remove(walks_file) + logger.debug(f"已删除临时文件: {walks_file}") + + print(f"✓ DeepWalk相似度计算完成") + print(f" - 输出文件: {output_file}") + print(f" - 商品数: {len(result)}") if __name__ == '__main__': main() - -- libgit2 0.21.2