Commit 0e45f7029178f20715ff1b7724250e407013c3d9

Authored by tangwang
1 parent 1088c261

deepwalk refactor for memsave and perfermance optimize

Showing 1 changed file with 157 additions and 138 deletions   Show diff stats
offline_tasks/scripts/i2i_deepwalk.py
1 1 """
2 2 i2i - DeepWalk算法实现
3 3 基于用户-物品图结构训练DeepWalk模型,获取物品向量相似度
  4 +复用 graphembedding/deepwalk/ 的高效实现
4 5 """
5   -import sys
6   -import os
7   -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
8   -
9 6 import pandas as pd
10 7 import argparse
  8 +import os
  9 +import sys
11 10 from datetime import datetime
12 11 from collections import defaultdict
13 12 from gensim.models import Word2Vec
14   -import numpy as np
15 13 from db_service import create_db_connection
16 14 from offline_tasks.config.offline_config import (
17 15 DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
18 16 DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
19 17 )
20 18 from offline_tasks.scripts.debug_utils import (
21   - setup_debug_logger, log_dataframe_info, log_dict_stats,
  19 + setup_debug_logger, log_dataframe_info,
22 20 save_readable_index, fetch_name_mappings, log_algorithm_params,
23 21 log_processing_step
24 22 )
25 23  
  24 +# 导入 DeepWalk 实现
  25 +sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'deepwalk'))
  26 +from deepwalk import DeepWalk
  27 +
26 28  
27   -def build_item_graph(df, behavior_weights):
  29 +def build_edge_file_from_db(df, behavior_weights, output_path, logger):
28 30 """
29   - 构建物品图(基于用户共同交互)
  31 + 从数据库数据构建边文件
  32 + 边文件格式: item_id \t neighbor_id1:weight1,neighbor_id2:weight2,...
30 33  
31 34 Args:
32 35 df: DataFrame with columns: user_id, item_id, event_type
33 36 behavior_weights: 行为权重字典
34   -
35   - Returns:
36   - edge_dict: {item_id: {neighbor_id: weight}}
  37 + output_path: 边文件输出路径
  38 + logger: 日志对象
37 39 """
  40 + logger.info("开始构建物品图...")
  41 +
38 42 # 构建用户-物品列表
39 43 user_items = defaultdict(list)
40 44  
... ... @@ -43,13 +47,19 @@ def build_item_graph(df, behavior_weights):
43 47 item_id = str(row['item_id'])
44 48 event_type = row['event_type']
45 49 weight = behavior_weights.get(event_type, 1.0)
46   -
47 50 user_items[user_id].append((item_id, weight))
48 51  
  52 + logger.info(f"共有 {len(user_items)} 个用户")
  53 +
49 54 # 构建物品图边
50 55 edge_dict = defaultdict(lambda: defaultdict(float))
51 56  
52 57 for user_id, items in user_items.items():
  58 + # 限制每个用户的物品数量,避免内存爆炸
  59 + if len(items) > 100:
  60 + # 按权重排序,只保留前100个
  61 + items = sorted(items, key=lambda x: -x[1])[:100]
  62 +
53 63 # 物品两两组合,构建边
54 64 for i in range(len(items)):
55 65 item_i, weight_i = items[i]
... ... @@ -61,106 +71,47 @@ def build_item_graph(df, behavior_weights):
61 71 edge_dict[item_i][item_j] += edge_weight
62 72 edge_dict[item_j][item_i] += edge_weight
63 73  
64   - return edge_dict
65   -
66   -
67   -def save_edge_file(edge_dict, output_path):
68   - """
69   - 保存边文件
  74 + logger.info(f"构建物品图完成,共 {len(edge_dict)} 个节点")
70 75  
71   - Args:
72   - edge_dict: 边字典
73   - output_path: 输出路径
74   - """
  76 + # 保存边文件
  77 + logger.info(f"保存边文件到 {output_path}")
75 78 with open(output_path, 'w', encoding='utf-8') as f:
76 79 for item_id, neighbors in edge_dict.items():
77   - # 格式: item_id \t neighbor1:weight1,neighbor2:weight2,...
78 80 neighbor_str = ','.join([f'{nbr}:{weight:.4f}' for nbr, weight in neighbors.items()])
79 81 f.write(f'{item_id}\t{neighbor_str}\n')
80 82  
81   - print(f"Edge file saved to {output_path}")
  83 + logger.info(f"边文件保存完成")
  84 + return len(edge_dict)
82 85  
83 86  
84   -def random_walk(graph, start_node, walk_length):
  87 +def train_word2vec_from_walks(walks_file, config, logger):
85 88 """
86   - 执行随机游走
  89 + 从游走文件训练Word2Vec模型
87 90  
88 91 Args:
89   - graph: 图结构 {node: {neighbor: weight}}
90   - start_node: 起始节点
91   - walk_length: 游走长度
92   -
93   - Returns:
94   - 游走序列
95   - """
96   - walk = [start_node]
97   -
98   - while len(walk) < walk_length:
99   - cur = walk[-1]
100   -
101   - if cur not in graph or not graph[cur]:
102   - break
103   -
104   - # 获取邻居和权重
105   - neighbors = list(graph[cur].keys())
106   - weights = list(graph[cur].values())
107   -
108   - # 归一化权重
109   - total_weight = sum(weights)
110   - if total_weight == 0:
111   - break
112   -
113   - probs = [w / total_weight for w in weights]
114   -
115   - # 按权重随机选择下一个节点
116   - next_node = np.random.choice(neighbors, p=probs)
117   - walk.append(next_node)
118   -
119   - return walk
120   -
121   -
122   -def generate_walks(graph, num_walks, walk_length):
123   - """
124   - 生成随机游走序列
125   -
126   - Args:
127   - graph: 图结构
128   - num_walks: 每个节点的游走次数
129   - walk_length: 游走长度
  92 + walks_file: 游走序列文件路径
  93 + config: Word2Vec配置
  94 + logger: 日志对象
130 95  
131 96 Returns:
132   - List of walks
  97 + Word2Vec模型
133 98 """
134   - walks = []
135   - nodes = list(graph.keys())
136   -
137   - print(f"Generating {num_walks} walks per node, walk length {walk_length}...")
  99 + logger.info(f"从 {walks_file} 读取游走序列...")
138 100  
139   - for _ in range(num_walks):
140   - np.random.shuffle(nodes)
141   - for node in nodes:
142   - walk = random_walk(graph, node, walk_length)
  101 + # 读取游走序列
  102 + sentences = []
  103 + with open(walks_file, 'r', encoding='utf-8') as f:
  104 + for line in f:
  105 + walk = line.strip().split()
143 106 if len(walk) >= 2:
144   - walks.append(walk)
145   -
146   - return walks
147   -
148   -
149   -def train_word2vec(walks, config):
150   - """
151   - 训练Word2Vec模型
  107 + sentences.append(walk)
152 108  
153   - Args:
154   - walks: 游走序列列表
155   - config: Word2Vec配置
156   -
157   - Returns:
158   - Word2Vec模型
159   - """
160   - print(f"Training Word2Vec with {len(walks)} walks...")
  109 + logger.info(f"共读取 {len(sentences)} 条游走序列")
161 110  
  111 + # 训练Word2Vec
  112 + logger.info("开始训练Word2Vec模型...")
162 113 model = Word2Vec(
163   - sentences=walks,
  114 + sentences=sentences,
164 115 vector_size=config['vector_size'],
165 116 window=config['window_size'],
166 117 min_count=config['min_count'],
... ... @@ -170,21 +121,23 @@ def train_word2vec(walks, config):
170 121 seed=42
171 122 )
172 123  
173   - print(f"Training completed. Vocabulary size: {len(model.wv)}")
  124 + logger.info(f"训练完成。词汇表大小:{len(model.wv)}")
174 125 return model
175 126  
176 127  
177   -def generate_similarities(model, top_n=50):
  128 +def generate_similarities(model, top_n, logger):
178 129 """
179   - 生成物品相似度
  130 + 从Word2Vec模型生成物品相似度
180 131  
181 132 Args:
182 133 model: Word2Vec模型
183 134 top_n: Top N similar items
  135 + logger: 日志对象
184 136  
185 137 Returns:
186 138 Dict[item_id, List[Tuple(similar_item_id, score)]]
187 139 """
  140 + logger.info("生成相似度...")
188 141 result = {}
189 142  
190 143 for item_id in model.wv.index_to_key:
... ... @@ -194,9 +147,37 @@ def generate_similarities(model, top_n=50):
194 147 except KeyError:
195 148 continue
196 149  
  150 + logger.info(f"为 {len(result)} 个物品生成了相似度")
197 151 return result
198 152  
199 153  
  154 +def save_results(result, output_file, name_mappings, logger):
  155 + """
  156 + 保存相似度结果到文件
  157 +
  158 + Args:
  159 + result: 相似度字典
  160 + output_file: 输出文件路径
  161 + name_mappings: ID到名称的映射
  162 + logger: 日志对象
  163 + """
  164 + logger.info(f"保存结果到 {output_file}...")
  165 +
  166 + with open(output_file, 'w', encoding='utf-8') as f:
  167 + for item_id, sims in result.items():
  168 + # 获取物品名称
  169 + item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown'
  170 +
  171 + if not sims:
  172 + continue
  173 +
  174 + # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
  175 + sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])
  176 + f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
  177 +
  178 + logger.info(f"结果保存完成")
  179 +
  180 +
200 181 def main():
201 182 parser = argparse.ArgumentParser(description='Run DeepWalk for i2i similarity')
202 183 parser.add_argument('--num_walks', type=int, default=I2I_CONFIG['deepwalk']['num_walks'],
... ... @@ -225,6 +206,10 @@ def main():
225 206 help='Save graph edge file')
226 207 parser.add_argument('--debug', action='store_true',
227 208 help='Enable debug mode with detailed logging and readable output')
  209 + parser.add_argument('--use_softmax', action='store_true',
  210 + help='Use softmax-based alias sampling (default: False)')
  211 + parser.add_argument('--temperature', type=float, default=1.0,
  212 + help='Temperature for softmax (default: 1.0)')
228 213  
229 214 args = parser.parse_args()
230 215  
... ... @@ -242,10 +227,25 @@ def main():
242 227 'epochs': args.epochs,
243 228 'top_n': args.top_n,
244 229 'lookback_days': args.lookback_days,
245   - 'debug': args.debug
  230 + 'debug': args.debug,
  231 + 'use_softmax': args.use_softmax,
  232 + 'temperature': args.temperature
246 233 }
247 234 log_algorithm_params(logger, params)
248 235  
  236 + # 创建临时目录
  237 + temp_dir = os.path.join(OUTPUT_DIR, 'temp')
  238 + os.makedirs(temp_dir, exist_ok=True)
  239 +
  240 + date_str = datetime.now().strftime('%Y%m%d')
  241 + edge_file = os.path.join(temp_dir, f'item_graph_{date_str}.txt')
  242 + walks_file = os.path.join(temp_dir, f'walks_{date_str}.txt')
  243 +
  244 + # ============================================================
  245 + # 步骤1: 从数据库获取数据并构建边文件
  246 + # ============================================================
  247 + log_processing_step(logger, "从数据库获取数据")
  248 +
249 249 # 创建数据库连接
250 250 logger.info("连接数据库...")
251 251 engine = create_db_connection(
... ... @@ -295,51 +295,67 @@ def main():
295 295 }
296 296 logger.debug(f"行为权重: {behavior_weights}")
297 297  
298   - # 构建物品图
299   - log_processing_step(logger, "构建物品图")
300   - graph = build_item_graph(df, behavior_weights)
301   - logger.info(f"构建物品图完成,共 {len(graph)} 个节点")
302   -
303   - # 保存边文件(可选)
304   - if args.save_graph:
305   - edge_file = os.path.join(OUTPUT_DIR, f'item_graph_{datetime.now().strftime("%Y%m%d")}.txt')
306   - save_edge_file(graph, edge_file)
307   - logger.info(f"图边文件已保存到 {edge_file}")
  298 + # 构建边文件
  299 + log_processing_step(logger, "构建边文件")
  300 + num_nodes = build_edge_file_from_db(df, behavior_weights, edge_file, logger)
  301 +
  302 + # ============================================================
  303 + # 步骤2: 使用DeepWalk进行随机游走
  304 + # ============================================================
  305 + log_processing_step(logger, "执行DeepWalk随机游走")
  306 +
  307 + logger.info("初始化DeepWalk...")
  308 + deepwalk = DeepWalk(
  309 + edge_file=edge_file,
  310 + node_tag_file=None, # 不使用标签游走
  311 + use_softmax=args.use_softmax,
  312 + temperature=args.temperature,
  313 + p_tag_walk=0.0 # 不使用标签游走
  314 + )
308 315  
309   - # 生成随机游走
310   - log_processing_step(logger, "生成随机游走")
311   - walks = generate_walks(graph, args.num_walks, args.walk_length)
312   - logger.info(f"生成 {len(walks)} 条游走路径")
  316 + logger.info("开始随机游走...")
  317 + deepwalk.simulate_walks(
  318 + num_walks=args.num_walks,
  319 + walk_length=args.walk_length,
  320 + workers=args.workers,
  321 + output_file=walks_file
  322 + )
313 323  
314   - # 训练Word2Vec模型
  324 + # ============================================================
  325 + # 步骤3: 训练Word2Vec模型
  326 + # ============================================================
315 327 log_processing_step(logger, "训练Word2Vec模型")
  328 +
316 329 w2v_config = {
317 330 'vector_size': args.vector_size,
318 331 'window_size': args.window_size,
319 332 'min_count': args.min_count,
320 333 'workers': args.workers,
321 334 'epochs': args.epochs,
322   - 'sg': 1
  335 + 'sg': 1 # Skip-gram
323 336 }
324 337 logger.debug(f"Word2Vec配置: {w2v_config}")
325 338  
326   - model = train_word2vec(walks, w2v_config)
327   - logger.info(f"训练完成。词汇表大小:{len(model.wv)}")
  339 + model = train_word2vec_from_walks(walks_file, w2v_config, logger)
328 340  
329 341 # 保存模型(可选)
330 342 if args.save_model:
331   - model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{datetime.now().strftime("%Y%m%d")}.model')
  343 + model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{date_str}.model')
332 344 model.save(model_path)
333 345 logger.info(f"模型已保存到 {model_path}")
334 346  
335   - # 生成相似度
  347 + # ============================================================
  348 + # 步骤4: 生成相似度
  349 + # ============================================================
336 350 log_processing_step(logger, "生成相似度")
337   - result = generate_similarities(model, top_n=args.top_n)
338   - logger.info(f"生成了 {len(result)} 个物品的相似度")
  351 + result = generate_similarities(model, args.top_n, logger)
339 352  
340   - # 输出结果
  353 + # ============================================================
  354 + # 步骤5: 保存结果
  355 + # ============================================================
341 356 log_processing_step(logger, "保存结果")
342   - output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{datetime.now().strftime("%Y%m%d")}.txt')
  357 +
  358 + output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{date_str}.txt')
343 359  
344 360 # 获取name mappings
345 361 name_mappings = {}
... ... @@ -347,23 +363,14 @@ def main():
347 363 logger.info("获取物品名称映射...")
348 364 name_mappings = fetch_name_mappings(engine, debug=True)
349 365  
350   - logger.info(f"写入结果到 {output_file}...")
351   - with open(output_file, 'w', encoding='utf-8') as f:
352   - for item_id, sims in result.items():
353   - # 使用name_mappings获取名称
354   - item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown'
355   - if item_name == 'Unknown' and 'item_name' in df.columns:
356   - item_name = df[df['item_id'].astype(str) == item_id]['item_name'].iloc[0] if len(df[df['item_id'].astype(str) == item_id]) > 0 else 'Unknown'
357   -
358   - if not sims:
359   - continue
360   -
361   - # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
362   - sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])
363   - f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
  366 + save_results(result, output_file, name_mappings, logger)
364 367  
365   - logger.info(f"完成!为 {len(result)} 个物品生成了相似度")
366   - logger.info(f"输出保存到:{output_file}")
  368 + logger.info(f"✓ DeepWalk完成!")
  369 + logger.info(f" - 输出文件: {output_file}")
  370 + logger.info(f" - 商品数: {len(result)}")
  371 + if result:
  372 + avg_sims = sum(len(sims) for sims in result.values()) / len(result)
  373 + logger.info(f" - 平均相似商品数: {avg_sims:.1f}")
367 374  
368 375 # 如果启用debug模式,保存可读格式
369 376 if args.debug:
... ... @@ -374,8 +381,20 @@ def main():
374 381 name_mappings,
375 382 description='i2i:deepwalk'
376 383 )
  384 +
  385 + # 清理临时文件(可选)
  386 + if not args.save_graph:
  387 + if os.path.exists(edge_file):
  388 + os.remove(edge_file)
  389 + logger.debug(f"已删除临时文件: {edge_file}")
  390 + if os.path.exists(walks_file):
  391 + os.remove(walks_file)
  392 + logger.debug(f"已删除临时文件: {walks_file}")
  393 +
  394 + print(f"✓ DeepWalk相似度计算完成")
  395 + print(f" - 输出文件: {output_file}")
  396 + print(f" - 商品数: {len(result)}")
377 397  
378 398  
379 399 if __name__ == '__main__':
380 400 main()
381   -
... ...