Commit 0e45f7029178f20715ff1b7724250e407013c3d9
1 parent
1088c261
deepwalk refactor for memsave and perfermance optimize
Showing
1 changed file
with
157 additions
and
138 deletions
Show diff stats
offline_tasks/scripts/i2i_deepwalk.py
| 1 | """ | 1 | """ |
| 2 | i2i - DeepWalk算法实现 | 2 | i2i - DeepWalk算法实现 |
| 3 | 基于用户-物品图结构训练DeepWalk模型,获取物品向量相似度 | 3 | 基于用户-物品图结构训练DeepWalk模型,获取物品向量相似度 |
| 4 | +复用 graphembedding/deepwalk/ 的高效实现 | ||
| 4 | """ | 5 | """ |
| 5 | -import sys | ||
| 6 | -import os | ||
| 7 | -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | ||
| 8 | - | ||
| 9 | import pandas as pd | 6 | import pandas as pd |
| 10 | import argparse | 7 | import argparse |
| 8 | +import os | ||
| 9 | +import sys | ||
| 11 | from datetime import datetime | 10 | from datetime import datetime |
| 12 | from collections import defaultdict | 11 | from collections import defaultdict |
| 13 | from gensim.models import Word2Vec | 12 | from gensim.models import Word2Vec |
| 14 | -import numpy as np | ||
| 15 | from db_service import create_db_connection | 13 | from db_service import create_db_connection |
| 16 | from offline_tasks.config.offline_config import ( | 14 | from offline_tasks.config.offline_config import ( |
| 17 | DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range, | 15 | DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range, |
| 18 | DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N | 16 | DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N |
| 19 | ) | 17 | ) |
| 20 | from offline_tasks.scripts.debug_utils import ( | 18 | from offline_tasks.scripts.debug_utils import ( |
| 21 | - setup_debug_logger, log_dataframe_info, log_dict_stats, | 19 | + setup_debug_logger, log_dataframe_info, |
| 22 | save_readable_index, fetch_name_mappings, log_algorithm_params, | 20 | save_readable_index, fetch_name_mappings, log_algorithm_params, |
| 23 | log_processing_step | 21 | log_processing_step |
| 24 | ) | 22 | ) |
| 25 | 23 | ||
| 24 | +# 导入 DeepWalk 实现 | ||
| 25 | +sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'deepwalk')) | ||
| 26 | +from deepwalk import DeepWalk | ||
| 27 | + | ||
| 26 | 28 | ||
| 27 | -def build_item_graph(df, behavior_weights): | 29 | +def build_edge_file_from_db(df, behavior_weights, output_path, logger): |
| 28 | """ | 30 | """ |
| 29 | - 构建物品图(基于用户共同交互) | 31 | + 从数据库数据构建边文件 |
| 32 | + 边文件格式: item_id \t neighbor_id1:weight1,neighbor_id2:weight2,... | ||
| 30 | 33 | ||
| 31 | Args: | 34 | Args: |
| 32 | df: DataFrame with columns: user_id, item_id, event_type | 35 | df: DataFrame with columns: user_id, item_id, event_type |
| 33 | behavior_weights: 行为权重字典 | 36 | behavior_weights: 行为权重字典 |
| 34 | - | ||
| 35 | - Returns: | ||
| 36 | - edge_dict: {item_id: {neighbor_id: weight}} | 37 | + output_path: 边文件输出路径 |
| 38 | + logger: 日志对象 | ||
| 37 | """ | 39 | """ |
| 40 | + logger.info("开始构建物品图...") | ||
| 41 | + | ||
| 38 | # 构建用户-物品列表 | 42 | # 构建用户-物品列表 |
| 39 | user_items = defaultdict(list) | 43 | user_items = defaultdict(list) |
| 40 | 44 | ||
| @@ -43,13 +47,19 @@ def build_item_graph(df, behavior_weights): | @@ -43,13 +47,19 @@ def build_item_graph(df, behavior_weights): | ||
| 43 | item_id = str(row['item_id']) | 47 | item_id = str(row['item_id']) |
| 44 | event_type = row['event_type'] | 48 | event_type = row['event_type'] |
| 45 | weight = behavior_weights.get(event_type, 1.0) | 49 | weight = behavior_weights.get(event_type, 1.0) |
| 46 | - | ||
| 47 | user_items[user_id].append((item_id, weight)) | 50 | user_items[user_id].append((item_id, weight)) |
| 48 | 51 | ||
| 52 | + logger.info(f"共有 {len(user_items)} 个用户") | ||
| 53 | + | ||
| 49 | # 构建物品图边 | 54 | # 构建物品图边 |
| 50 | edge_dict = defaultdict(lambda: defaultdict(float)) | 55 | edge_dict = defaultdict(lambda: defaultdict(float)) |
| 51 | 56 | ||
| 52 | for user_id, items in user_items.items(): | 57 | for user_id, items in user_items.items(): |
| 58 | + # 限制每个用户的物品数量,避免内存爆炸 | ||
| 59 | + if len(items) > 100: | ||
| 60 | + # 按权重排序,只保留前100个 | ||
| 61 | + items = sorted(items, key=lambda x: -x[1])[:100] | ||
| 62 | + | ||
| 53 | # 物品两两组合,构建边 | 63 | # 物品两两组合,构建边 |
| 54 | for i in range(len(items)): | 64 | for i in range(len(items)): |
| 55 | item_i, weight_i = items[i] | 65 | item_i, weight_i = items[i] |
| @@ -61,106 +71,47 @@ def build_item_graph(df, behavior_weights): | @@ -61,106 +71,47 @@ def build_item_graph(df, behavior_weights): | ||
| 61 | edge_dict[item_i][item_j] += edge_weight | 71 | edge_dict[item_i][item_j] += edge_weight |
| 62 | edge_dict[item_j][item_i] += edge_weight | 72 | edge_dict[item_j][item_i] += edge_weight |
| 63 | 73 | ||
| 64 | - return edge_dict | ||
| 65 | - | ||
| 66 | - | ||
| 67 | -def save_edge_file(edge_dict, output_path): | ||
| 68 | - """ | ||
| 69 | - 保存边文件 | 74 | + logger.info(f"构建物品图完成,共 {len(edge_dict)} 个节点") |
| 70 | 75 | ||
| 71 | - Args: | ||
| 72 | - edge_dict: 边字典 | ||
| 73 | - output_path: 输出路径 | ||
| 74 | - """ | 76 | + # 保存边文件 |
| 77 | + logger.info(f"保存边文件到 {output_path}") | ||
| 75 | with open(output_path, 'w', encoding='utf-8') as f: | 78 | with open(output_path, 'w', encoding='utf-8') as f: |
| 76 | for item_id, neighbors in edge_dict.items(): | 79 | for item_id, neighbors in edge_dict.items(): |
| 77 | - # 格式: item_id \t neighbor1:weight1,neighbor2:weight2,... | ||
| 78 | neighbor_str = ','.join([f'{nbr}:{weight:.4f}' for nbr, weight in neighbors.items()]) | 80 | neighbor_str = ','.join([f'{nbr}:{weight:.4f}' for nbr, weight in neighbors.items()]) |
| 79 | f.write(f'{item_id}\t{neighbor_str}\n') | 81 | f.write(f'{item_id}\t{neighbor_str}\n') |
| 80 | 82 | ||
| 81 | - print(f"Edge file saved to {output_path}") | 83 | + logger.info(f"边文件保存完成") |
| 84 | + return len(edge_dict) | ||
| 82 | 85 | ||
| 83 | 86 | ||
| 84 | -def random_walk(graph, start_node, walk_length): | 87 | +def train_word2vec_from_walks(walks_file, config, logger): |
| 85 | """ | 88 | """ |
| 86 | - 执行随机游走 | 89 | + 从游走文件训练Word2Vec模型 |
| 87 | 90 | ||
| 88 | Args: | 91 | Args: |
| 89 | - graph: 图结构 {node: {neighbor: weight}} | ||
| 90 | - start_node: 起始节点 | ||
| 91 | - walk_length: 游走长度 | ||
| 92 | - | ||
| 93 | - Returns: | ||
| 94 | - 游走序列 | ||
| 95 | - """ | ||
| 96 | - walk = [start_node] | ||
| 97 | - | ||
| 98 | - while len(walk) < walk_length: | ||
| 99 | - cur = walk[-1] | ||
| 100 | - | ||
| 101 | - if cur not in graph or not graph[cur]: | ||
| 102 | - break | ||
| 103 | - | ||
| 104 | - # 获取邻居和权重 | ||
| 105 | - neighbors = list(graph[cur].keys()) | ||
| 106 | - weights = list(graph[cur].values()) | ||
| 107 | - | ||
| 108 | - # 归一化权重 | ||
| 109 | - total_weight = sum(weights) | ||
| 110 | - if total_weight == 0: | ||
| 111 | - break | ||
| 112 | - | ||
| 113 | - probs = [w / total_weight for w in weights] | ||
| 114 | - | ||
| 115 | - # 按权重随机选择下一个节点 | ||
| 116 | - next_node = np.random.choice(neighbors, p=probs) | ||
| 117 | - walk.append(next_node) | ||
| 118 | - | ||
| 119 | - return walk | ||
| 120 | - | ||
| 121 | - | ||
| 122 | -def generate_walks(graph, num_walks, walk_length): | ||
| 123 | - """ | ||
| 124 | - 生成随机游走序列 | ||
| 125 | - | ||
| 126 | - Args: | ||
| 127 | - graph: 图结构 | ||
| 128 | - num_walks: 每个节点的游走次数 | ||
| 129 | - walk_length: 游走长度 | 92 | + walks_file: 游走序列文件路径 |
| 93 | + config: Word2Vec配置 | ||
| 94 | + logger: 日志对象 | ||
| 130 | 95 | ||
| 131 | Returns: | 96 | Returns: |
| 132 | - List of walks | 97 | + Word2Vec模型 |
| 133 | """ | 98 | """ |
| 134 | - walks = [] | ||
| 135 | - nodes = list(graph.keys()) | ||
| 136 | - | ||
| 137 | - print(f"Generating {num_walks} walks per node, walk length {walk_length}...") | 99 | + logger.info(f"从 {walks_file} 读取游走序列...") |
| 138 | 100 | ||
| 139 | - for _ in range(num_walks): | ||
| 140 | - np.random.shuffle(nodes) | ||
| 141 | - for node in nodes: | ||
| 142 | - walk = random_walk(graph, node, walk_length) | 101 | + # 读取游走序列 |
| 102 | + sentences = [] | ||
| 103 | + with open(walks_file, 'r', encoding='utf-8') as f: | ||
| 104 | + for line in f: | ||
| 105 | + walk = line.strip().split() | ||
| 143 | if len(walk) >= 2: | 106 | if len(walk) >= 2: |
| 144 | - walks.append(walk) | ||
| 145 | - | ||
| 146 | - return walks | ||
| 147 | - | ||
| 148 | - | ||
| 149 | -def train_word2vec(walks, config): | ||
| 150 | - """ | ||
| 151 | - 训练Word2Vec模型 | 107 | + sentences.append(walk) |
| 152 | 108 | ||
| 153 | - Args: | ||
| 154 | - walks: 游走序列列表 | ||
| 155 | - config: Word2Vec配置 | ||
| 156 | - | ||
| 157 | - Returns: | ||
| 158 | - Word2Vec模型 | ||
| 159 | - """ | ||
| 160 | - print(f"Training Word2Vec with {len(walks)} walks...") | 109 | + logger.info(f"共读取 {len(sentences)} 条游走序列") |
| 161 | 110 | ||
| 111 | + # 训练Word2Vec | ||
| 112 | + logger.info("开始训练Word2Vec模型...") | ||
| 162 | model = Word2Vec( | 113 | model = Word2Vec( |
| 163 | - sentences=walks, | 114 | + sentences=sentences, |
| 164 | vector_size=config['vector_size'], | 115 | vector_size=config['vector_size'], |
| 165 | window=config['window_size'], | 116 | window=config['window_size'], |
| 166 | min_count=config['min_count'], | 117 | min_count=config['min_count'], |
| @@ -170,21 +121,23 @@ def train_word2vec(walks, config): | @@ -170,21 +121,23 @@ def train_word2vec(walks, config): | ||
| 170 | seed=42 | 121 | seed=42 |
| 171 | ) | 122 | ) |
| 172 | 123 | ||
| 173 | - print(f"Training completed. Vocabulary size: {len(model.wv)}") | 124 | + logger.info(f"训练完成。词汇表大小:{len(model.wv)}") |
| 174 | return model | 125 | return model |
| 175 | 126 | ||
| 176 | 127 | ||
| 177 | -def generate_similarities(model, top_n=50): | 128 | +def generate_similarities(model, top_n, logger): |
| 178 | """ | 129 | """ |
| 179 | - 生成物品相似度 | 130 | + 从Word2Vec模型生成物品相似度 |
| 180 | 131 | ||
| 181 | Args: | 132 | Args: |
| 182 | model: Word2Vec模型 | 133 | model: Word2Vec模型 |
| 183 | top_n: Top N similar items | 134 | top_n: Top N similar items |
| 135 | + logger: 日志对象 | ||
| 184 | 136 | ||
| 185 | Returns: | 137 | Returns: |
| 186 | Dict[item_id, List[Tuple(similar_item_id, score)]] | 138 | Dict[item_id, List[Tuple(similar_item_id, score)]] |
| 187 | """ | 139 | """ |
| 140 | + logger.info("生成相似度...") | ||
| 188 | result = {} | 141 | result = {} |
| 189 | 142 | ||
| 190 | for item_id in model.wv.index_to_key: | 143 | for item_id in model.wv.index_to_key: |
| @@ -194,9 +147,37 @@ def generate_similarities(model, top_n=50): | @@ -194,9 +147,37 @@ def generate_similarities(model, top_n=50): | ||
| 194 | except KeyError: | 147 | except KeyError: |
| 195 | continue | 148 | continue |
| 196 | 149 | ||
| 150 | + logger.info(f"为 {len(result)} 个物品生成了相似度") | ||
| 197 | return result | 151 | return result |
| 198 | 152 | ||
| 199 | 153 | ||
| 154 | +def save_results(result, output_file, name_mappings, logger): | ||
| 155 | + """ | ||
| 156 | + 保存相似度结果到文件 | ||
| 157 | + | ||
| 158 | + Args: | ||
| 159 | + result: 相似度字典 | ||
| 160 | + output_file: 输出文件路径 | ||
| 161 | + name_mappings: ID到名称的映射 | ||
| 162 | + logger: 日志对象 | ||
| 163 | + """ | ||
| 164 | + logger.info(f"保存结果到 {output_file}...") | ||
| 165 | + | ||
| 166 | + with open(output_file, 'w', encoding='utf-8') as f: | ||
| 167 | + for item_id, sims in result.items(): | ||
| 168 | + # 获取物品名称 | ||
| 169 | + item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown' | ||
| 170 | + | ||
| 171 | + if not sims: | ||
| 172 | + continue | ||
| 173 | + | ||
| 174 | + # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,... | ||
| 175 | + sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims]) | ||
| 176 | + f.write(f'{item_id}\t{item_name}\t{sim_str}\n') | ||
| 177 | + | ||
| 178 | + logger.info(f"结果保存完成") | ||
| 179 | + | ||
| 180 | + | ||
| 200 | def main(): | 181 | def main(): |
| 201 | parser = argparse.ArgumentParser(description='Run DeepWalk for i2i similarity') | 182 | parser = argparse.ArgumentParser(description='Run DeepWalk for i2i similarity') |
| 202 | parser.add_argument('--num_walks', type=int, default=I2I_CONFIG['deepwalk']['num_walks'], | 183 | parser.add_argument('--num_walks', type=int, default=I2I_CONFIG['deepwalk']['num_walks'], |
| @@ -225,6 +206,10 @@ def main(): | @@ -225,6 +206,10 @@ def main(): | ||
| 225 | help='Save graph edge file') | 206 | help='Save graph edge file') |
| 226 | parser.add_argument('--debug', action='store_true', | 207 | parser.add_argument('--debug', action='store_true', |
| 227 | help='Enable debug mode with detailed logging and readable output') | 208 | help='Enable debug mode with detailed logging and readable output') |
| 209 | + parser.add_argument('--use_softmax', action='store_true', | ||
| 210 | + help='Use softmax-based alias sampling (default: False)') | ||
| 211 | + parser.add_argument('--temperature', type=float, default=1.0, | ||
| 212 | + help='Temperature for softmax (default: 1.0)') | ||
| 228 | 213 | ||
| 229 | args = parser.parse_args() | 214 | args = parser.parse_args() |
| 230 | 215 | ||
| @@ -242,10 +227,25 @@ def main(): | @@ -242,10 +227,25 @@ def main(): | ||
| 242 | 'epochs': args.epochs, | 227 | 'epochs': args.epochs, |
| 243 | 'top_n': args.top_n, | 228 | 'top_n': args.top_n, |
| 244 | 'lookback_days': args.lookback_days, | 229 | 'lookback_days': args.lookback_days, |
| 245 | - 'debug': args.debug | 230 | + 'debug': args.debug, |
| 231 | + 'use_softmax': args.use_softmax, | ||
| 232 | + 'temperature': args.temperature | ||
| 246 | } | 233 | } |
| 247 | log_algorithm_params(logger, params) | 234 | log_algorithm_params(logger, params) |
| 248 | 235 | ||
| 236 | + # 创建临时目录 | ||
| 237 | + temp_dir = os.path.join(OUTPUT_DIR, 'temp') | ||
| 238 | + os.makedirs(temp_dir, exist_ok=True) | ||
| 239 | + | ||
| 240 | + date_str = datetime.now().strftime('%Y%m%d') | ||
| 241 | + edge_file = os.path.join(temp_dir, f'item_graph_{date_str}.txt') | ||
| 242 | + walks_file = os.path.join(temp_dir, f'walks_{date_str}.txt') | ||
| 243 | + | ||
| 244 | + # ============================================================ | ||
| 245 | + # 步骤1: 从数据库获取数据并构建边文件 | ||
| 246 | + # ============================================================ | ||
| 247 | + log_processing_step(logger, "从数据库获取数据") | ||
| 248 | + | ||
| 249 | # 创建数据库连接 | 249 | # 创建数据库连接 |
| 250 | logger.info("连接数据库...") | 250 | logger.info("连接数据库...") |
| 251 | engine = create_db_connection( | 251 | engine = create_db_connection( |
| @@ -295,51 +295,67 @@ def main(): | @@ -295,51 +295,67 @@ def main(): | ||
| 295 | } | 295 | } |
| 296 | logger.debug(f"行为权重: {behavior_weights}") | 296 | logger.debug(f"行为权重: {behavior_weights}") |
| 297 | 297 | ||
| 298 | - # 构建物品图 | ||
| 299 | - log_processing_step(logger, "构建物品图") | ||
| 300 | - graph = build_item_graph(df, behavior_weights) | ||
| 301 | - logger.info(f"构建物品图完成,共 {len(graph)} 个节点") | ||
| 302 | - | ||
| 303 | - # 保存边文件(可选) | ||
| 304 | - if args.save_graph: | ||
| 305 | - edge_file = os.path.join(OUTPUT_DIR, f'item_graph_{datetime.now().strftime("%Y%m%d")}.txt') | ||
| 306 | - save_edge_file(graph, edge_file) | ||
| 307 | - logger.info(f"图边文件已保存到 {edge_file}") | 298 | + # 构建边文件 |
| 299 | + log_processing_step(logger, "构建边文件") | ||
| 300 | + num_nodes = build_edge_file_from_db(df, behavior_weights, edge_file, logger) | ||
| 301 | + | ||
| 302 | + # ============================================================ | ||
| 303 | + # 步骤2: 使用DeepWalk进行随机游走 | ||
| 304 | + # ============================================================ | ||
| 305 | + log_processing_step(logger, "执行DeepWalk随机游走") | ||
| 306 | + | ||
| 307 | + logger.info("初始化DeepWalk...") | ||
| 308 | + deepwalk = DeepWalk( | ||
| 309 | + edge_file=edge_file, | ||
| 310 | + node_tag_file=None, # 不使用标签游走 | ||
| 311 | + use_softmax=args.use_softmax, | ||
| 312 | + temperature=args.temperature, | ||
| 313 | + p_tag_walk=0.0 # 不使用标签游走 | ||
| 314 | + ) | ||
| 308 | 315 | ||
| 309 | - # 生成随机游走 | ||
| 310 | - log_processing_step(logger, "生成随机游走") | ||
| 311 | - walks = generate_walks(graph, args.num_walks, args.walk_length) | ||
| 312 | - logger.info(f"生成 {len(walks)} 条游走路径") | 316 | + logger.info("开始随机游走...") |
| 317 | + deepwalk.simulate_walks( | ||
| 318 | + num_walks=args.num_walks, | ||
| 319 | + walk_length=args.walk_length, | ||
| 320 | + workers=args.workers, | ||
| 321 | + output_file=walks_file | ||
| 322 | + ) | ||
| 313 | 323 | ||
| 314 | - # 训练Word2Vec模型 | 324 | + # ============================================================ |
| 325 | + # 步骤3: 训练Word2Vec模型 | ||
| 326 | + # ============================================================ | ||
| 315 | log_processing_step(logger, "训练Word2Vec模型") | 327 | log_processing_step(logger, "训练Word2Vec模型") |
| 328 | + | ||
| 316 | w2v_config = { | 329 | w2v_config = { |
| 317 | 'vector_size': args.vector_size, | 330 | 'vector_size': args.vector_size, |
| 318 | 'window_size': args.window_size, | 331 | 'window_size': args.window_size, |
| 319 | 'min_count': args.min_count, | 332 | 'min_count': args.min_count, |
| 320 | 'workers': args.workers, | 333 | 'workers': args.workers, |
| 321 | 'epochs': args.epochs, | 334 | 'epochs': args.epochs, |
| 322 | - 'sg': 1 | 335 | + 'sg': 1 # Skip-gram |
| 323 | } | 336 | } |
| 324 | logger.debug(f"Word2Vec配置: {w2v_config}") | 337 | logger.debug(f"Word2Vec配置: {w2v_config}") |
| 325 | 338 | ||
| 326 | - model = train_word2vec(walks, w2v_config) | ||
| 327 | - logger.info(f"训练完成。词汇表大小:{len(model.wv)}") | 339 | + model = train_word2vec_from_walks(walks_file, w2v_config, logger) |
| 328 | 340 | ||
| 329 | # 保存模型(可选) | 341 | # 保存模型(可选) |
| 330 | if args.save_model: | 342 | if args.save_model: |
| 331 | - model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{datetime.now().strftime("%Y%m%d")}.model') | 343 | + model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{date_str}.model') |
| 332 | model.save(model_path) | 344 | model.save(model_path) |
| 333 | logger.info(f"模型已保存到 {model_path}") | 345 | logger.info(f"模型已保存到 {model_path}") |
| 334 | 346 | ||
| 335 | - # 生成相似度 | 347 | + # ============================================================ |
| 348 | + # 步骤4: 生成相似度 | ||
| 349 | + # ============================================================ | ||
| 336 | log_processing_step(logger, "生成相似度") | 350 | log_processing_step(logger, "生成相似度") |
| 337 | - result = generate_similarities(model, top_n=args.top_n) | ||
| 338 | - logger.info(f"生成了 {len(result)} 个物品的相似度") | 351 | + result = generate_similarities(model, args.top_n, logger) |
| 339 | 352 | ||
| 340 | - # 输出结果 | 353 | + # ============================================================ |
| 354 | + # 步骤5: 保存结果 | ||
| 355 | + # ============================================================ | ||
| 341 | log_processing_step(logger, "保存结果") | 356 | log_processing_step(logger, "保存结果") |
| 342 | - output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{datetime.now().strftime("%Y%m%d")}.txt') | 357 | + |
| 358 | + output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{date_str}.txt') | ||
| 343 | 359 | ||
| 344 | # 获取name mappings | 360 | # 获取name mappings |
| 345 | name_mappings = {} | 361 | name_mappings = {} |
| @@ -347,23 +363,14 @@ def main(): | @@ -347,23 +363,14 @@ def main(): | ||
| 347 | logger.info("获取物品名称映射...") | 363 | logger.info("获取物品名称映射...") |
| 348 | name_mappings = fetch_name_mappings(engine, debug=True) | 364 | name_mappings = fetch_name_mappings(engine, debug=True) |
| 349 | 365 | ||
| 350 | - logger.info(f"写入结果到 {output_file}...") | ||
| 351 | - with open(output_file, 'w', encoding='utf-8') as f: | ||
| 352 | - for item_id, sims in result.items(): | ||
| 353 | - # 使用name_mappings获取名称 | ||
| 354 | - item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown' | ||
| 355 | - if item_name == 'Unknown' and 'item_name' in df.columns: | ||
| 356 | - item_name = df[df['item_id'].astype(str) == item_id]['item_name'].iloc[0] if len(df[df['item_id'].astype(str) == item_id]) > 0 else 'Unknown' | ||
| 357 | - | ||
| 358 | - if not sims: | ||
| 359 | - continue | ||
| 360 | - | ||
| 361 | - # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,... | ||
| 362 | - sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims]) | ||
| 363 | - f.write(f'{item_id}\t{item_name}\t{sim_str}\n') | 366 | + save_results(result, output_file, name_mappings, logger) |
| 364 | 367 | ||
| 365 | - logger.info(f"完成!为 {len(result)} 个物品生成了相似度") | ||
| 366 | - logger.info(f"输出保存到:{output_file}") | 368 | + logger.info(f"✓ DeepWalk完成!") |
| 369 | + logger.info(f" - 输出文件: {output_file}") | ||
| 370 | + logger.info(f" - 商品数: {len(result)}") | ||
| 371 | + if result: | ||
| 372 | + avg_sims = sum(len(sims) for sims in result.values()) / len(result) | ||
| 373 | + logger.info(f" - 平均相似商品数: {avg_sims:.1f}") | ||
| 367 | 374 | ||
| 368 | # 如果启用debug模式,保存可读格式 | 375 | # 如果启用debug模式,保存可读格式 |
| 369 | if args.debug: | 376 | if args.debug: |
| @@ -374,8 +381,20 @@ def main(): | @@ -374,8 +381,20 @@ def main(): | ||
| 374 | name_mappings, | 381 | name_mappings, |
| 375 | description='i2i:deepwalk' | 382 | description='i2i:deepwalk' |
| 376 | ) | 383 | ) |
| 384 | + | ||
| 385 | + # 清理临时文件(可选) | ||
| 386 | + if not args.save_graph: | ||
| 387 | + if os.path.exists(edge_file): | ||
| 388 | + os.remove(edge_file) | ||
| 389 | + logger.debug(f"已删除临时文件: {edge_file}") | ||
| 390 | + if os.path.exists(walks_file): | ||
| 391 | + os.remove(walks_file) | ||
| 392 | + logger.debug(f"已删除临时文件: {walks_file}") | ||
| 393 | + | ||
| 394 | + print(f"✓ DeepWalk相似度计算完成") | ||
| 395 | + print(f" - 输出文件: {output_file}") | ||
| 396 | + print(f" - 商品数: {len(result)}") | ||
| 377 | 397 | ||
| 378 | 398 | ||
| 379 | if __name__ == '__main__': | 399 | if __name__ == '__main__': |
| 380 | main() | 400 | main() |
| 381 | - |