Blame view

offline_tasks/scripts/i2i_session_w2v.py 11 KB
5ab1c29c   tangwang   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
  """
  i2i - Session Word2Vec算法实现
  基于用户会话序列训练Word2Vec模型,获取物品向量相似度
  """
  import sys
  import os
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  
  import pandas as pd
  import json
  import argparse
  from datetime import datetime
  from collections import defaultdict
  from gensim.models import Word2Vec
  import numpy as np
  from db_service import create_db_connection
  from offline_tasks.config.offline_config import (
      DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
      DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
  )
14f3dcbe   tangwang   offline tasks
21
22
23
24
25
  from offline_tasks.scripts.debug_utils import (
      setup_debug_logger, log_dataframe_info, log_dict_stats,
      save_readable_index, fetch_name_mappings, log_algorithm_params,
      log_processing_step
  )
5ab1c29c   tangwang   first commit
26
27
  
  
9832fef6   tangwang   offline tasks
28
  def prepare_session_data(df, max_session_length=50, min_session_length=2, logger=None):
5ab1c29c   tangwang   first commit
29
      """
9832fef6   tangwang   offline tasks
30
      准备会话数据 - 基于固定长度分块,适合B2B低频场景
5ab1c29c   tangwang   first commit
31
32
33
34
      
      Args:
          df: DataFrame with columns: user_id, item_id, create_time
          session_gap_minutes: 会话间隔时间(分钟)
14f3dcbe   tangwang   offline tasks
35
          logger: Logger instance for debugging
5ab1c29c   tangwang   first commit
36
37
38
39
40
41
      
      Returns:
          List of sessions, each session is a list of item_ids
      """
      sessions = []
      
14f3dcbe   tangwang   offline tasks
42
      if logger:
9832fef6   tangwang   offline tasks
43
          logger.debug(f"开始准备会话数据(固定长度分块):max_length={max_session_length}, min_length={min_session_length}")
14f3dcbe   tangwang   offline tasks
44
      
5ab1c29c   tangwang   first commit
45
46
47
      # 按用户和时间排序
      df = df.sort_values(['user_id', 'create_time'])
      
9832fef6   tangwang   offline tasks
48
      # 按用户分组,获取每个用户的行为序列
5ab1c29c   tangwang   first commit
49
      for user_id, user_df in df.groupby('user_id'):
9832fef6   tangwang   offline tasks
50
51
          # 获取用户的item序列
          item_sequence = user_df['item_id'].astype(str).tolist()
5ab1c29c   tangwang   first commit
52
          
9832fef6   tangwang   offline tasks
53
54
55
56
57
58
59
60
61
          # 如果序列太短,跳过
          if len(item_sequence) < min_session_length:
              continue
          
          # 按最大长度分块(不重叠)
          user_sessions = [
              item_sequence[i:i + max_session_length] 
              for i in range(0, len(item_sequence), max_session_length)
          ]
5ab1c29c   tangwang   first commit
62
          
9832fef6   tangwang   offline tasks
63
64
          # 过滤掉长度不足的最后一块
          user_sessions = [s for s in user_sessions if len(s) >= min_session_length]
5ab1c29c   tangwang   first commit
65
66
67
          
          sessions.extend(user_sessions)
      
14f3dcbe   tangwang   offline tasks
68
      if logger:
9832fef6   tangwang   offline tasks
69
70
71
72
73
74
75
          if sessions:
              session_lengths = [len(s) for s in sessions]
              logger.debug(f"生成 {len(sessions)} 个会话")
              logger.debug(f"会话长度统计:最小={min(session_lengths)}, 最大={max(session_lengths)}, "
                          f"平均={sum(session_lengths)/len(session_lengths):.2f}")
          else:
              logger.warning("未生成任何会话!")
14f3dcbe   tangwang   offline tasks
76
      
5ab1c29c   tangwang   first commit
77
78
79
      return sessions
  
  
14f3dcbe   tangwang   offline tasks
80
  def train_word2vec(sessions, config, logger=None):
5ab1c29c   tangwang   first commit
81
82
83
84
85
86
      """
      训练Word2Vec模型
      
      Args:
          sessions: List of sessions
          config: Word2Vec配置
14f3dcbe   tangwang   offline tasks
87
          logger: Logger instance for debugging
5ab1c29c   tangwang   first commit
88
89
90
91
      
      Returns:
          Word2Vec模型
      """
14f3dcbe   tangwang   offline tasks
92
93
94
95
96
97
      if logger:
          logger.info(f"训练Word2Vec模型,共 {len(sessions)} 个会话")
          logger.debug(f"模型参数:vector_size={config['vector_size']}, window={config['window_size']}, "
                      f"min_count={config['min_count']}, epochs={config['epochs']}")
      else:
          print(f"Training Word2Vec with {len(sessions)} sessions...")
5ab1c29c   tangwang   first commit
98
99
100
101
102
103
104
105
106
107
108
109
      
      model = Word2Vec(
          sentences=sessions,
          vector_size=config['vector_size'],
          window=config['window_size'],
          min_count=config['min_count'],
          workers=config['workers'],
          sg=config['sg'],
          epochs=config['epochs'],
          seed=42
      )
      
14f3dcbe   tangwang   offline tasks
110
111
112
113
      if logger:
          logger.info(f"训练完成。词汇表大小:{len(model.wv)}")
      else:
          print(f"Training completed. Vocabulary size: {len(model.wv)}")
5ab1c29c   tangwang   first commit
114
115
116
      return model
  
  
14f3dcbe   tangwang   offline tasks
117
  def generate_similarities(model, top_n=50, logger=None):
5ab1c29c   tangwang   first commit
118
119
120
121
122
123
      """
      生成物品相似度
      
      Args:
          model: Word2Vec模型
          top_n: Top N similar items
14f3dcbe   tangwang   offline tasks
124
          logger: Logger instance for debugging
5ab1c29c   tangwang   first commit
125
126
127
128
129
130
      
      Returns:
          Dict[item_id, List[Tuple(similar_item_id, score)]]
      """
      result = {}
      
14f3dcbe   tangwang   offline tasks
131
132
133
      if logger:
          logger.info(f"生成Top {top_n} 相似物品")
      
5ab1c29c   tangwang   first commit
134
135
136
137
138
139
140
      for item_id in model.wv.index_to_key:
          try:
              similar_items = model.wv.most_similar(item_id, topn=top_n)
              result[item_id] = [(sim_id, float(score)) for sim_id, score in similar_items]
          except KeyError:
              continue
      
14f3dcbe   tangwang   offline tasks
141
142
143
      if logger:
          logger.info(f"生成了 {len(result)} 个物品的相似度")
      
5ab1c29c   tangwang   first commit
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
      return result
  
  
  def main():
      parser = argparse.ArgumentParser(description='Run Session Word2Vec for i2i similarity')
      parser.add_argument('--window_size', type=int, default=I2I_CONFIG['session_w2v']['window_size'],
                         help='Window size for Word2Vec')
      parser.add_argument('--vector_size', type=int, default=I2I_CONFIG['session_w2v']['vector_size'],
                         help='Vector size for Word2Vec')
      parser.add_argument('--min_count', type=int, default=I2I_CONFIG['session_w2v']['min_count'],
                         help='Minimum word count')
      parser.add_argument('--workers', type=int, default=I2I_CONFIG['session_w2v']['workers'],
                         help='Number of workers')
      parser.add_argument('--epochs', type=int, default=I2I_CONFIG['session_w2v']['epochs'],
                         help='Number of epochs')
      parser.add_argument('--top_n', type=int, default=DEFAULT_I2I_TOP_N,
                         help=f'Top N similar items to output (default: {DEFAULT_I2I_TOP_N})')
      parser.add_argument('--lookback_days', type=int, default=DEFAULT_LOOKBACK_DAYS,
                         help=f'Number of days to look back (default: {DEFAULT_LOOKBACK_DAYS})')
9832fef6   tangwang   offline tasks
163
164
165
166
      parser.add_argument('--max_session_length', type=int, default=50,
                         help='Maximum session length for chunking (default: 50)')
      parser.add_argument('--min_session_length', type=int, default=2,
                         help='Minimum session length to keep (default: 2)')
5ab1c29c   tangwang   first commit
167
168
169
170
      parser.add_argument('--output', type=str, default=None,
                         help='Output file path')
      parser.add_argument('--save_model', action='store_true',
                         help='Save Word2Vec model')
1721766b   tangwang   offline tasks
171
172
      parser.add_argument('--debug', action='store_true',
                         help='Enable debug mode with detailed logging and readable output')
5ab1c29c   tangwang   first commit
173
174
175
      
      args = parser.parse_args()
      
14f3dcbe   tangwang   offline tasks
176
177
178
179
180
181
182
183
184
185
186
187
      # 设置logger
      logger = setup_debug_logger('i2i_session_w2v', debug=args.debug)
      
      # 记录算法参数
      params = {
          'window_size': args.window_size,
          'vector_size': args.vector_size,
          'min_count': args.min_count,
          'workers': args.workers,
          'epochs': args.epochs,
          'top_n': args.top_n,
          'lookback_days': args.lookback_days,
40442baf   tangwang   offline tasks: fi...
188
189
          'max_session_length': args.max_session_length,
          'min_session_length': args.min_session_length,
14f3dcbe   tangwang   offline tasks
190
191
192
193
          'debug': args.debug
      }
      log_algorithm_params(logger, params)
      
5ab1c29c   tangwang   first commit
194
      # 创建数据库连接
14f3dcbe   tangwang   offline tasks
195
      logger.info("连接数据库...")
5ab1c29c   tangwang   first commit
196
197
198
199
200
201
202
203
204
205
      engine = create_db_connection(
          DB_CONFIG['host'],
          DB_CONFIG['port'],
          DB_CONFIG['database'],
          DB_CONFIG['username'],
          DB_CONFIG['password']
      )
      
      # 获取时间范围
      start_date, end_date = get_time_range(args.lookback_days)
14f3dcbe   tangwang   offline tasks
206
      logger.info(f"获取数据范围:{start_date} 到 {end_date}")
5ab1c29c   tangwang   first commit
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
      
      # SQL查询 - 获取用户行为序列
      sql_query = f"""
      SELECT 
          se.anonymous_id AS user_id,
          se.item_id,
          se.create_time,
          pgs.name AS item_name
      FROM 
          sensors_events se
      LEFT JOIN prd_goods_sku pgs ON se.item_id = pgs.id
      WHERE 
          se.event IN ('click', 'contactFactory', 'addToPool', 'addToCart', 'purchase')
          AND se.create_time >= '{start_date}'
          AND se.create_time <= '{end_date}'
          AND se.item_id IS NOT NULL
          AND se.anonymous_id IS NOT NULL
      ORDER BY 
          se.anonymous_id,
          se.create_time
      """
      
14f3dcbe   tangwang   offline tasks
229
      logger.info("执行SQL查询...")
5ab1c29c   tangwang   first commit
230
      df = pd.read_sql(sql_query, engine)
14f3dcbe   tangwang   offline tasks
231
232
233
234
      logger.info(f"获取到 {len(df)} 条记录")
      
      # 记录数据信息
      log_dataframe_info(logger, df, "用户行为数据")
5ab1c29c   tangwang   first commit
235
236
237
238
239
      
      # 转换create_time为datetime
      df['create_time'] = pd.to_datetime(df['create_time'])
      
      # 准备会话数据
14f3dcbe   tangwang   offline tasks
240
      log_processing_step(logger, "准备会话数据")
9832fef6   tangwang   offline tasks
241
242
243
244
245
246
      sessions = prepare_session_data(
          df, 
          max_session_length=args.max_session_length,
          min_session_length=args.min_session_length,
          logger=logger
      )
14f3dcbe   tangwang   offline tasks
247
      logger.info(f"生成 {len(sessions)} 个会话")
5ab1c29c   tangwang   first commit
248
249
      
      # 训练Word2Vec模型
14f3dcbe   tangwang   offline tasks
250
      log_processing_step(logger, "训练Word2Vec模型")
5ab1c29c   tangwang   first commit
251
252
253
254
255
256
257
258
259
      w2v_config = {
          'vector_size': args.vector_size,
          'window_size': args.window_size,
          'min_count': args.min_count,
          'workers': args.workers,
          'epochs': args.epochs,
          'sg': 1
      }
      
14f3dcbe   tangwang   offline tasks
260
      model = train_word2vec(sessions, w2v_config, logger=logger)
5ab1c29c   tangwang   first commit
261
262
263
264
265
      
      # 保存模型(可选)
      if args.save_model:
          model_path = os.path.join(OUTPUT_DIR, f'session_w2v_model_{datetime.now().strftime("%Y%m%d")}.model')
          model.save(model_path)
14f3dcbe   tangwang   offline tasks
266
          logger.info(f"模型已保存到 {model_path}")
5ab1c29c   tangwang   first commit
267
268
      
      # 生成相似度
14f3dcbe   tangwang   offline tasks
269
270
      log_processing_step(logger, "生成相似度")
      result = generate_similarities(model, top_n=args.top_n, logger=logger)
5ab1c29c   tangwang   first commit
271
272
      
      # 输出结果
14f3dcbe   tangwang   offline tasks
273
      log_processing_step(logger, "保存结果")
5ab1c29c   tangwang   first commit
274
275
      output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_session_w2v_{datetime.now().strftime("%Y%m%d")}.txt')
      
14f3dcbe   tangwang   offline tasks
276
277
278
279
280
281
282
      # 获取name mappings用于标准输出格式
      name_mappings = {}
      if args.debug:
          logger.info("获取物品名称映射...")
          name_mappings = fetch_name_mappings(engine, debug=True)
      
      logger.info(f"写入结果到 {output_file}...")
5ab1c29c   tangwang   first commit
283
284
      with open(output_file, 'w', encoding='utf-8') as f:
          for item_id, sims in result.items():
14f3dcbe   tangwang   offline tasks
285
286
287
288
              # 使用name_mappings获取名称,如果没有则从df中获取
              item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown'
              if item_name == 'Unknown' and 'item_name' in df.columns:
                  item_name = df[df['item_id'].astype(str) == item_id]['item_name'].iloc[0] if len(df[df['item_id'].astype(str) == item_id]) > 0 else 'Unknown'
5ab1c29c   tangwang   first commit
289
290
291
292
293
294
295
296
              
              if not sims:
                  continue
              
              # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
              sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])
              f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
      
14f3dcbe   tangwang   offline tasks
297
298
299
300
301
302
303
304
305
306
      logger.info(f"完成!为 {len(result)} 个物品生成了相似度")
      logger.info(f"输出保存到:{output_file}")
      
      # 如果启用debug模式,保存可读格式
      if args.debug:
          log_processing_step(logger, "保存Debug可读格式")
          save_readable_index(
              output_file,
              result,
              name_mappings,
40442baf   tangwang   offline tasks: fi...
307
              description='i2i:session_w2v'
14f3dcbe   tangwang   offline tasks
308
          )
5ab1c29c   tangwang   first commit
309
310
311
312
  
  
  if __name__ == '__main__':
      main()