Blame view

offline_tasks/scripts/i2i_swing.py 7.82 KB
5ab1c29c   tangwang   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
  """
  i2i - Swing算法实现
  基于用户行为的物品相似度计算
  参考item_sim.py的数据格式,适配真实数据
  """
  import sys
  import os
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  
  import pandas as pd
  import math
  from collections import defaultdict
  import argparse
  import json
  from datetime import datetime, timedelta
  from db_service import create_db_connection
  from offline_tasks.config.offline_config import (
      DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
      DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
  )
  
  
  def calculate_time_weight(event_time, reference_time, decay_factor=0.95, days_unit=30):
      """
      计算时间衰减权重
      
      Args:
          event_time: 事件发生时间
          reference_time: 参考时间(通常是当前时间)
          decay_factor: 衰减因子
          days_unit: 衰减单位(天)
      
      Returns:
          时间权重
      """
      if pd.isna(event_time):
          return 1.0
      
      time_diff = (reference_time - event_time).days
      if time_diff < 0:
          return 1.0
      
      # 计算衰减权重
      periods = time_diff / days_unit
      weight = math.pow(decay_factor, periods)
      return weight
  
  
  def swing_algorithm(df, alpha=0.5, time_decay=True, decay_factor=0.95):
      """
      Swing算法实现
      
      Args:
          df: DataFrame with columns: user_id, item_id, weight, create_time
          alpha: Swing算法的alpha参数
          time_decay: 是否使用时间衰减
          decay_factor: 时间衰减因子
      
      Returns:
          Dict[item_id, List[Tuple(similar_item_id, score)]]
      """
      # 如果使用时间衰减,计算时间权重
      reference_time = datetime.now()
      if time_decay and 'create_time' in df.columns:
          df['time_weight'] = df['create_time'].apply(
              lambda x: calculate_time_weight(x, reference_time, decay_factor)
          )
          df['weight'] = df['weight'] * df['time_weight']
      
      # 构建用户-物品倒排索引
      user_items = defaultdict(set)
      item_users = defaultdict(set)
      item_freq = defaultdict(float)
      
      for _, row in df.iterrows():
          user_id = row['user_id']
          item_id = row['item_id']
          weight = row['weight']
          
          user_items[user_id].add(item_id)
          item_users[item_id].add(user_id)
          item_freq[item_id] += weight
      
      print(f"Total users: {len(user_items)}, Total items: {len(item_users)}")
      
      # 计算物品相似度
      item_sim_dict = defaultdict(lambda: defaultdict(float))
      
      # 遍历每个物品对
      for item_i in item_users:
          users_i = item_users[item_i]
          
          # 找到所有与item_i共现的物品
          for item_j in item_users:
              if item_i >= item_j:  # 避免重复计算
                  continue
              
              users_j = item_users[item_j]
              common_users = users_i & users_j
              
              if len(common_users) < 2:
                  continue
              
              # 计算Swing相似度
              sim_score = 0.0
              common_users_list = list(common_users)
              
              for idx_u in range(len(common_users_list)):
                  user_u = common_users_list[idx_u]
                  items_u = user_items[user_u]
                  
                  for idx_v in range(idx_u + 1, len(common_users_list)):
                      user_v = common_users_list[idx_v]
                      items_v = user_items[user_v]
                      
                      # 计算用户u和用户v的共同物品数
                      common_items = items_u & items_v
                      
                      # Swing公式
                      sim_score += 1.0 / (alpha + len(common_items))
              
              item_sim_dict[item_i][item_j] = sim_score
              item_sim_dict[item_j][item_i] = sim_score
      
      # 对相似度进行归一化并排序
      result = {}
      for item_i in item_sim_dict:
          sims = item_sim_dict[item_i]
          
          # 归一化(可选)
          # 按相似度排序
          sorted_sims = sorted(sims.items(), key=lambda x: -x[1])
          result[item_i] = sorted_sims
      
      return result
  
  
  def main():
      parser = argparse.ArgumentParser(description='Run Swing algorithm for i2i similarity')
      parser.add_argument('--alpha', type=float, default=I2I_CONFIG['swing']['alpha'],
                         help='Alpha parameter for Swing algorithm')
      parser.add_argument('--top_n', type=int, default=DEFAULT_I2I_TOP_N,
                         help=f'Top N similar items to output (default: {DEFAULT_I2I_TOP_N})')
      parser.add_argument('--lookback_days', type=int, default=DEFAULT_LOOKBACK_DAYS,
                         help=f'Number of days to look back for user behavior (default: {DEFAULT_LOOKBACK_DAYS})')
      parser.add_argument('--time_decay', action='store_true', default=True,
                         help='Use time decay for behavior weights')
      parser.add_argument('--decay_factor', type=float, default=0.95,
                         help='Time decay factor')
      parser.add_argument('--output', type=str, default=None,
                         help='Output file path')
      
      args = parser.parse_args()
      
      # 创建数据库连接
      print("Connecting to database...")
      engine = create_db_connection(
          DB_CONFIG['host'],
          DB_CONFIG['port'],
          DB_CONFIG['database'],
          DB_CONFIG['username'],
          DB_CONFIG['password']
      )
      
      # 获取时间范围
      start_date, end_date = get_time_range(args.lookback_days)
      print(f"Fetching data from {start_date} to {end_date}...")
      
      # SQL查询 - 获取用户行为数据
      sql_query = f"""
      SELECT 
          se.anonymous_id AS user_id,
          se.item_id,
          se.event AS event_type,
          se.create_time,
          pgs.name AS item_name
      FROM 
          sensors_events se
      LEFT JOIN prd_goods_sku pgs ON se.item_id = pgs.id
      WHERE 
          se.event IN ('contactFactory', 'addToPool', 'addToCart', 'purchase')
          AND se.create_time >= '{start_date}'
          AND se.create_time <= '{end_date}'
          AND se.item_id IS NOT NULL
          AND se.anonymous_id IS NOT NULL
      ORDER BY 
          se.create_time
      """
      
      print("Executing SQL query...")
      df = pd.read_sql(sql_query, engine)
      print(f"Fetched {len(df)} records")
      
      # 转换create_time为datetime
      df['create_time'] = pd.to_datetime(df['create_time'])
      
      # 定义行为权重
      behavior_weights = {
          'contactFactory': 5.0,
          'addToPool': 2.0,
          'addToCart': 3.0,
          'purchase': 10.0
      }
      
      # 添加权重列
      df['weight'] = df['event_type'].map(behavior_weights).fillna(1.0)
      
      # 运行Swing算法
      print("Running Swing algorithm...")
      result = swing_algorithm(
          df,
          alpha=args.alpha,
          time_decay=args.time_decay,
          decay_factor=args.decay_factor
      )
      
      # 创建item_id到name的映射
      item_name_map = dict(zip(df['item_id'].unique(), df.groupby('item_id')['item_name'].first()))
      
      # 输出结果
      output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_swing_{datetime.now().strftime("%Y%m%d")}.txt')
      
      print(f"Writing results to {output_file}...")
      with open(output_file, 'w', encoding='utf-8') as f:
          for item_id, sims in result.items():
              item_name = item_name_map.get(item_id, 'Unknown')
              
              # 只取前N个最相似的商品
              top_sims = sims[:args.top_n]
              
              if not top_sims:
                  continue
              
              # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
              sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in top_sims])
              f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
      
      print(f"Done! Generated i2i similarities for {len(result)} items")
      print(f"Output saved to: {output_file}")
  
  
  if __name__ == '__main__':
      main()