Blame view

offline_tasks/scripts/i2i_session_w2v.py 10.9 KB
5ab1c29c   tangwang   first commit
1
2
3
4
  """
  i2i - Session Word2Vec算法实现
  基于用户会话序列训练Word2Vec模型,获取物品向量相似度
  """
5ab1c29c   tangwang   first commit
5
6
7
8
9
10
11
12
13
14
15
16
  import pandas as pd
  import json
  import argparse
  from datetime import datetime
  from collections import defaultdict
  from gensim.models import Word2Vec
  import numpy as np
  from db_service import create_db_connection
  from offline_tasks.config.offline_config import (
      DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
      DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
  )
14f3dcbe   tangwang   offline tasks
17
18
19
20
21
  from offline_tasks.scripts.debug_utils import (
      setup_debug_logger, log_dataframe_info, log_dict_stats,
      save_readable_index, fetch_name_mappings, log_algorithm_params,
      log_processing_step
  )
5ab1c29c   tangwang   first commit
22
23
  
  
9832fef6   tangwang   offline tasks
24
  def prepare_session_data(df, max_session_length=50, min_session_length=2, logger=None):
5ab1c29c   tangwang   first commit
25
      """
9832fef6   tangwang   offline tasks
26
      准备会话数据 - 基于固定长度分块,适合B2B低频场景
5ab1c29c   tangwang   first commit
27
28
29
30
      
      Args:
          df: DataFrame with columns: user_id, item_id, create_time
          session_gap_minutes: 会话间隔时间(分钟)
14f3dcbe   tangwang   offline tasks
31
          logger: Logger instance for debugging
5ab1c29c   tangwang   first commit
32
33
34
35
36
37
      
      Returns:
          List of sessions, each session is a list of item_ids
      """
      sessions = []
      
14f3dcbe   tangwang   offline tasks
38
      if logger:
9832fef6   tangwang   offline tasks
39
          logger.debug(f"开始准备会话数据(固定长度分块):max_length={max_session_length}, min_length={min_session_length}")
14f3dcbe   tangwang   offline tasks
40
      
5ab1c29c   tangwang   first commit
41
42
43
      # 按用户和时间排序
      df = df.sort_values(['user_id', 'create_time'])
      
9832fef6   tangwang   offline tasks
44
      # 按用户分组,获取每个用户的行为序列
5ab1c29c   tangwang   first commit
45
      for user_id, user_df in df.groupby('user_id'):
9832fef6   tangwang   offline tasks
46
47
          # 获取用户的item序列
          item_sequence = user_df['item_id'].astype(str).tolist()
5ab1c29c   tangwang   first commit
48
          
9832fef6   tangwang   offline tasks
49
50
51
52
53
54
55
56
57
          # 如果序列太短,跳过
          if len(item_sequence) < min_session_length:
              continue
          
          # 按最大长度分块(不重叠)
          user_sessions = [
              item_sequence[i:i + max_session_length] 
              for i in range(0, len(item_sequence), max_session_length)
          ]
5ab1c29c   tangwang   first commit
58
          
9832fef6   tangwang   offline tasks
59
60
          # 过滤掉长度不足的最后一块
          user_sessions = [s for s in user_sessions if len(s) >= min_session_length]
5ab1c29c   tangwang   first commit
61
62
63
          
          sessions.extend(user_sessions)
      
14f3dcbe   tangwang   offline tasks
64
      if logger:
9832fef6   tangwang   offline tasks
65
66
67
68
69
70
71
          if sessions:
              session_lengths = [len(s) for s in sessions]
              logger.debug(f"生成 {len(sessions)} 个会话")
              logger.debug(f"会话长度统计:最小={min(session_lengths)}, 最大={max(session_lengths)}, "
                          f"平均={sum(session_lengths)/len(session_lengths):.2f}")
          else:
              logger.warning("未生成任何会话!")
14f3dcbe   tangwang   offline tasks
72
      
5ab1c29c   tangwang   first commit
73
74
75
      return sessions
  
  
14f3dcbe   tangwang   offline tasks
76
  def train_word2vec(sessions, config, logger=None):
5ab1c29c   tangwang   first commit
77
78
79
80
81
82
      """
      训练Word2Vec模型
      
      Args:
          sessions: List of sessions
          config: Word2Vec配置
14f3dcbe   tangwang   offline tasks
83
          logger: Logger instance for debugging
5ab1c29c   tangwang   first commit
84
85
86
87
      
      Returns:
          Word2Vec模型
      """
14f3dcbe   tangwang   offline tasks
88
89
90
91
92
93
      if logger:
          logger.info(f"训练Word2Vec模型,共 {len(sessions)} 个会话")
          logger.debug(f"模型参数:vector_size={config['vector_size']}, window={config['window_size']}, "
                      f"min_count={config['min_count']}, epochs={config['epochs']}")
      else:
          print(f"Training Word2Vec with {len(sessions)} sessions...")
5ab1c29c   tangwang   first commit
94
95
96
97
98
99
100
101
102
103
104
105
      
      model = Word2Vec(
          sentences=sessions,
          vector_size=config['vector_size'],
          window=config['window_size'],
          min_count=config['min_count'],
          workers=config['workers'],
          sg=config['sg'],
          epochs=config['epochs'],
          seed=42
      )
      
14f3dcbe   tangwang   offline tasks
106
107
108
109
      if logger:
          logger.info(f"训练完成。词汇表大小:{len(model.wv)}")
      else:
          print(f"Training completed. Vocabulary size: {len(model.wv)}")
5ab1c29c   tangwang   first commit
110
111
112
      return model
  
  
14f3dcbe   tangwang   offline tasks
113
  def generate_similarities(model, top_n=50, logger=None):
5ab1c29c   tangwang   first commit
114
115
116
117
118
119
      """
      生成物品相似度
      
      Args:
          model: Word2Vec模型
          top_n: Top N similar items
14f3dcbe   tangwang   offline tasks
120
          logger: Logger instance for debugging
5ab1c29c   tangwang   first commit
121
122
123
124
125
126
      
      Returns:
          Dict[item_id, List[Tuple(similar_item_id, score)]]
      """
      result = {}
      
14f3dcbe   tangwang   offline tasks
127
128
129
      if logger:
          logger.info(f"生成Top {top_n} 相似物品")
      
5ab1c29c   tangwang   first commit
130
131
132
133
134
135
136
      for item_id in model.wv.index_to_key:
          try:
              similar_items = model.wv.most_similar(item_id, topn=top_n)
              result[item_id] = [(sim_id, float(score)) for sim_id, score in similar_items]
          except KeyError:
              continue
      
14f3dcbe   tangwang   offline tasks
137
138
139
      if logger:
          logger.info(f"生成了 {len(result)} 个物品的相似度")
      
5ab1c29c   tangwang   first commit
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
      return result
  
  
  def main():
      parser = argparse.ArgumentParser(description='Run Session Word2Vec for i2i similarity')
      parser.add_argument('--window_size', type=int, default=I2I_CONFIG['session_w2v']['window_size'],
                         help='Window size for Word2Vec')
      parser.add_argument('--vector_size', type=int, default=I2I_CONFIG['session_w2v']['vector_size'],
                         help='Vector size for Word2Vec')
      parser.add_argument('--min_count', type=int, default=I2I_CONFIG['session_w2v']['min_count'],
                         help='Minimum word count')
      parser.add_argument('--workers', type=int, default=I2I_CONFIG['session_w2v']['workers'],
                         help='Number of workers')
      parser.add_argument('--epochs', type=int, default=I2I_CONFIG['session_w2v']['epochs'],
                         help='Number of epochs')
      parser.add_argument('--top_n', type=int, default=DEFAULT_I2I_TOP_N,
                         help=f'Top N similar items to output (default: {DEFAULT_I2I_TOP_N})')
      parser.add_argument('--lookback_days', type=int, default=DEFAULT_LOOKBACK_DAYS,
                         help=f'Number of days to look back (default: {DEFAULT_LOOKBACK_DAYS})')
9832fef6   tangwang   offline tasks
159
160
161
162
      parser.add_argument('--max_session_length', type=int, default=50,
                         help='Maximum session length for chunking (default: 50)')
      parser.add_argument('--min_session_length', type=int, default=2,
                         help='Minimum session length to keep (default: 2)')
5ab1c29c   tangwang   first commit
163
164
165
166
      parser.add_argument('--output', type=str, default=None,
                         help='Output file path')
      parser.add_argument('--save_model', action='store_true',
                         help='Save Word2Vec model')
1721766b   tangwang   offline tasks
167
168
      parser.add_argument('--debug', action='store_true',
                         help='Enable debug mode with detailed logging and readable output')
5ab1c29c   tangwang   first commit
169
170
171
      
      args = parser.parse_args()
      
14f3dcbe   tangwang   offline tasks
172
173
174
175
176
177
178
179
180
181
182
183
      # 设置logger
      logger = setup_debug_logger('i2i_session_w2v', debug=args.debug)
      
      # 记录算法参数
      params = {
          'window_size': args.window_size,
          'vector_size': args.vector_size,
          'min_count': args.min_count,
          'workers': args.workers,
          'epochs': args.epochs,
          'top_n': args.top_n,
          'lookback_days': args.lookback_days,
40442baf   tangwang   offline tasks: fi...
184
185
          'max_session_length': args.max_session_length,
          'min_session_length': args.min_session_length,
14f3dcbe   tangwang   offline tasks
186
187
188
189
          'debug': args.debug
      }
      log_algorithm_params(logger, params)
      
5ab1c29c   tangwang   first commit
190
      # 创建数据库连接
14f3dcbe   tangwang   offline tasks
191
      logger.info("连接数据库...")
5ab1c29c   tangwang   first commit
192
193
194
195
196
197
198
199
200
201
      engine = create_db_connection(
          DB_CONFIG['host'],
          DB_CONFIG['port'],
          DB_CONFIG['database'],
          DB_CONFIG['username'],
          DB_CONFIG['password']
      )
      
      # 获取时间范围
      start_date, end_date = get_time_range(args.lookback_days)
14f3dcbe   tangwang   offline tasks
202
      logger.info(f"获取数据范围:{start_date} 到 {end_date}")
5ab1c29c   tangwang   first commit
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
      
      # SQL查询 - 获取用户行为序列
      sql_query = f"""
      SELECT 
          se.anonymous_id AS user_id,
          se.item_id,
          se.create_time,
          pgs.name AS item_name
      FROM 
          sensors_events se
      LEFT JOIN prd_goods_sku pgs ON se.item_id = pgs.id
      WHERE 
          se.event IN ('click', 'contactFactory', 'addToPool', 'addToCart', 'purchase')
          AND se.create_time >= '{start_date}'
          AND se.create_time <= '{end_date}'
          AND se.item_id IS NOT NULL
          AND se.anonymous_id IS NOT NULL
      ORDER BY 
          se.anonymous_id,
          se.create_time
      """
      
14f3dcbe   tangwang   offline tasks
225
      logger.info("执行SQL查询...")
5ab1c29c   tangwang   first commit
226
      df = pd.read_sql(sql_query, engine)
14f3dcbe   tangwang   offline tasks
227
228
229
230
      logger.info(f"获取到 {len(df)} 条记录")
      
      # 记录数据信息
      log_dataframe_info(logger, df, "用户行为数据")
5ab1c29c   tangwang   first commit
231
232
233
234
235
      
      # 转换create_time为datetime
      df['create_time'] = pd.to_datetime(df['create_time'])
      
      # 准备会话数据
14f3dcbe   tangwang   offline tasks
236
      log_processing_step(logger, "准备会话数据")
9832fef6   tangwang   offline tasks
237
238
239
240
241
242
      sessions = prepare_session_data(
          df, 
          max_session_length=args.max_session_length,
          min_session_length=args.min_session_length,
          logger=logger
      )
14f3dcbe   tangwang   offline tasks
243
      logger.info(f"生成 {len(sessions)} 个会话")
5ab1c29c   tangwang   first commit
244
245
      
      # 训练Word2Vec模型
14f3dcbe   tangwang   offline tasks
246
      log_processing_step(logger, "训练Word2Vec模型")
5ab1c29c   tangwang   first commit
247
248
249
250
251
252
253
254
255
      w2v_config = {
          'vector_size': args.vector_size,
          'window_size': args.window_size,
          'min_count': args.min_count,
          'workers': args.workers,
          'epochs': args.epochs,
          'sg': 1
      }
      
14f3dcbe   tangwang   offline tasks
256
      model = train_word2vec(sessions, w2v_config, logger=logger)
5ab1c29c   tangwang   first commit
257
258
259
260
261
      
      # 保存模型(可选)
      if args.save_model:
          model_path = os.path.join(OUTPUT_DIR, f'session_w2v_model_{datetime.now().strftime("%Y%m%d")}.model')
          model.save(model_path)
14f3dcbe   tangwang   offline tasks
262
          logger.info(f"模型已保存到 {model_path}")
5ab1c29c   tangwang   first commit
263
264
      
      # 生成相似度
14f3dcbe   tangwang   offline tasks
265
266
      log_processing_step(logger, "生成相似度")
      result = generate_similarities(model, top_n=args.top_n, logger=logger)
5ab1c29c   tangwang   first commit
267
268
      
      # 输出结果
14f3dcbe   tangwang   offline tasks
269
      log_processing_step(logger, "保存结果")
5ab1c29c   tangwang   first commit
270
271
      output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_session_w2v_{datetime.now().strftime("%Y%m%d")}.txt')
      
14f3dcbe   tangwang   offline tasks
272
273
274
275
276
277
278
      # 获取name mappings用于标准输出格式
      name_mappings = {}
      if args.debug:
          logger.info("获取物品名称映射...")
          name_mappings = fetch_name_mappings(engine, debug=True)
      
      logger.info(f"写入结果到 {output_file}...")
5ab1c29c   tangwang   first commit
279
280
      with open(output_file, 'w', encoding='utf-8') as f:
          for item_id, sims in result.items():
14f3dcbe   tangwang   offline tasks
281
282
283
284
              # 使用name_mappings获取名称,如果没有则从df中获取
              item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown'
              if item_name == 'Unknown' and 'item_name' in df.columns:
                  item_name = df[df['item_id'].astype(str) == item_id]['item_name'].iloc[0] if len(df[df['item_id'].astype(str) == item_id]) > 0 else 'Unknown'
5ab1c29c   tangwang   first commit
285
286
287
288
289
290
291
292
              
              if not sims:
                  continue
              
              # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
              sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])
              f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
      
14f3dcbe   tangwang   offline tasks
293
294
295
296
297
298
299
300
301
302
      logger.info(f"完成!为 {len(result)} 个物品生成了相似度")
      logger.info(f"输出保存到:{output_file}")
      
      # 如果启用debug模式,保存可读格式
      if args.debug:
          log_processing_step(logger, "保存Debug可读格式")
          save_readable_index(
              output_file,
              result,
              name_mappings,
40442baf   tangwang   offline tasks: fi...
303
              description='i2i:session_w2v'
14f3dcbe   tangwang   offline tasks
304
          )
5ab1c29c   tangwang   first commit
305
306
307
308
  
  
  if __name__ == '__main__':
      main()