5ab1c29c
tangwang
first commit
|
1
2
3
4
|
"""
i2i - Session Word2Vec算法实现
基于用户会话序列训练Word2Vec模型,获取物品向量相似度
"""
|
5ab1c29c
tangwang
first commit
|
5
6
|
import pandas as pd
import json
|
e89d7a84
tangwang
deepwalk refactor...
|
7
|
import os
|
5ab1c29c
tangwang
first commit
|
8
9
10
11
12
13
|
import argparse
from datetime import datetime
from collections import defaultdict
from gensim.models import Word2Vec
import numpy as np
from db_service import create_db_connection
|
c9f77c8f
tangwang
deepwalk refactor...
|
14
|
from config.offline_config import (
|
5ab1c29c
tangwang
first commit
|
15
16
17
|
DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
)
|
c9f77c8f
tangwang
deepwalk refactor...
|
18
|
from scripts.debug_utils import (
|
14f3dcbe
tangwang
offline tasks
|
19
20
21
22
|
setup_debug_logger, log_dataframe_info, log_dict_stats,
save_readable_index, fetch_name_mappings, log_algorithm_params,
log_processing_step
)
|
5ab1c29c
tangwang
first commit
|
23
24
|
|
9832fef6
tangwang
offline tasks
|
25
|
def prepare_session_data(df, max_session_length=50, min_session_length=2, logger=None):
|
5ab1c29c
tangwang
first commit
|
26
|
"""
|
9832fef6
tangwang
offline tasks
|
27
|
准备会话数据 - 基于固定长度分块,适合B2B低频场景
|
5ab1c29c
tangwang
first commit
|
28
29
30
31
|
Args:
df: DataFrame with columns: user_id, item_id, create_time
session_gap_minutes: 会话间隔时间(分钟)
|
14f3dcbe
tangwang
offline tasks
|
32
|
logger: Logger instance for debugging
|
5ab1c29c
tangwang
first commit
|
33
34
35
36
37
38
|
Returns:
List of sessions, each session is a list of item_ids
"""
sessions = []
|
14f3dcbe
tangwang
offline tasks
|
39
|
if logger:
|
9832fef6
tangwang
offline tasks
|
40
|
logger.debug(f"开始准备会话数据(固定长度分块):max_length={max_session_length}, min_length={min_session_length}")
|
14f3dcbe
tangwang
offline tasks
|
41
|
|
5ab1c29c
tangwang
first commit
|
42
43
44
|
# 按用户和时间排序
df = df.sort_values(['user_id', 'create_time'])
|
9832fef6
tangwang
offline tasks
|
45
|
# 按用户分组,获取每个用户的行为序列
|
5ab1c29c
tangwang
first commit
|
46
|
for user_id, user_df in df.groupby('user_id'):
|
9832fef6
tangwang
offline tasks
|
47
48
|
# 获取用户的item序列
item_sequence = user_df['item_id'].astype(str).tolist()
|
5ab1c29c
tangwang
first commit
|
49
|
|
9832fef6
tangwang
offline tasks
|
50
51
52
53
54
55
56
57
58
|
# 如果序列太短,跳过
if len(item_sequence) < min_session_length:
continue
# 按最大长度分块(不重叠)
user_sessions = [
item_sequence[i:i + max_session_length]
for i in range(0, len(item_sequence), max_session_length)
]
|
5ab1c29c
tangwang
first commit
|
59
|
|
9832fef6
tangwang
offline tasks
|
60
61
|
# 过滤掉长度不足的最后一块
user_sessions = [s for s in user_sessions if len(s) >= min_session_length]
|
5ab1c29c
tangwang
first commit
|
62
63
64
|
sessions.extend(user_sessions)
|
14f3dcbe
tangwang
offline tasks
|
65
|
if logger:
|
9832fef6
tangwang
offline tasks
|
66
67
68
69
70
71
72
|
if sessions:
session_lengths = [len(s) for s in sessions]
logger.debug(f"生成 {len(sessions)} 个会话")
logger.debug(f"会话长度统计:最小={min(session_lengths)}, 最大={max(session_lengths)}, "
f"平均={sum(session_lengths)/len(session_lengths):.2f}")
else:
logger.warning("未生成任何会话!")
|
14f3dcbe
tangwang
offline tasks
|
73
|
|
5ab1c29c
tangwang
first commit
|
74
75
76
|
return sessions
|
14f3dcbe
tangwang
offline tasks
|
77
|
def train_word2vec(sessions, config, logger=None):
|
5ab1c29c
tangwang
first commit
|
78
79
80
81
82
83
|
"""
训练Word2Vec模型
Args:
sessions: List of sessions
config: Word2Vec配置
|
14f3dcbe
tangwang
offline tasks
|
84
|
logger: Logger instance for debugging
|
5ab1c29c
tangwang
first commit
|
85
86
87
88
|
Returns:
Word2Vec模型
"""
|
14f3dcbe
tangwang
offline tasks
|
89
90
91
92
93
94
|
if logger:
logger.info(f"训练Word2Vec模型,共 {len(sessions)} 个会话")
logger.debug(f"模型参数:vector_size={config['vector_size']}, window={config['window_size']}, "
f"min_count={config['min_count']}, epochs={config['epochs']}")
else:
print(f"Training Word2Vec with {len(sessions)} sessions...")
|
5ab1c29c
tangwang
first commit
|
95
96
97
98
99
100
101
102
103
104
105
106
|
model = Word2Vec(
sentences=sessions,
vector_size=config['vector_size'],
window=config['window_size'],
min_count=config['min_count'],
workers=config['workers'],
sg=config['sg'],
epochs=config['epochs'],
seed=42
)
|
14f3dcbe
tangwang
offline tasks
|
107
108
109
110
|
if logger:
logger.info(f"训练完成。词汇表大小:{len(model.wv)}")
else:
print(f"Training completed. Vocabulary size: {len(model.wv)}")
|
5ab1c29c
tangwang
first commit
|
111
112
113
|
return model
|
14f3dcbe
tangwang
offline tasks
|
114
|
def generate_similarities(model, top_n=50, logger=None):
|
5ab1c29c
tangwang
first commit
|
115
116
117
118
119
120
|
"""
生成物品相似度
Args:
model: Word2Vec模型
top_n: Top N similar items
|
14f3dcbe
tangwang
offline tasks
|
121
|
logger: Logger instance for debugging
|
5ab1c29c
tangwang
first commit
|
122
123
124
125
126
127
|
Returns:
Dict[item_id, List[Tuple(similar_item_id, score)]]
"""
result = {}
|
14f3dcbe
tangwang
offline tasks
|
128
129
130
|
if logger:
logger.info(f"生成Top {top_n} 相似物品")
|
5ab1c29c
tangwang
first commit
|
131
132
133
134
135
136
137
|
for item_id in model.wv.index_to_key:
try:
similar_items = model.wv.most_similar(item_id, topn=top_n)
result[item_id] = [(sim_id, float(score)) for sim_id, score in similar_items]
except KeyError:
continue
|
14f3dcbe
tangwang
offline tasks
|
138
139
140
|
if logger:
logger.info(f"生成了 {len(result)} 个物品的相似度")
|
5ab1c29c
tangwang
first commit
|
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
|
return result
def main():
parser = argparse.ArgumentParser(description='Run Session Word2Vec for i2i similarity')
parser.add_argument('--window_size', type=int, default=I2I_CONFIG['session_w2v']['window_size'],
help='Window size for Word2Vec')
parser.add_argument('--vector_size', type=int, default=I2I_CONFIG['session_w2v']['vector_size'],
help='Vector size for Word2Vec')
parser.add_argument('--min_count', type=int, default=I2I_CONFIG['session_w2v']['min_count'],
help='Minimum word count')
parser.add_argument('--workers', type=int, default=I2I_CONFIG['session_w2v']['workers'],
help='Number of workers')
parser.add_argument('--epochs', type=int, default=I2I_CONFIG['session_w2v']['epochs'],
help='Number of epochs')
parser.add_argument('--top_n', type=int, default=DEFAULT_I2I_TOP_N,
help=f'Top N similar items to output (default: {DEFAULT_I2I_TOP_N})')
parser.add_argument('--lookback_days', type=int, default=DEFAULT_LOOKBACK_DAYS,
help=f'Number of days to look back (default: {DEFAULT_LOOKBACK_DAYS})')
|
9832fef6
tangwang
offline tasks
|
160
161
162
163
|
parser.add_argument('--max_session_length', type=int, default=50,
help='Maximum session length for chunking (default: 50)')
parser.add_argument('--min_session_length', type=int, default=2,
help='Minimum session length to keep (default: 2)')
|
5ab1c29c
tangwang
first commit
|
164
165
166
167
|
parser.add_argument('--output', type=str, default=None,
help='Output file path')
parser.add_argument('--save_model', action='store_true',
help='Save Word2Vec model')
|
1721766b
tangwang
offline tasks
|
168
169
|
parser.add_argument('--debug', action='store_true',
help='Enable debug mode with detailed logging and readable output')
|
5ab1c29c
tangwang
first commit
|
170
171
172
|
args = parser.parse_args()
|
14f3dcbe
tangwang
offline tasks
|
173
174
175
176
177
178
179
180
181
182
183
184
|
# 设置logger
logger = setup_debug_logger('i2i_session_w2v', debug=args.debug)
# 记录算法参数
params = {
'window_size': args.window_size,
'vector_size': args.vector_size,
'min_count': args.min_count,
'workers': args.workers,
'epochs': args.epochs,
'top_n': args.top_n,
'lookback_days': args.lookback_days,
|
40442baf
tangwang
offline tasks: fi...
|
185
186
|
'max_session_length': args.max_session_length,
'min_session_length': args.min_session_length,
|
14f3dcbe
tangwang
offline tasks
|
187
188
189
190
|
'debug': args.debug
}
log_algorithm_params(logger, params)
|
5ab1c29c
tangwang
first commit
|
191
|
# 创建数据库连接
|
14f3dcbe
tangwang
offline tasks
|
192
|
logger.info("连接数据库...")
|
5ab1c29c
tangwang
first commit
|
193
194
195
196
197
198
199
200
201
202
|
engine = create_db_connection(
DB_CONFIG['host'],
DB_CONFIG['port'],
DB_CONFIG['database'],
DB_CONFIG['username'],
DB_CONFIG['password']
)
# 获取时间范围
start_date, end_date = get_time_range(args.lookback_days)
|
14f3dcbe
tangwang
offline tasks
|
203
|
logger.info(f"获取数据范围:{start_date} 到 {end_date}")
|
5ab1c29c
tangwang
first commit
|
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
|
# SQL查询 - 获取用户行为序列
sql_query = f"""
SELECT
se.anonymous_id AS user_id,
se.item_id,
se.create_time,
pgs.name AS item_name
FROM
sensors_events se
LEFT JOIN prd_goods_sku pgs ON se.item_id = pgs.id
WHERE
se.event IN ('click', 'contactFactory', 'addToPool', 'addToCart', 'purchase')
AND se.create_time >= '{start_date}'
AND se.create_time <= '{end_date}'
AND se.item_id IS NOT NULL
AND se.anonymous_id IS NOT NULL
ORDER BY
se.anonymous_id,
se.create_time
"""
|
14f3dcbe
tangwang
offline tasks
|
226
|
logger.info("执行SQL查询...")
|
5ab1c29c
tangwang
first commit
|
227
|
df = pd.read_sql(sql_query, engine)
|
14f3dcbe
tangwang
offline tasks
|
228
229
|
logger.info(f"获取到 {len(df)} 条记录")
|
0b73c877
tangwang
fix
|
230
231
232
233
|
# 确保ID为整数类型
df['item_id'] = df['item_id'].astype(int)
df['user_id'] = df['user_id'].astype(str)
|
14f3dcbe
tangwang
offline tasks
|
234
235
|
# 记录数据信息
log_dataframe_info(logger, df, "用户行为数据")
|
5ab1c29c
tangwang
first commit
|
236
237
238
239
240
|
# 转换create_time为datetime
df['create_time'] = pd.to_datetime(df['create_time'])
# 准备会话数据
|
14f3dcbe
tangwang
offline tasks
|
241
|
log_processing_step(logger, "准备会话数据")
|
9832fef6
tangwang
offline tasks
|
242
243
244
245
246
247
|
sessions = prepare_session_data(
df,
max_session_length=args.max_session_length,
min_session_length=args.min_session_length,
logger=logger
)
|
14f3dcbe
tangwang
offline tasks
|
248
|
logger.info(f"生成 {len(sessions)} 个会话")
|
5ab1c29c
tangwang
first commit
|
249
250
|
# 训练Word2Vec模型
|
14f3dcbe
tangwang
offline tasks
|
251
|
log_processing_step(logger, "训练Word2Vec模型")
|
5ab1c29c
tangwang
first commit
|
252
253
254
255
256
257
258
259
260
|
w2v_config = {
'vector_size': args.vector_size,
'window_size': args.window_size,
'min_count': args.min_count,
'workers': args.workers,
'epochs': args.epochs,
'sg': 1
}
|
14f3dcbe
tangwang
offline tasks
|
261
|
model = train_word2vec(sessions, w2v_config, logger=logger)
|
5ab1c29c
tangwang
first commit
|
262
263
264
265
266
|
# 保存模型(可选)
if args.save_model:
model_path = os.path.join(OUTPUT_DIR, f'session_w2v_model_{datetime.now().strftime("%Y%m%d")}.model')
model.save(model_path)
|
14f3dcbe
tangwang
offline tasks
|
267
|
logger.info(f"模型已保存到 {model_path}")
|
5ab1c29c
tangwang
first commit
|
268
269
|
# 生成相似度
|
14f3dcbe
tangwang
offline tasks
|
270
271
|
log_processing_step(logger, "生成相似度")
result = generate_similarities(model, top_n=args.top_n, logger=logger)
|
5ab1c29c
tangwang
first commit
|
272
273
|
# 输出结果
|
14f3dcbe
tangwang
offline tasks
|
274
|
log_processing_step(logger, "保存结果")
|
5ab1c29c
tangwang
first commit
|
275
276
|
output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_session_w2v_{datetime.now().strftime("%Y%m%d")}.txt')
|
14f3dcbe
tangwang
offline tasks
|
277
278
279
280
281
282
283
|
# 获取name mappings用于标准输出格式
name_mappings = {}
if args.debug:
logger.info("获取物品名称映射...")
name_mappings = fetch_name_mappings(engine, debug=True)
logger.info(f"写入结果到 {output_file}...")
|
5ab1c29c
tangwang
first commit
|
284
285
|
with open(output_file, 'w', encoding='utf-8') as f:
for item_id, sims in result.items():
|
14f3dcbe
tangwang
offline tasks
|
286
287
288
289
|
# 使用name_mappings获取名称,如果没有则从df中获取
item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown'
if item_name == 'Unknown' and 'item_name' in df.columns:
item_name = df[df['item_id'].astype(str) == item_id]['item_name'].iloc[0] if len(df[df['item_id'].astype(str) == item_id]) > 0 else 'Unknown'
|
5ab1c29c
tangwang
first commit
|
290
291
292
293
294
295
296
297
|
if not sims:
continue
# 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])
f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
|
14f3dcbe
tangwang
offline tasks
|
298
299
300
301
302
303
304
305
306
307
|
logger.info(f"完成!为 {len(result)} 个物品生成了相似度")
logger.info(f"输出保存到:{output_file}")
# 如果启用debug模式,保存可读格式
if args.debug:
log_processing_step(logger, "保存Debug可读格式")
save_readable_index(
output_file,
result,
name_mappings,
|
40442baf
tangwang
offline tasks: fi...
|
308
|
description='i2i:session_w2v'
|
14f3dcbe
tangwang
offline tasks
|
309
|
)
|
5ab1c29c
tangwang
first commit
|
310
311
312
313
|
if __name__ == '__main__':
main()
|