i2i_swing.py 7.82 KB
"""
i2i - Swing算法实现
基于用户行为的物品相似度计算
参考item_sim.py的数据格式,适配真实数据
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

import pandas as pd
import math
from collections import defaultdict
import argparse
import json
from datetime import datetime, timedelta
from db_service import create_db_connection
from offline_tasks.config.offline_config import (
    DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
    DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
)


def calculate_time_weight(event_time, reference_time, decay_factor=0.95, days_unit=30):
    """
    计算时间衰减权重
    
    Args:
        event_time: 事件发生时间
        reference_time: 参考时间(通常是当前时间)
        decay_factor: 衰减因子
        days_unit: 衰减单位(天)
    
    Returns:
        时间权重
    """
    if pd.isna(event_time):
        return 1.0
    
    time_diff = (reference_time - event_time).days
    if time_diff < 0:
        return 1.0
    
    # 计算衰减权重
    periods = time_diff / days_unit
    weight = math.pow(decay_factor, periods)
    return weight


def swing_algorithm(df, alpha=0.5, time_decay=True, decay_factor=0.95):
    """
    Swing算法实现
    
    Args:
        df: DataFrame with columns: user_id, item_id, weight, create_time
        alpha: Swing算法的alpha参数
        time_decay: 是否使用时间衰减
        decay_factor: 时间衰减因子
    
    Returns:
        Dict[item_id, List[Tuple(similar_item_id, score)]]
    """
    # 如果使用时间衰减,计算时间权重
    reference_time = datetime.now()
    if time_decay and 'create_time' in df.columns:
        df['time_weight'] = df['create_time'].apply(
            lambda x: calculate_time_weight(x, reference_time, decay_factor)
        )
        df['weight'] = df['weight'] * df['time_weight']
    
    # 构建用户-物品倒排索引
    user_items = defaultdict(set)
    item_users = defaultdict(set)
    item_freq = defaultdict(float)
    
    for _, row in df.iterrows():
        user_id = row['user_id']
        item_id = row['item_id']
        weight = row['weight']
        
        user_items[user_id].add(item_id)
        item_users[item_id].add(user_id)
        item_freq[item_id] += weight
    
    print(f"Total users: {len(user_items)}, Total items: {len(item_users)}")
    
    # 计算物品相似度
    item_sim_dict = defaultdict(lambda: defaultdict(float))
    
    # 遍历每个物品对
    for item_i in item_users:
        users_i = item_users[item_i]
        
        # 找到所有与item_i共现的物品
        for item_j in item_users:
            if item_i >= item_j:  # 避免重复计算
                continue
            
            users_j = item_users[item_j]
            common_users = users_i & users_j
            
            if len(common_users) < 2:
                continue
            
            # 计算Swing相似度
            sim_score = 0.0
            common_users_list = list(common_users)
            
            for idx_u in range(len(common_users_list)):
                user_u = common_users_list[idx_u]
                items_u = user_items[user_u]
                
                for idx_v in range(idx_u + 1, len(common_users_list)):
                    user_v = common_users_list[idx_v]
                    items_v = user_items[user_v]
                    
                    # 计算用户u和用户v的共同物品数
                    common_items = items_u & items_v
                    
                    # Swing公式
                    sim_score += 1.0 / (alpha + len(common_items))
            
            item_sim_dict[item_i][item_j] = sim_score
            item_sim_dict[item_j][item_i] = sim_score
    
    # 对相似度进行归一化并排序
    result = {}
    for item_i in item_sim_dict:
        sims = item_sim_dict[item_i]
        
        # 归一化(可选)
        # 按相似度排序
        sorted_sims = sorted(sims.items(), key=lambda x: -x[1])
        result[item_i] = sorted_sims
    
    return result


def main():
    parser = argparse.ArgumentParser(description='Run Swing algorithm for i2i similarity')
    parser.add_argument('--alpha', type=float, default=I2I_CONFIG['swing']['alpha'],
                       help='Alpha parameter for Swing algorithm')
    parser.add_argument('--top_n', type=int, default=DEFAULT_I2I_TOP_N,
                       help=f'Top N similar items to output (default: {DEFAULT_I2I_TOP_N})')
    parser.add_argument('--lookback_days', type=int, default=DEFAULT_LOOKBACK_DAYS,
                       help=f'Number of days to look back for user behavior (default: {DEFAULT_LOOKBACK_DAYS})')
    parser.add_argument('--time_decay', action='store_true', default=True,
                       help='Use time decay for behavior weights')
    parser.add_argument('--decay_factor', type=float, default=0.95,
                       help='Time decay factor')
    parser.add_argument('--output', type=str, default=None,
                       help='Output file path')
    
    args = parser.parse_args()
    
    # 创建数据库连接
    print("Connecting to database...")
    engine = create_db_connection(
        DB_CONFIG['host'],
        DB_CONFIG['port'],
        DB_CONFIG['database'],
        DB_CONFIG['username'],
        DB_CONFIG['password']
    )
    
    # 获取时间范围
    start_date, end_date = get_time_range(args.lookback_days)
    print(f"Fetching data from {start_date} to {end_date}...")
    
    # SQL查询 - 获取用户行为数据
    sql_query = f"""
    SELECT 
        se.anonymous_id AS user_id,
        se.item_id,
        se.event AS event_type,
        se.create_time,
        pgs.name AS item_name
    FROM 
        sensors_events se
    LEFT JOIN prd_goods_sku pgs ON se.item_id = pgs.id
    WHERE 
        se.event IN ('contactFactory', 'addToPool', 'addToCart', 'purchase')
        AND se.create_time >= '{start_date}'
        AND se.create_time <= '{end_date}'
        AND se.item_id IS NOT NULL
        AND se.anonymous_id IS NOT NULL
    ORDER BY 
        se.create_time
    """
    
    print("Executing SQL query...")
    df = pd.read_sql(sql_query, engine)
    print(f"Fetched {len(df)} records")
    
    # 转换create_time为datetime
    df['create_time'] = pd.to_datetime(df['create_time'])
    
    # 定义行为权重
    behavior_weights = {
        'contactFactory': 5.0,
        'addToPool': 2.0,
        'addToCart': 3.0,
        'purchase': 10.0
    }
    
    # 添加权重列
    df['weight'] = df['event_type'].map(behavior_weights).fillna(1.0)
    
    # 运行Swing算法
    print("Running Swing algorithm...")
    result = swing_algorithm(
        df,
        alpha=args.alpha,
        time_decay=args.time_decay,
        decay_factor=args.decay_factor
    )
    
    # 创建item_id到name的映射
    item_name_map = dict(zip(df['item_id'].unique(), df.groupby('item_id')['item_name'].first()))
    
    # 输出结果
    output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_swing_{datetime.now().strftime("%Y%m%d")}.txt')
    
    print(f"Writing results to {output_file}...")
    with open(output_file, 'w', encoding='utf-8') as f:
        for item_id, sims in result.items():
            item_name = item_name_map.get(item_id, 'Unknown')
            
            # 只取前N个最相似的商品
            top_sims = sims[:args.top_n]
            
            if not top_sims:
                continue
            
            # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
            sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in top_sims])
            f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
    
    print(f"Done! Generated i2i similarities for {len(result)} items")
    print(f"Output saved to: {output_file}")


if __name__ == '__main__':
    main()