5ab1c29c
tangwang
first commit
|
1
2
3
|
"""
i2i - DeepWalk算法实现
基于用户-物品图结构训练DeepWalk模型,获取物品向量相似度
|
0e45f702
tangwang
deepwalk refactor...
|
4
|
复用 graphembedding/deepwalk/ 的高效实现
|
5ab1c29c
tangwang
first commit
|
5
|
"""
|
5ab1c29c
tangwang
first commit
|
6
7
|
import pandas as pd
import argparse
|
0e45f702
tangwang
deepwalk refactor...
|
8
9
|
import os
import sys
|
5ab1c29c
tangwang
first commit
|
10
11
12
|
from datetime import datetime
from collections import defaultdict
from gensim.models import Word2Vec
|
5ab1c29c
tangwang
first commit
|
13
|
from db_service import create_db_connection
|
06cb25fa
tangwang
deepwalk refactor...
|
14
|
from config import (
|
5ab1c29c
tangwang
first commit
|
15
16
17
|
DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
)
|
06cb25fa
tangwang
deepwalk refactor...
|
18
|
from debug_utils import (
|
0e45f702
tangwang
deepwalk refactor...
|
19
|
setup_debug_logger, log_dataframe_info,
|
14f3dcbe
tangwang
offline tasks
|
20
21
22
|
save_readable_index, fetch_name_mappings, log_algorithm_params,
log_processing_step
)
|
5ab1c29c
tangwang
first commit
|
23
|
|
0e45f702
tangwang
deepwalk refactor...
|
24
25
26
27
|
# 导入 DeepWalk 实现
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'deepwalk'))
from deepwalk import DeepWalk
|
5ab1c29c
tangwang
first commit
|
28
|
|
0e45f702
tangwang
deepwalk refactor...
|
29
|
def build_edge_file_from_db(df, behavior_weights, output_path, logger):
|
5ab1c29c
tangwang
first commit
|
30
|
"""
|
0e45f702
tangwang
deepwalk refactor...
|
31
32
|
从数据库数据构建边文件
边文件格式: item_id \t neighbor_id1:weight1,neighbor_id2:weight2,...
|
5ab1c29c
tangwang
first commit
|
33
34
35
36
|
Args:
df: DataFrame with columns: user_id, item_id, event_type
behavior_weights: 行为权重字典
|
0e45f702
tangwang
deepwalk refactor...
|
37
38
|
output_path: 边文件输出路径
logger: 日志对象
|
5ab1c29c
tangwang
first commit
|
39
|
"""
|
0e45f702
tangwang
deepwalk refactor...
|
40
41
|
logger.info("开始构建物品图...")
|
5ab1c29c
tangwang
first commit
|
42
43
44
45
46
47
48
49
|
# 构建用户-物品列表
user_items = defaultdict(list)
for _, row in df.iterrows():
user_id = row['user_id']
item_id = str(row['item_id'])
event_type = row['event_type']
weight = behavior_weights.get(event_type, 1.0)
|
5ab1c29c
tangwang
first commit
|
50
51
|
user_items[user_id].append((item_id, weight))
|
0e45f702
tangwang
deepwalk refactor...
|
52
53
|
logger.info(f"共有 {len(user_items)} 个用户")
|
5ab1c29c
tangwang
first commit
|
54
55
56
57
|
# 构建物品图边
edge_dict = defaultdict(lambda: defaultdict(float))
for user_id, items in user_items.items():
|
0e45f702
tangwang
deepwalk refactor...
|
58
59
60
61
62
|
# 限制每个用户的物品数量,避免内存爆炸
if len(items) > 100:
# 按权重排序,只保留前100个
items = sorted(items, key=lambda x: -x[1])[:100]
|
5ab1c29c
tangwang
first commit
|
63
64
65
66
67
68
69
70
71
72
73
|
# 物品两两组合,构建边
for i in range(len(items)):
item_i, weight_i = items[i]
for j in range(i + 1, len(items)):
item_j, weight_j = items[j]
# 边的权重为两个物品权重的平均值
edge_weight = (weight_i + weight_j) / 2.0
edge_dict[item_i][item_j] += edge_weight
edge_dict[item_j][item_i] += edge_weight
|
0e45f702
tangwang
deepwalk refactor...
|
74
|
logger.info(f"构建物品图完成,共 {len(edge_dict)} 个节点")
|
5ab1c29c
tangwang
first commit
|
75
|
|
0e45f702
tangwang
deepwalk refactor...
|
76
77
|
# 保存边文件
logger.info(f"保存边文件到 {output_path}")
|
5ab1c29c
tangwang
first commit
|
78
79
|
with open(output_path, 'w', encoding='utf-8') as f:
for item_id, neighbors in edge_dict.items():
|
5ab1c29c
tangwang
first commit
|
80
81
82
|
neighbor_str = ','.join([f'{nbr}:{weight:.4f}' for nbr, weight in neighbors.items()])
f.write(f'{item_id}\t{neighbor_str}\n')
|
0e45f702
tangwang
deepwalk refactor...
|
83
84
|
logger.info(f"边文件保存完成")
return len(edge_dict)
|
5ab1c29c
tangwang
first commit
|
85
86
|
|
0e45f702
tangwang
deepwalk refactor...
|
87
|
def train_word2vec_from_walks(walks_file, config, logger):
|
5ab1c29c
tangwang
first commit
|
88
|
"""
|
0e45f702
tangwang
deepwalk refactor...
|
89
|
从游走文件训练Word2Vec模型
|
5ab1c29c
tangwang
first commit
|
90
91
|
Args:
|
0e45f702
tangwang
deepwalk refactor...
|
92
93
94
|
walks_file: 游走序列文件路径
config: Word2Vec配置
logger: 日志对象
|
5ab1c29c
tangwang
first commit
|
95
96
|
Returns:
|
0e45f702
tangwang
deepwalk refactor...
|
97
|
Word2Vec模型
|
5ab1c29c
tangwang
first commit
|
98
|
"""
|
0e45f702
tangwang
deepwalk refactor...
|
99
|
logger.info(f"从 {walks_file} 读取游走序列...")
|
5ab1c29c
tangwang
first commit
|
100
|
|
0e45f702
tangwang
deepwalk refactor...
|
101
102
103
104
105
|
# 读取游走序列
sentences = []
with open(walks_file, 'r', encoding='utf-8') as f:
for line in f:
walk = line.strip().split()
|
5ab1c29c
tangwang
first commit
|
106
|
if len(walk) >= 2:
|
0e45f702
tangwang
deepwalk refactor...
|
107
|
sentences.append(walk)
|
5ab1c29c
tangwang
first commit
|
108
|
|
0e45f702
tangwang
deepwalk refactor...
|
109
|
logger.info(f"共读取 {len(sentences)} 条游走序列")
|
5ab1c29c
tangwang
first commit
|
110
|
|
0e45f702
tangwang
deepwalk refactor...
|
111
112
|
# 训练Word2Vec
logger.info("开始训练Word2Vec模型...")
|
5ab1c29c
tangwang
first commit
|
113
|
model = Word2Vec(
|
0e45f702
tangwang
deepwalk refactor...
|
114
|
sentences=sentences,
|
5ab1c29c
tangwang
first commit
|
115
116
117
118
119
120
121
122
123
|
vector_size=config['vector_size'],
window=config['window_size'],
min_count=config['min_count'],
workers=config['workers'],
sg=config['sg'],
epochs=config['epochs'],
seed=42
)
|
0e45f702
tangwang
deepwalk refactor...
|
124
|
logger.info(f"训练完成。词汇表大小:{len(model.wv)}")
|
5ab1c29c
tangwang
first commit
|
125
126
127
|
return model
|
0e45f702
tangwang
deepwalk refactor...
|
128
|
def generate_similarities(model, top_n, logger):
|
5ab1c29c
tangwang
first commit
|
129
|
"""
|
0e45f702
tangwang
deepwalk refactor...
|
130
|
从Word2Vec模型生成物品相似度
|
5ab1c29c
tangwang
first commit
|
131
132
133
134
|
Args:
model: Word2Vec模型
top_n: Top N similar items
|
0e45f702
tangwang
deepwalk refactor...
|
135
|
logger: 日志对象
|
5ab1c29c
tangwang
first commit
|
136
137
138
139
|
Returns:
Dict[item_id, List[Tuple(similar_item_id, score)]]
"""
|
0e45f702
tangwang
deepwalk refactor...
|
140
|
logger.info("生成相似度...")
|
5ab1c29c
tangwang
first commit
|
141
142
143
144
145
146
147
148
149
|
result = {}
for item_id in model.wv.index_to_key:
try:
similar_items = model.wv.most_similar(item_id, topn=top_n)
result[item_id] = [(sim_id, float(score)) for sim_id, score in similar_items]
except KeyError:
continue
|
0e45f702
tangwang
deepwalk refactor...
|
150
|
logger.info(f"为 {len(result)} 个物品生成了相似度")
|
5ab1c29c
tangwang
first commit
|
151
152
153
|
return result
|
0e45f702
tangwang
deepwalk refactor...
|
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
|
def save_results(result, output_file, name_mappings, logger):
"""
保存相似度结果到文件
Args:
result: 相似度字典
output_file: 输出文件路径
name_mappings: ID到名称的映射
logger: 日志对象
"""
logger.info(f"保存结果到 {output_file}...")
with open(output_file, 'w', encoding='utf-8') as f:
for item_id, sims in result.items():
# 获取物品名称
item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown'
if not sims:
continue
# 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])
f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
logger.info(f"结果保存完成")
|
5ab1c29c
tangwang
first commit
|
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
|
def main():
parser = argparse.ArgumentParser(description='Run DeepWalk for i2i similarity')
parser.add_argument('--num_walks', type=int, default=I2I_CONFIG['deepwalk']['num_walks'],
help='Number of walks per node')
parser.add_argument('--walk_length', type=int, default=I2I_CONFIG['deepwalk']['walk_length'],
help='Walk length')
parser.add_argument('--window_size', type=int, default=I2I_CONFIG['deepwalk']['window_size'],
help='Window size for Word2Vec')
parser.add_argument('--vector_size', type=int, default=I2I_CONFIG['deepwalk']['vector_size'],
help='Vector size for Word2Vec')
parser.add_argument('--min_count', type=int, default=I2I_CONFIG['deepwalk']['min_count'],
help='Minimum word count')
parser.add_argument('--workers', type=int, default=I2I_CONFIG['deepwalk']['workers'],
help='Number of workers')
parser.add_argument('--epochs', type=int, default=I2I_CONFIG['deepwalk']['epochs'],
help='Number of epochs')
parser.add_argument('--top_n', type=int, default=DEFAULT_I2I_TOP_N,
help=f'Top N similar items to output (default: {DEFAULT_I2I_TOP_N})')
parser.add_argument('--lookback_days', type=int, default=DEFAULT_LOOKBACK_DAYS,
help=f'Number of days to look back (default: {DEFAULT_LOOKBACK_DAYS})')
parser.add_argument('--output', type=str, default=None,
help='Output file path')
parser.add_argument('--save_model', action='store_true',
help='Save Word2Vec model')
parser.add_argument('--save_graph', action='store_true',
help='Save graph edge file')
|
1721766b
tangwang
offline tasks
|
207
208
|
parser.add_argument('--debug', action='store_true',
help='Enable debug mode with detailed logging and readable output')
|
0e45f702
tangwang
deepwalk refactor...
|
209
210
211
212
|
parser.add_argument('--use_softmax', action='store_true',
help='Use softmax-based alias sampling (default: False)')
parser.add_argument('--temperature', type=float, default=1.0,
help='Temperature for softmax (default: 1.0)')
|
5ab1c29c
tangwang
first commit
|
213
214
215
|
args = parser.parse_args()
|
14f3dcbe
tangwang
offline tasks
|
216
217
218
219
220
221
222
223
224
225
226
227
228
229
|
# 设置logger
logger = setup_debug_logger('i2i_deepwalk', debug=args.debug)
# 记录算法参数
params = {
'num_walks': args.num_walks,
'walk_length': args.walk_length,
'window_size': args.window_size,
'vector_size': args.vector_size,
'min_count': args.min_count,
'workers': args.workers,
'epochs': args.epochs,
'top_n': args.top_n,
'lookback_days': args.lookback_days,
|
0e45f702
tangwang
deepwalk refactor...
|
230
231
232
|
'debug': args.debug,
'use_softmax': args.use_softmax,
'temperature': args.temperature
|
14f3dcbe
tangwang
offline tasks
|
233
234
235
|
}
log_algorithm_params(logger, params)
|
0e45f702
tangwang
deepwalk refactor...
|
236
237
238
239
240
241
242
243
244
245
246
247
248
|
# 创建临时目录
temp_dir = os.path.join(OUTPUT_DIR, 'temp')
os.makedirs(temp_dir, exist_ok=True)
date_str = datetime.now().strftime('%Y%m%d')
edge_file = os.path.join(temp_dir, f'item_graph_{date_str}.txt')
walks_file = os.path.join(temp_dir, f'walks_{date_str}.txt')
# ============================================================
# 步骤1: 从数据库获取数据并构建边文件
# ============================================================
log_processing_step(logger, "从数据库获取数据")
|
5ab1c29c
tangwang
first commit
|
249
|
# 创建数据库连接
|
14f3dcbe
tangwang
offline tasks
|
250
|
logger.info("连接数据库...")
|
5ab1c29c
tangwang
first commit
|
251
252
253
254
255
256
257
258
259
260
|
engine = create_db_connection(
DB_CONFIG['host'],
DB_CONFIG['port'],
DB_CONFIG['database'],
DB_CONFIG['username'],
DB_CONFIG['password']
)
# 获取时间范围
start_date, end_date = get_time_range(args.lookback_days)
|
14f3dcbe
tangwang
offline tasks
|
261
|
logger.info(f"获取数据范围:{start_date} 到 {end_date}")
|
5ab1c29c
tangwang
first commit
|
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
|
# SQL查询 - 获取用户行为数据
sql_query = f"""
SELECT
se.anonymous_id AS user_id,
se.item_id,
se.event AS event_type,
pgs.name AS item_name
FROM
sensors_events se
LEFT JOIN prd_goods_sku pgs ON se.item_id = pgs.id
WHERE
se.event IN ('click', 'contactFactory', 'addToPool', 'addToCart', 'purchase')
AND se.create_time >= '{start_date}'
AND se.create_time <= '{end_date}'
AND se.item_id IS NOT NULL
AND se.anonymous_id IS NOT NULL
"""
|
14f3dcbe
tangwang
offline tasks
|
281
|
logger.info("执行SQL查询...")
|
5ab1c29c
tangwang
first commit
|
282
|
df = pd.read_sql(sql_query, engine)
|
14f3dcbe
tangwang
offline tasks
|
283
284
285
286
|
logger.info(f"获取到 {len(df)} 条记录")
# 记录数据信息
log_dataframe_info(logger, df, "用户行为数据")
|
5ab1c29c
tangwang
first commit
|
287
288
289
290
291
292
293
294
295
|
# 定义行为权重
behavior_weights = {
'click': 1.0,
'contactFactory': 5.0,
'addToPool': 2.0,
'addToCart': 3.0,
'purchase': 10.0
}
|
14f3dcbe
tangwang
offline tasks
|
296
|
logger.debug(f"行为权重: {behavior_weights}")
|
5ab1c29c
tangwang
first commit
|
297
|
|
0e45f702
tangwang
deepwalk refactor...
|
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
|
# 构建边文件
log_processing_step(logger, "构建边文件")
num_nodes = build_edge_file_from_db(df, behavior_weights, edge_file, logger)
# ============================================================
# 步骤2: 使用DeepWalk进行随机游走
# ============================================================
log_processing_step(logger, "执行DeepWalk随机游走")
logger.info("初始化DeepWalk...")
deepwalk = DeepWalk(
edge_file=edge_file,
node_tag_file=None, # 不使用标签游走
use_softmax=args.use_softmax,
temperature=args.temperature,
p_tag_walk=0.0 # 不使用标签游走
)
|
5ab1c29c
tangwang
first commit
|
315
|
|
0e45f702
tangwang
deepwalk refactor...
|
316
317
318
319
320
321
322
|
logger.info("开始随机游走...")
deepwalk.simulate_walks(
num_walks=args.num_walks,
walk_length=args.walk_length,
workers=args.workers,
output_file=walks_file
)
|
5ab1c29c
tangwang
first commit
|
323
|
|
0e45f702
tangwang
deepwalk refactor...
|
324
325
326
|
# ============================================================
# 步骤3: 训练Word2Vec模型
# ============================================================
|
14f3dcbe
tangwang
offline tasks
|
327
|
log_processing_step(logger, "训练Word2Vec模型")
|
0e45f702
tangwang
deepwalk refactor...
|
328
|
|
5ab1c29c
tangwang
first commit
|
329
330
331
332
333
334
|
w2v_config = {
'vector_size': args.vector_size,
'window_size': args.window_size,
'min_count': args.min_count,
'workers': args.workers,
'epochs': args.epochs,
|
0e45f702
tangwang
deepwalk refactor...
|
335
|
'sg': 1 # Skip-gram
|
5ab1c29c
tangwang
first commit
|
336
|
}
|
14f3dcbe
tangwang
offline tasks
|
337
|
logger.debug(f"Word2Vec配置: {w2v_config}")
|
5ab1c29c
tangwang
first commit
|
338
|
|
0e45f702
tangwang
deepwalk refactor...
|
339
|
model = train_word2vec_from_walks(walks_file, w2v_config, logger)
|
5ab1c29c
tangwang
first commit
|
340
341
342
|
# 保存模型(可选)
if args.save_model:
|
0e45f702
tangwang
deepwalk refactor...
|
343
|
model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{date_str}.model')
|
5ab1c29c
tangwang
first commit
|
344
|
model.save(model_path)
|
14f3dcbe
tangwang
offline tasks
|
345
|
logger.info(f"模型已保存到 {model_path}")
|
5ab1c29c
tangwang
first commit
|
346
|
|
0e45f702
tangwang
deepwalk refactor...
|
347
348
349
|
# ============================================================
# 步骤4: 生成相似度
# ============================================================
|
14f3dcbe
tangwang
offline tasks
|
350
|
log_processing_step(logger, "生成相似度")
|
0e45f702
tangwang
deepwalk refactor...
|
351
|
result = generate_similarities(model, args.top_n, logger)
|
5ab1c29c
tangwang
first commit
|
352
|
|
0e45f702
tangwang
deepwalk refactor...
|
353
354
355
|
# ============================================================
# 步骤5: 保存结果
# ============================================================
|
14f3dcbe
tangwang
offline tasks
|
356
|
log_processing_step(logger, "保存结果")
|
0e45f702
tangwang
deepwalk refactor...
|
357
358
|
output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{date_str}.txt')
|
5ab1c29c
tangwang
first commit
|
359
|
|
14f3dcbe
tangwang
offline tasks
|
360
361
362
363
364
365
|
# 获取name mappings
name_mappings = {}
if args.debug:
logger.info("获取物品名称映射...")
name_mappings = fetch_name_mappings(engine, debug=True)
|
0e45f702
tangwang
deepwalk refactor...
|
366
|
save_results(result, output_file, name_mappings, logger)
|
5ab1c29c
tangwang
first commit
|
367
|
|
0e45f702
tangwang
deepwalk refactor...
|
368
369
370
371
372
373
|
logger.info(f"✓ DeepWalk完成!")
logger.info(f" - 输出文件: {output_file}")
logger.info(f" - 商品数: {len(result)}")
if result:
avg_sims = sum(len(sims) for sims in result.values()) / len(result)
logger.info(f" - 平均相似商品数: {avg_sims:.1f}")
|
14f3dcbe
tangwang
offline tasks
|
374
375
376
377
378
379
380
381
|
# 如果启用debug模式,保存可读格式
if args.debug:
log_processing_step(logger, "保存Debug可读格式")
save_readable_index(
output_file,
result,
name_mappings,
|
40442baf
tangwang
offline tasks: fi...
|
382
|
description='i2i:deepwalk'
|
14f3dcbe
tangwang
offline tasks
|
383
|
)
|
0e45f702
tangwang
deepwalk refactor...
|
384
385
386
387
388
389
390
391
392
393
394
395
396
|
# 清理临时文件(可选)
if not args.save_graph:
if os.path.exists(edge_file):
os.remove(edge_file)
logger.debug(f"已删除临时文件: {edge_file}")
if os.path.exists(walks_file):
os.remove(walks_file)
logger.debug(f"已删除临时文件: {walks_file}")
print(f"✓ DeepWalk相似度计算完成")
print(f" - 输出文件: {output_file}")
print(f" - 商品数: {len(result)}")
|
5ab1c29c
tangwang
first commit
|
397
398
399
400
|
if __name__ == '__main__':
main()
|