Commit 0e45f7029178f20715ff1b7724250e407013c3d9

Authored by tangwang
1 parent 1088c261

deepwalk refactor for memsave and perfermance optimize

Showing 1 changed file with 157 additions and 138 deletions   Show diff stats
offline_tasks/scripts/i2i_deepwalk.py
1 """ 1 """
2 i2i - DeepWalk算法实现 2 i2i - DeepWalk算法实现
3 基于用户-物品图结构训练DeepWalk模型,获取物品向量相似度 3 基于用户-物品图结构训练DeepWalk模型,获取物品向量相似度
  4 +复用 graphembedding/deepwalk/ 的高效实现
4 """ 5 """
5 -import sys  
6 -import os  
7 -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))  
8 -  
9 import pandas as pd 6 import pandas as pd
10 import argparse 7 import argparse
  8 +import os
  9 +import sys
11 from datetime import datetime 10 from datetime import datetime
12 from collections import defaultdict 11 from collections import defaultdict
13 from gensim.models import Word2Vec 12 from gensim.models import Word2Vec
14 -import numpy as np  
15 from db_service import create_db_connection 13 from db_service import create_db_connection
16 from offline_tasks.config.offline_config import ( 14 from offline_tasks.config.offline_config import (
17 DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range, 15 DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
18 DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N 16 DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
19 ) 17 )
20 from offline_tasks.scripts.debug_utils import ( 18 from offline_tasks.scripts.debug_utils import (
21 - setup_debug_logger, log_dataframe_info, log_dict_stats, 19 + setup_debug_logger, log_dataframe_info,
22 save_readable_index, fetch_name_mappings, log_algorithm_params, 20 save_readable_index, fetch_name_mappings, log_algorithm_params,
23 log_processing_step 21 log_processing_step
24 ) 22 )
25 23
  24 +# 导入 DeepWalk 实现
  25 +sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'deepwalk'))
  26 +from deepwalk import DeepWalk
  27 +
26 28
27 -def build_item_graph(df, behavior_weights): 29 +def build_edge_file_from_db(df, behavior_weights, output_path, logger):
28 """ 30 """
29 - 构建物品图(基于用户共同交互) 31 + 从数据库数据构建边文件
  32 + 边文件格式: item_id \t neighbor_id1:weight1,neighbor_id2:weight2,...
30 33
31 Args: 34 Args:
32 df: DataFrame with columns: user_id, item_id, event_type 35 df: DataFrame with columns: user_id, item_id, event_type
33 behavior_weights: 行为权重字典 36 behavior_weights: 行为权重字典
34 -  
35 - Returns:  
36 - edge_dict: {item_id: {neighbor_id: weight}} 37 + output_path: 边文件输出路径
  38 + logger: 日志对象
37 """ 39 """
  40 + logger.info("开始构建物品图...")
  41 +
38 # 构建用户-物品列表 42 # 构建用户-物品列表
39 user_items = defaultdict(list) 43 user_items = defaultdict(list)
40 44
@@ -43,13 +47,19 @@ def build_item_graph(df, behavior_weights): @@ -43,13 +47,19 @@ def build_item_graph(df, behavior_weights):
43 item_id = str(row['item_id']) 47 item_id = str(row['item_id'])
44 event_type = row['event_type'] 48 event_type = row['event_type']
45 weight = behavior_weights.get(event_type, 1.0) 49 weight = behavior_weights.get(event_type, 1.0)
46 -  
47 user_items[user_id].append((item_id, weight)) 50 user_items[user_id].append((item_id, weight))
48 51
  52 + logger.info(f"共有 {len(user_items)} 个用户")
  53 +
49 # 构建物品图边 54 # 构建物品图边
50 edge_dict = defaultdict(lambda: defaultdict(float)) 55 edge_dict = defaultdict(lambda: defaultdict(float))
51 56
52 for user_id, items in user_items.items(): 57 for user_id, items in user_items.items():
  58 + # 限制每个用户的物品数量,避免内存爆炸
  59 + if len(items) > 100:
  60 + # 按权重排序,只保留前100个
  61 + items = sorted(items, key=lambda x: -x[1])[:100]
  62 +
53 # 物品两两组合,构建边 63 # 物品两两组合,构建边
54 for i in range(len(items)): 64 for i in range(len(items)):
55 item_i, weight_i = items[i] 65 item_i, weight_i = items[i]
@@ -61,106 +71,47 @@ def build_item_graph(df, behavior_weights): @@ -61,106 +71,47 @@ def build_item_graph(df, behavior_weights):
61 edge_dict[item_i][item_j] += edge_weight 71 edge_dict[item_i][item_j] += edge_weight
62 edge_dict[item_j][item_i] += edge_weight 72 edge_dict[item_j][item_i] += edge_weight
63 73
64 - return edge_dict  
65 -  
66 -  
67 -def save_edge_file(edge_dict, output_path):  
68 - """  
69 - 保存边文件 74 + logger.info(f"构建物品图完成,共 {len(edge_dict)} 个节点")
70 75
71 - Args:  
72 - edge_dict: 边字典  
73 - output_path: 输出路径  
74 - """ 76 + # 保存边文件
  77 + logger.info(f"保存边文件到 {output_path}")
75 with open(output_path, 'w', encoding='utf-8') as f: 78 with open(output_path, 'w', encoding='utf-8') as f:
76 for item_id, neighbors in edge_dict.items(): 79 for item_id, neighbors in edge_dict.items():
77 - # 格式: item_id \t neighbor1:weight1,neighbor2:weight2,...  
78 neighbor_str = ','.join([f'{nbr}:{weight:.4f}' for nbr, weight in neighbors.items()]) 80 neighbor_str = ','.join([f'{nbr}:{weight:.4f}' for nbr, weight in neighbors.items()])
79 f.write(f'{item_id}\t{neighbor_str}\n') 81 f.write(f'{item_id}\t{neighbor_str}\n')
80 82
81 - print(f"Edge file saved to {output_path}") 83 + logger.info(f"边文件保存完成")
  84 + return len(edge_dict)
82 85
83 86
84 -def random_walk(graph, start_node, walk_length): 87 +def train_word2vec_from_walks(walks_file, config, logger):
85 """ 88 """
86 - 执行随机游走 89 + 从游走文件训练Word2Vec模型
87 90
88 Args: 91 Args:
89 - graph: 图结构 {node: {neighbor: weight}}  
90 - start_node: 起始节点  
91 - walk_length: 游走长度  
92 -  
93 - Returns:  
94 - 游走序列  
95 - """  
96 - walk = [start_node]  
97 -  
98 - while len(walk) < walk_length:  
99 - cur = walk[-1]  
100 -  
101 - if cur not in graph or not graph[cur]:  
102 - break  
103 -  
104 - # 获取邻居和权重  
105 - neighbors = list(graph[cur].keys())  
106 - weights = list(graph[cur].values())  
107 -  
108 - # 归一化权重  
109 - total_weight = sum(weights)  
110 - if total_weight == 0:  
111 - break  
112 -  
113 - probs = [w / total_weight for w in weights]  
114 -  
115 - # 按权重随机选择下一个节点  
116 - next_node = np.random.choice(neighbors, p=probs)  
117 - walk.append(next_node)  
118 -  
119 - return walk  
120 -  
121 -  
122 -def generate_walks(graph, num_walks, walk_length):  
123 - """  
124 - 生成随机游走序列  
125 -  
126 - Args:  
127 - graph: 图结构  
128 - num_walks: 每个节点的游走次数  
129 - walk_length: 游走长度 92 + walks_file: 游走序列文件路径
  93 + config: Word2Vec配置
  94 + logger: 日志对象
130 95
131 Returns: 96 Returns:
132 - List of walks 97 + Word2Vec模型
133 """ 98 """
134 - walks = []  
135 - nodes = list(graph.keys())  
136 -  
137 - print(f"Generating {num_walks} walks per node, walk length {walk_length}...") 99 + logger.info(f"从 {walks_file} 读取游走序列...")
138 100
139 - for _ in range(num_walks):  
140 - np.random.shuffle(nodes)  
141 - for node in nodes:  
142 - walk = random_walk(graph, node, walk_length) 101 + # 读取游走序列
  102 + sentences = []
  103 + with open(walks_file, 'r', encoding='utf-8') as f:
  104 + for line in f:
  105 + walk = line.strip().split()
143 if len(walk) >= 2: 106 if len(walk) >= 2:
144 - walks.append(walk)  
145 -  
146 - return walks  
147 -  
148 -  
149 -def train_word2vec(walks, config):  
150 - """  
151 - 训练Word2Vec模型 107 + sentences.append(walk)
152 108
153 - Args:  
154 - walks: 游走序列列表  
155 - config: Word2Vec配置  
156 -  
157 - Returns:  
158 - Word2Vec模型  
159 - """  
160 - print(f"Training Word2Vec with {len(walks)} walks...") 109 + logger.info(f"共读取 {len(sentences)} 条游走序列")
161 110
  111 + # 训练Word2Vec
  112 + logger.info("开始训练Word2Vec模型...")
162 model = Word2Vec( 113 model = Word2Vec(
163 - sentences=walks, 114 + sentences=sentences,
164 vector_size=config['vector_size'], 115 vector_size=config['vector_size'],
165 window=config['window_size'], 116 window=config['window_size'],
166 min_count=config['min_count'], 117 min_count=config['min_count'],
@@ -170,21 +121,23 @@ def train_word2vec(walks, config): @@ -170,21 +121,23 @@ def train_word2vec(walks, config):
170 seed=42 121 seed=42
171 ) 122 )
172 123
173 - print(f"Training completed. Vocabulary size: {len(model.wv)}") 124 + logger.info(f"训练完成。词汇表大小:{len(model.wv)}")
174 return model 125 return model
175 126
176 127
177 -def generate_similarities(model, top_n=50): 128 +def generate_similarities(model, top_n, logger):
178 """ 129 """
179 - 生成物品相似度 130 + 从Word2Vec模型生成物品相似度
180 131
181 Args: 132 Args:
182 model: Word2Vec模型 133 model: Word2Vec模型
183 top_n: Top N similar items 134 top_n: Top N similar items
  135 + logger: 日志对象
184 136
185 Returns: 137 Returns:
186 Dict[item_id, List[Tuple(similar_item_id, score)]] 138 Dict[item_id, List[Tuple(similar_item_id, score)]]
187 """ 139 """
  140 + logger.info("生成相似度...")
188 result = {} 141 result = {}
189 142
190 for item_id in model.wv.index_to_key: 143 for item_id in model.wv.index_to_key:
@@ -194,9 +147,37 @@ def generate_similarities(model, top_n=50): @@ -194,9 +147,37 @@ def generate_similarities(model, top_n=50):
194 except KeyError: 147 except KeyError:
195 continue 148 continue
196 149
  150 + logger.info(f"为 {len(result)} 个物品生成了相似度")
197 return result 151 return result
198 152
199 153
  154 +def save_results(result, output_file, name_mappings, logger):
  155 + """
  156 + 保存相似度结果到文件
  157 +
  158 + Args:
  159 + result: 相似度字典
  160 + output_file: 输出文件路径
  161 + name_mappings: ID到名称的映射
  162 + logger: 日志对象
  163 + """
  164 + logger.info(f"保存结果到 {output_file}...")
  165 +
  166 + with open(output_file, 'w', encoding='utf-8') as f:
  167 + for item_id, sims in result.items():
  168 + # 获取物品名称
  169 + item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown'
  170 +
  171 + if not sims:
  172 + continue
  173 +
  174 + # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
  175 + sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])
  176 + f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
  177 +
  178 + logger.info(f"结果保存完成")
  179 +
  180 +
200 def main(): 181 def main():
201 parser = argparse.ArgumentParser(description='Run DeepWalk for i2i similarity') 182 parser = argparse.ArgumentParser(description='Run DeepWalk for i2i similarity')
202 parser.add_argument('--num_walks', type=int, default=I2I_CONFIG['deepwalk']['num_walks'], 183 parser.add_argument('--num_walks', type=int, default=I2I_CONFIG['deepwalk']['num_walks'],
@@ -225,6 +206,10 @@ def main(): @@ -225,6 +206,10 @@ def main():
225 help='Save graph edge file') 206 help='Save graph edge file')
226 parser.add_argument('--debug', action='store_true', 207 parser.add_argument('--debug', action='store_true',
227 help='Enable debug mode with detailed logging and readable output') 208 help='Enable debug mode with detailed logging and readable output')
  209 + parser.add_argument('--use_softmax', action='store_true',
  210 + help='Use softmax-based alias sampling (default: False)')
  211 + parser.add_argument('--temperature', type=float, default=1.0,
  212 + help='Temperature for softmax (default: 1.0)')
228 213
229 args = parser.parse_args() 214 args = parser.parse_args()
230 215
@@ -242,10 +227,25 @@ def main(): @@ -242,10 +227,25 @@ def main():
242 'epochs': args.epochs, 227 'epochs': args.epochs,
243 'top_n': args.top_n, 228 'top_n': args.top_n,
244 'lookback_days': args.lookback_days, 229 'lookback_days': args.lookback_days,
245 - 'debug': args.debug 230 + 'debug': args.debug,
  231 + 'use_softmax': args.use_softmax,
  232 + 'temperature': args.temperature
246 } 233 }
247 log_algorithm_params(logger, params) 234 log_algorithm_params(logger, params)
248 235
  236 + # 创建临时目录
  237 + temp_dir = os.path.join(OUTPUT_DIR, 'temp')
  238 + os.makedirs(temp_dir, exist_ok=True)
  239 +
  240 + date_str = datetime.now().strftime('%Y%m%d')
  241 + edge_file = os.path.join(temp_dir, f'item_graph_{date_str}.txt')
  242 + walks_file = os.path.join(temp_dir, f'walks_{date_str}.txt')
  243 +
  244 + # ============================================================
  245 + # 步骤1: 从数据库获取数据并构建边文件
  246 + # ============================================================
  247 + log_processing_step(logger, "从数据库获取数据")
  248 +
249 # 创建数据库连接 249 # 创建数据库连接
250 logger.info("连接数据库...") 250 logger.info("连接数据库...")
251 engine = create_db_connection( 251 engine = create_db_connection(
@@ -295,51 +295,67 @@ def main(): @@ -295,51 +295,67 @@ def main():
295 } 295 }
296 logger.debug(f"行为权重: {behavior_weights}") 296 logger.debug(f"行为权重: {behavior_weights}")
297 297
298 - # 构建物品图  
299 - log_processing_step(logger, "构建物品图")  
300 - graph = build_item_graph(df, behavior_weights)  
301 - logger.info(f"构建物品图完成,共 {len(graph)} 个节点")  
302 -  
303 - # 保存边文件(可选)  
304 - if args.save_graph:  
305 - edge_file = os.path.join(OUTPUT_DIR, f'item_graph_{datetime.now().strftime("%Y%m%d")}.txt')  
306 - save_edge_file(graph, edge_file)  
307 - logger.info(f"图边文件已保存到 {edge_file}") 298 + # 构建边文件
  299 + log_processing_step(logger, "构建边文件")
  300 + num_nodes = build_edge_file_from_db(df, behavior_weights, edge_file, logger)
  301 +
  302 + # ============================================================
  303 + # 步骤2: 使用DeepWalk进行随机游走
  304 + # ============================================================
  305 + log_processing_step(logger, "执行DeepWalk随机游走")
  306 +
  307 + logger.info("初始化DeepWalk...")
  308 + deepwalk = DeepWalk(
  309 + edge_file=edge_file,
  310 + node_tag_file=None, # 不使用标签游走
  311 + use_softmax=args.use_softmax,
  312 + temperature=args.temperature,
  313 + p_tag_walk=0.0 # 不使用标签游走
  314 + )
308 315
309 - # 生成随机游走  
310 - log_processing_step(logger, "生成随机游走")  
311 - walks = generate_walks(graph, args.num_walks, args.walk_length)  
312 - logger.info(f"生成 {len(walks)} 条游走路径") 316 + logger.info("开始随机游走...")
  317 + deepwalk.simulate_walks(
  318 + num_walks=args.num_walks,
  319 + walk_length=args.walk_length,
  320 + workers=args.workers,
  321 + output_file=walks_file
  322 + )
313 323
314 - # 训练Word2Vec模型 324 + # ============================================================
  325 + # 步骤3: 训练Word2Vec模型
  326 + # ============================================================
315 log_processing_step(logger, "训练Word2Vec模型") 327 log_processing_step(logger, "训练Word2Vec模型")
  328 +
316 w2v_config = { 329 w2v_config = {
317 'vector_size': args.vector_size, 330 'vector_size': args.vector_size,
318 'window_size': args.window_size, 331 'window_size': args.window_size,
319 'min_count': args.min_count, 332 'min_count': args.min_count,
320 'workers': args.workers, 333 'workers': args.workers,
321 'epochs': args.epochs, 334 'epochs': args.epochs,
322 - 'sg': 1 335 + 'sg': 1 # Skip-gram
323 } 336 }
324 logger.debug(f"Word2Vec配置: {w2v_config}") 337 logger.debug(f"Word2Vec配置: {w2v_config}")
325 338
326 - model = train_word2vec(walks, w2v_config)  
327 - logger.info(f"训练完成。词汇表大小:{len(model.wv)}") 339 + model = train_word2vec_from_walks(walks_file, w2v_config, logger)
328 340
329 # 保存模型(可选) 341 # 保存模型(可选)
330 if args.save_model: 342 if args.save_model:
331 - model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{datetime.now().strftime("%Y%m%d")}.model') 343 + model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{date_str}.model')
332 model.save(model_path) 344 model.save(model_path)
333 logger.info(f"模型已保存到 {model_path}") 345 logger.info(f"模型已保存到 {model_path}")
334 346
335 - # 生成相似度 347 + # ============================================================
  348 + # 步骤4: 生成相似度
  349 + # ============================================================
336 log_processing_step(logger, "生成相似度") 350 log_processing_step(logger, "生成相似度")
337 - result = generate_similarities(model, top_n=args.top_n)  
338 - logger.info(f"生成了 {len(result)} 个物品的相似度") 351 + result = generate_similarities(model, args.top_n, logger)
339 352
340 - # 输出结果 353 + # ============================================================
  354 + # 步骤5: 保存结果
  355 + # ============================================================
341 log_processing_step(logger, "保存结果") 356 log_processing_step(logger, "保存结果")
342 - output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{datetime.now().strftime("%Y%m%d")}.txt') 357 +
  358 + output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{date_str}.txt')
343 359
344 # 获取name mappings 360 # 获取name mappings
345 name_mappings = {} 361 name_mappings = {}
@@ -347,23 +363,14 @@ def main(): @@ -347,23 +363,14 @@ def main():
347 logger.info("获取物品名称映射...") 363 logger.info("获取物品名称映射...")
348 name_mappings = fetch_name_mappings(engine, debug=True) 364 name_mappings = fetch_name_mappings(engine, debug=True)
349 365
350 - logger.info(f"写入结果到 {output_file}...")  
351 - with open(output_file, 'w', encoding='utf-8') as f:  
352 - for item_id, sims in result.items():  
353 - # 使用name_mappings获取名称  
354 - item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown'  
355 - if item_name == 'Unknown' and 'item_name' in df.columns:  
356 - item_name = df[df['item_id'].astype(str) == item_id]['item_name'].iloc[0] if len(df[df['item_id'].astype(str) == item_id]) > 0 else 'Unknown'  
357 -  
358 - if not sims:  
359 - continue  
360 -  
361 - # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...  
362 - sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])  
363 - f.write(f'{item_id}\t{item_name}\t{sim_str}\n') 366 + save_results(result, output_file, name_mappings, logger)
364 367
365 - logger.info(f"完成!为 {len(result)} 个物品生成了相似度")  
366 - logger.info(f"输出保存到:{output_file}") 368 + logger.info(f"✓ DeepWalk完成!")
  369 + logger.info(f" - 输出文件: {output_file}")
  370 + logger.info(f" - 商品数: {len(result)}")
  371 + if result:
  372 + avg_sims = sum(len(sims) for sims in result.values()) / len(result)
  373 + logger.info(f" - 平均相似商品数: {avg_sims:.1f}")
367 374
368 # 如果启用debug模式,保存可读格式 375 # 如果启用debug模式,保存可读格式
369 if args.debug: 376 if args.debug:
@@ -374,8 +381,20 @@ def main(): @@ -374,8 +381,20 @@ def main():
374 name_mappings, 381 name_mappings,
375 description='i2i:deepwalk' 382 description='i2i:deepwalk'
376 ) 383 )
  384 +
  385 + # 清理临时文件(可选)
  386 + if not args.save_graph:
  387 + if os.path.exists(edge_file):
  388 + os.remove(edge_file)
  389 + logger.debug(f"已删除临时文件: {edge_file}")
  390 + if os.path.exists(walks_file):
  391 + os.remove(walks_file)
  392 + logger.debug(f"已删除临时文件: {walks_file}")
  393 +
  394 + print(f"✓ DeepWalk相似度计算完成")
  395 + print(f" - 输出文件: {output_file}")
  396 + print(f" - 商品数: {len(result)}")
377 397
378 398
379 if __name__ == '__main__': 399 if __name__ == '__main__':
380 main() 400 main()
381 -