Commit 0e45f7029178f20715ff1b7724250e407013c3d9
1 parent
1088c261
deepwalk refactor for memsave and perfermance optimize
Showing
1 changed file
with
157 additions
and
138 deletions
Show diff stats
offline_tasks/scripts/i2i_deepwalk.py
| 1 | 1 | """ |
| 2 | 2 | i2i - DeepWalk算法实现 |
| 3 | 3 | 基于用户-物品图结构训练DeepWalk模型,获取物品向量相似度 |
| 4 | +复用 graphembedding/deepwalk/ 的高效实现 | |
| 4 | 5 | """ |
| 5 | -import sys | |
| 6 | -import os | |
| 7 | -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| 8 | - | |
| 9 | 6 | import pandas as pd |
| 10 | 7 | import argparse |
| 8 | +import os | |
| 9 | +import sys | |
| 11 | 10 | from datetime import datetime |
| 12 | 11 | from collections import defaultdict |
| 13 | 12 | from gensim.models import Word2Vec |
| 14 | -import numpy as np | |
| 15 | 13 | from db_service import create_db_connection |
| 16 | 14 | from offline_tasks.config.offline_config import ( |
| 17 | 15 | DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range, |
| 18 | 16 | DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N |
| 19 | 17 | ) |
| 20 | 18 | from offline_tasks.scripts.debug_utils import ( |
| 21 | - setup_debug_logger, log_dataframe_info, log_dict_stats, | |
| 19 | + setup_debug_logger, log_dataframe_info, | |
| 22 | 20 | save_readable_index, fetch_name_mappings, log_algorithm_params, |
| 23 | 21 | log_processing_step |
| 24 | 22 | ) |
| 25 | 23 | |
| 24 | +# 导入 DeepWalk 实现 | |
| 25 | +sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'deepwalk')) | |
| 26 | +from deepwalk import DeepWalk | |
| 27 | + | |
| 26 | 28 | |
| 27 | -def build_item_graph(df, behavior_weights): | |
| 29 | +def build_edge_file_from_db(df, behavior_weights, output_path, logger): | |
| 28 | 30 | """ |
| 29 | - 构建物品图(基于用户共同交互) | |
| 31 | + 从数据库数据构建边文件 | |
| 32 | + 边文件格式: item_id \t neighbor_id1:weight1,neighbor_id2:weight2,... | |
| 30 | 33 | |
| 31 | 34 | Args: |
| 32 | 35 | df: DataFrame with columns: user_id, item_id, event_type |
| 33 | 36 | behavior_weights: 行为权重字典 |
| 34 | - | |
| 35 | - Returns: | |
| 36 | - edge_dict: {item_id: {neighbor_id: weight}} | |
| 37 | + output_path: 边文件输出路径 | |
| 38 | + logger: 日志对象 | |
| 37 | 39 | """ |
| 40 | + logger.info("开始构建物品图...") | |
| 41 | + | |
| 38 | 42 | # 构建用户-物品列表 |
| 39 | 43 | user_items = defaultdict(list) |
| 40 | 44 | |
| ... | ... | @@ -43,13 +47,19 @@ def build_item_graph(df, behavior_weights): |
| 43 | 47 | item_id = str(row['item_id']) |
| 44 | 48 | event_type = row['event_type'] |
| 45 | 49 | weight = behavior_weights.get(event_type, 1.0) |
| 46 | - | |
| 47 | 50 | user_items[user_id].append((item_id, weight)) |
| 48 | 51 | |
| 52 | + logger.info(f"共有 {len(user_items)} 个用户") | |
| 53 | + | |
| 49 | 54 | # 构建物品图边 |
| 50 | 55 | edge_dict = defaultdict(lambda: defaultdict(float)) |
| 51 | 56 | |
| 52 | 57 | for user_id, items in user_items.items(): |
| 58 | + # 限制每个用户的物品数量,避免内存爆炸 | |
| 59 | + if len(items) > 100: | |
| 60 | + # 按权重排序,只保留前100个 | |
| 61 | + items = sorted(items, key=lambda x: -x[1])[:100] | |
| 62 | + | |
| 53 | 63 | # 物品两两组合,构建边 |
| 54 | 64 | for i in range(len(items)): |
| 55 | 65 | item_i, weight_i = items[i] |
| ... | ... | @@ -61,106 +71,47 @@ def build_item_graph(df, behavior_weights): |
| 61 | 71 | edge_dict[item_i][item_j] += edge_weight |
| 62 | 72 | edge_dict[item_j][item_i] += edge_weight |
| 63 | 73 | |
| 64 | - return edge_dict | |
| 65 | - | |
| 66 | - | |
| 67 | -def save_edge_file(edge_dict, output_path): | |
| 68 | - """ | |
| 69 | - 保存边文件 | |
| 74 | + logger.info(f"构建物品图完成,共 {len(edge_dict)} 个节点") | |
| 70 | 75 | |
| 71 | - Args: | |
| 72 | - edge_dict: 边字典 | |
| 73 | - output_path: 输出路径 | |
| 74 | - """ | |
| 76 | + # 保存边文件 | |
| 77 | + logger.info(f"保存边文件到 {output_path}") | |
| 75 | 78 | with open(output_path, 'w', encoding='utf-8') as f: |
| 76 | 79 | for item_id, neighbors in edge_dict.items(): |
| 77 | - # 格式: item_id \t neighbor1:weight1,neighbor2:weight2,... | |
| 78 | 80 | neighbor_str = ','.join([f'{nbr}:{weight:.4f}' for nbr, weight in neighbors.items()]) |
| 79 | 81 | f.write(f'{item_id}\t{neighbor_str}\n') |
| 80 | 82 | |
| 81 | - print(f"Edge file saved to {output_path}") | |
| 83 | + logger.info(f"边文件保存完成") | |
| 84 | + return len(edge_dict) | |
| 82 | 85 | |
| 83 | 86 | |
| 84 | -def random_walk(graph, start_node, walk_length): | |
| 87 | +def train_word2vec_from_walks(walks_file, config, logger): | |
| 85 | 88 | """ |
| 86 | - 执行随机游走 | |
| 89 | + 从游走文件训练Word2Vec模型 | |
| 87 | 90 | |
| 88 | 91 | Args: |
| 89 | - graph: 图结构 {node: {neighbor: weight}} | |
| 90 | - start_node: 起始节点 | |
| 91 | - walk_length: 游走长度 | |
| 92 | - | |
| 93 | - Returns: | |
| 94 | - 游走序列 | |
| 95 | - """ | |
| 96 | - walk = [start_node] | |
| 97 | - | |
| 98 | - while len(walk) < walk_length: | |
| 99 | - cur = walk[-1] | |
| 100 | - | |
| 101 | - if cur not in graph or not graph[cur]: | |
| 102 | - break | |
| 103 | - | |
| 104 | - # 获取邻居和权重 | |
| 105 | - neighbors = list(graph[cur].keys()) | |
| 106 | - weights = list(graph[cur].values()) | |
| 107 | - | |
| 108 | - # 归一化权重 | |
| 109 | - total_weight = sum(weights) | |
| 110 | - if total_weight == 0: | |
| 111 | - break | |
| 112 | - | |
| 113 | - probs = [w / total_weight for w in weights] | |
| 114 | - | |
| 115 | - # 按权重随机选择下一个节点 | |
| 116 | - next_node = np.random.choice(neighbors, p=probs) | |
| 117 | - walk.append(next_node) | |
| 118 | - | |
| 119 | - return walk | |
| 120 | - | |
| 121 | - | |
| 122 | -def generate_walks(graph, num_walks, walk_length): | |
| 123 | - """ | |
| 124 | - 生成随机游走序列 | |
| 125 | - | |
| 126 | - Args: | |
| 127 | - graph: 图结构 | |
| 128 | - num_walks: 每个节点的游走次数 | |
| 129 | - walk_length: 游走长度 | |
| 92 | + walks_file: 游走序列文件路径 | |
| 93 | + config: Word2Vec配置 | |
| 94 | + logger: 日志对象 | |
| 130 | 95 | |
| 131 | 96 | Returns: |
| 132 | - List of walks | |
| 97 | + Word2Vec模型 | |
| 133 | 98 | """ |
| 134 | - walks = [] | |
| 135 | - nodes = list(graph.keys()) | |
| 136 | - | |
| 137 | - print(f"Generating {num_walks} walks per node, walk length {walk_length}...") | |
| 99 | + logger.info(f"从 {walks_file} 读取游走序列...") | |
| 138 | 100 | |
| 139 | - for _ in range(num_walks): | |
| 140 | - np.random.shuffle(nodes) | |
| 141 | - for node in nodes: | |
| 142 | - walk = random_walk(graph, node, walk_length) | |
| 101 | + # 读取游走序列 | |
| 102 | + sentences = [] | |
| 103 | + with open(walks_file, 'r', encoding='utf-8') as f: | |
| 104 | + for line in f: | |
| 105 | + walk = line.strip().split() | |
| 143 | 106 | if len(walk) >= 2: |
| 144 | - walks.append(walk) | |
| 145 | - | |
| 146 | - return walks | |
| 147 | - | |
| 148 | - | |
| 149 | -def train_word2vec(walks, config): | |
| 150 | - """ | |
| 151 | - 训练Word2Vec模型 | |
| 107 | + sentences.append(walk) | |
| 152 | 108 | |
| 153 | - Args: | |
| 154 | - walks: 游走序列列表 | |
| 155 | - config: Word2Vec配置 | |
| 156 | - | |
| 157 | - Returns: | |
| 158 | - Word2Vec模型 | |
| 159 | - """ | |
| 160 | - print(f"Training Word2Vec with {len(walks)} walks...") | |
| 109 | + logger.info(f"共读取 {len(sentences)} 条游走序列") | |
| 161 | 110 | |
| 111 | + # 训练Word2Vec | |
| 112 | + logger.info("开始训练Word2Vec模型...") | |
| 162 | 113 | model = Word2Vec( |
| 163 | - sentences=walks, | |
| 114 | + sentences=sentences, | |
| 164 | 115 | vector_size=config['vector_size'], |
| 165 | 116 | window=config['window_size'], |
| 166 | 117 | min_count=config['min_count'], |
| ... | ... | @@ -170,21 +121,23 @@ def train_word2vec(walks, config): |
| 170 | 121 | seed=42 |
| 171 | 122 | ) |
| 172 | 123 | |
| 173 | - print(f"Training completed. Vocabulary size: {len(model.wv)}") | |
| 124 | + logger.info(f"训练完成。词汇表大小:{len(model.wv)}") | |
| 174 | 125 | return model |
| 175 | 126 | |
| 176 | 127 | |
| 177 | -def generate_similarities(model, top_n=50): | |
| 128 | +def generate_similarities(model, top_n, logger): | |
| 178 | 129 | """ |
| 179 | - 生成物品相似度 | |
| 130 | + 从Word2Vec模型生成物品相似度 | |
| 180 | 131 | |
| 181 | 132 | Args: |
| 182 | 133 | model: Word2Vec模型 |
| 183 | 134 | top_n: Top N similar items |
| 135 | + logger: 日志对象 | |
| 184 | 136 | |
| 185 | 137 | Returns: |
| 186 | 138 | Dict[item_id, List[Tuple(similar_item_id, score)]] |
| 187 | 139 | """ |
| 140 | + logger.info("生成相似度...") | |
| 188 | 141 | result = {} |
| 189 | 142 | |
| 190 | 143 | for item_id in model.wv.index_to_key: |
| ... | ... | @@ -194,9 +147,37 @@ def generate_similarities(model, top_n=50): |
| 194 | 147 | except KeyError: |
| 195 | 148 | continue |
| 196 | 149 | |
| 150 | + logger.info(f"为 {len(result)} 个物品生成了相似度") | |
| 197 | 151 | return result |
| 198 | 152 | |
| 199 | 153 | |
| 154 | +def save_results(result, output_file, name_mappings, logger): | |
| 155 | + """ | |
| 156 | + 保存相似度结果到文件 | |
| 157 | + | |
| 158 | + Args: | |
| 159 | + result: 相似度字典 | |
| 160 | + output_file: 输出文件路径 | |
| 161 | + name_mappings: ID到名称的映射 | |
| 162 | + logger: 日志对象 | |
| 163 | + """ | |
| 164 | + logger.info(f"保存结果到 {output_file}...") | |
| 165 | + | |
| 166 | + with open(output_file, 'w', encoding='utf-8') as f: | |
| 167 | + for item_id, sims in result.items(): | |
| 168 | + # 获取物品名称 | |
| 169 | + item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown' | |
| 170 | + | |
| 171 | + if not sims: | |
| 172 | + continue | |
| 173 | + | |
| 174 | + # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,... | |
| 175 | + sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims]) | |
| 176 | + f.write(f'{item_id}\t{item_name}\t{sim_str}\n') | |
| 177 | + | |
| 178 | + logger.info(f"结果保存完成") | |
| 179 | + | |
| 180 | + | |
| 200 | 181 | def main(): |
| 201 | 182 | parser = argparse.ArgumentParser(description='Run DeepWalk for i2i similarity') |
| 202 | 183 | parser.add_argument('--num_walks', type=int, default=I2I_CONFIG['deepwalk']['num_walks'], |
| ... | ... | @@ -225,6 +206,10 @@ def main(): |
| 225 | 206 | help='Save graph edge file') |
| 226 | 207 | parser.add_argument('--debug', action='store_true', |
| 227 | 208 | help='Enable debug mode with detailed logging and readable output') |
| 209 | + parser.add_argument('--use_softmax', action='store_true', | |
| 210 | + help='Use softmax-based alias sampling (default: False)') | |
| 211 | + parser.add_argument('--temperature', type=float, default=1.0, | |
| 212 | + help='Temperature for softmax (default: 1.0)') | |
| 228 | 213 | |
| 229 | 214 | args = parser.parse_args() |
| 230 | 215 | |
| ... | ... | @@ -242,10 +227,25 @@ def main(): |
| 242 | 227 | 'epochs': args.epochs, |
| 243 | 228 | 'top_n': args.top_n, |
| 244 | 229 | 'lookback_days': args.lookback_days, |
| 245 | - 'debug': args.debug | |
| 230 | + 'debug': args.debug, | |
| 231 | + 'use_softmax': args.use_softmax, | |
| 232 | + 'temperature': args.temperature | |
| 246 | 233 | } |
| 247 | 234 | log_algorithm_params(logger, params) |
| 248 | 235 | |
| 236 | + # 创建临时目录 | |
| 237 | + temp_dir = os.path.join(OUTPUT_DIR, 'temp') | |
| 238 | + os.makedirs(temp_dir, exist_ok=True) | |
| 239 | + | |
| 240 | + date_str = datetime.now().strftime('%Y%m%d') | |
| 241 | + edge_file = os.path.join(temp_dir, f'item_graph_{date_str}.txt') | |
| 242 | + walks_file = os.path.join(temp_dir, f'walks_{date_str}.txt') | |
| 243 | + | |
| 244 | + # ============================================================ | |
| 245 | + # 步骤1: 从数据库获取数据并构建边文件 | |
| 246 | + # ============================================================ | |
| 247 | + log_processing_step(logger, "从数据库获取数据") | |
| 248 | + | |
| 249 | 249 | # 创建数据库连接 |
| 250 | 250 | logger.info("连接数据库...") |
| 251 | 251 | engine = create_db_connection( |
| ... | ... | @@ -295,51 +295,67 @@ def main(): |
| 295 | 295 | } |
| 296 | 296 | logger.debug(f"行为权重: {behavior_weights}") |
| 297 | 297 | |
| 298 | - # 构建物品图 | |
| 299 | - log_processing_step(logger, "构建物品图") | |
| 300 | - graph = build_item_graph(df, behavior_weights) | |
| 301 | - logger.info(f"构建物品图完成,共 {len(graph)} 个节点") | |
| 302 | - | |
| 303 | - # 保存边文件(可选) | |
| 304 | - if args.save_graph: | |
| 305 | - edge_file = os.path.join(OUTPUT_DIR, f'item_graph_{datetime.now().strftime("%Y%m%d")}.txt') | |
| 306 | - save_edge_file(graph, edge_file) | |
| 307 | - logger.info(f"图边文件已保存到 {edge_file}") | |
| 298 | + # 构建边文件 | |
| 299 | + log_processing_step(logger, "构建边文件") | |
| 300 | + num_nodes = build_edge_file_from_db(df, behavior_weights, edge_file, logger) | |
| 301 | + | |
| 302 | + # ============================================================ | |
| 303 | + # 步骤2: 使用DeepWalk进行随机游走 | |
| 304 | + # ============================================================ | |
| 305 | + log_processing_step(logger, "执行DeepWalk随机游走") | |
| 306 | + | |
| 307 | + logger.info("初始化DeepWalk...") | |
| 308 | + deepwalk = DeepWalk( | |
| 309 | + edge_file=edge_file, | |
| 310 | + node_tag_file=None, # 不使用标签游走 | |
| 311 | + use_softmax=args.use_softmax, | |
| 312 | + temperature=args.temperature, | |
| 313 | + p_tag_walk=0.0 # 不使用标签游走 | |
| 314 | + ) | |
| 308 | 315 | |
| 309 | - # 生成随机游走 | |
| 310 | - log_processing_step(logger, "生成随机游走") | |
| 311 | - walks = generate_walks(graph, args.num_walks, args.walk_length) | |
| 312 | - logger.info(f"生成 {len(walks)} 条游走路径") | |
| 316 | + logger.info("开始随机游走...") | |
| 317 | + deepwalk.simulate_walks( | |
| 318 | + num_walks=args.num_walks, | |
| 319 | + walk_length=args.walk_length, | |
| 320 | + workers=args.workers, | |
| 321 | + output_file=walks_file | |
| 322 | + ) | |
| 313 | 323 | |
| 314 | - # 训练Word2Vec模型 | |
| 324 | + # ============================================================ | |
| 325 | + # 步骤3: 训练Word2Vec模型 | |
| 326 | + # ============================================================ | |
| 315 | 327 | log_processing_step(logger, "训练Word2Vec模型") |
| 328 | + | |
| 316 | 329 | w2v_config = { |
| 317 | 330 | 'vector_size': args.vector_size, |
| 318 | 331 | 'window_size': args.window_size, |
| 319 | 332 | 'min_count': args.min_count, |
| 320 | 333 | 'workers': args.workers, |
| 321 | 334 | 'epochs': args.epochs, |
| 322 | - 'sg': 1 | |
| 335 | + 'sg': 1 # Skip-gram | |
| 323 | 336 | } |
| 324 | 337 | logger.debug(f"Word2Vec配置: {w2v_config}") |
| 325 | 338 | |
| 326 | - model = train_word2vec(walks, w2v_config) | |
| 327 | - logger.info(f"训练完成。词汇表大小:{len(model.wv)}") | |
| 339 | + model = train_word2vec_from_walks(walks_file, w2v_config, logger) | |
| 328 | 340 | |
| 329 | 341 | # 保存模型(可选) |
| 330 | 342 | if args.save_model: |
| 331 | - model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{datetime.now().strftime("%Y%m%d")}.model') | |
| 343 | + model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{date_str}.model') | |
| 332 | 344 | model.save(model_path) |
| 333 | 345 | logger.info(f"模型已保存到 {model_path}") |
| 334 | 346 | |
| 335 | - # 生成相似度 | |
| 347 | + # ============================================================ | |
| 348 | + # 步骤4: 生成相似度 | |
| 349 | + # ============================================================ | |
| 336 | 350 | log_processing_step(logger, "生成相似度") |
| 337 | - result = generate_similarities(model, top_n=args.top_n) | |
| 338 | - logger.info(f"生成了 {len(result)} 个物品的相似度") | |
| 351 | + result = generate_similarities(model, args.top_n, logger) | |
| 339 | 352 | |
| 340 | - # 输出结果 | |
| 353 | + # ============================================================ | |
| 354 | + # 步骤5: 保存结果 | |
| 355 | + # ============================================================ | |
| 341 | 356 | log_processing_step(logger, "保存结果") |
| 342 | - output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{datetime.now().strftime("%Y%m%d")}.txt') | |
| 357 | + | |
| 358 | + output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{date_str}.txt') | |
| 343 | 359 | |
| 344 | 360 | # 获取name mappings |
| 345 | 361 | name_mappings = {} |
| ... | ... | @@ -347,23 +363,14 @@ def main(): |
| 347 | 363 | logger.info("获取物品名称映射...") |
| 348 | 364 | name_mappings = fetch_name_mappings(engine, debug=True) |
| 349 | 365 | |
| 350 | - logger.info(f"写入结果到 {output_file}...") | |
| 351 | - with open(output_file, 'w', encoding='utf-8') as f: | |
| 352 | - for item_id, sims in result.items(): | |
| 353 | - # 使用name_mappings获取名称 | |
| 354 | - item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown' | |
| 355 | - if item_name == 'Unknown' and 'item_name' in df.columns: | |
| 356 | - item_name = df[df['item_id'].astype(str) == item_id]['item_name'].iloc[0] if len(df[df['item_id'].astype(str) == item_id]) > 0 else 'Unknown' | |
| 357 | - | |
| 358 | - if not sims: | |
| 359 | - continue | |
| 360 | - | |
| 361 | - # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,... | |
| 362 | - sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims]) | |
| 363 | - f.write(f'{item_id}\t{item_name}\t{sim_str}\n') | |
| 366 | + save_results(result, output_file, name_mappings, logger) | |
| 364 | 367 | |
| 365 | - logger.info(f"完成!为 {len(result)} 个物品生成了相似度") | |
| 366 | - logger.info(f"输出保存到:{output_file}") | |
| 368 | + logger.info(f"✓ DeepWalk完成!") | |
| 369 | + logger.info(f" - 输出文件: {output_file}") | |
| 370 | + logger.info(f" - 商品数: {len(result)}") | |
| 371 | + if result: | |
| 372 | + avg_sims = sum(len(sims) for sims in result.values()) / len(result) | |
| 373 | + logger.info(f" - 平均相似商品数: {avg_sims:.1f}") | |
| 367 | 374 | |
| 368 | 375 | # 如果启用debug模式,保存可读格式 |
| 369 | 376 | if args.debug: |
| ... | ... | @@ -374,8 +381,20 @@ def main(): |
| 374 | 381 | name_mappings, |
| 375 | 382 | description='i2i:deepwalk' |
| 376 | 383 | ) |
| 384 | + | |
| 385 | + # 清理临时文件(可选) | |
| 386 | + if not args.save_graph: | |
| 387 | + if os.path.exists(edge_file): | |
| 388 | + os.remove(edge_file) | |
| 389 | + logger.debug(f"已删除临时文件: {edge_file}") | |
| 390 | + if os.path.exists(walks_file): | |
| 391 | + os.remove(walks_file) | |
| 392 | + logger.debug(f"已删除临时文件: {walks_file}") | |
| 393 | + | |
| 394 | + print(f"✓ DeepWalk相似度计算完成") | |
| 395 | + print(f" - 输出文件: {output_file}") | |
| 396 | + print(f" - 商品数: {len(result)}") | |
| 377 | 397 | |
| 378 | 398 | |
| 379 | 399 | if __name__ == '__main__': |
| 380 | 400 | main() |
| 381 | - | ... | ... |