i2i_swing.py 12.9 KB
"""
i2i - Swing算法实现
基于用户行为的物品相似度计算
参考item_sim.py的数据格式,适配真实数据
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

import pandas as pd
import math
from collections import defaultdict
import argparse
import json
from datetime import datetime, timedelta
from db_service import create_db_connection
from offline_tasks.config.offline_config import (
    DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
    DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
)
from offline_tasks.scripts.debug_utils import (
    setup_debug_logger, log_dataframe_info, log_dict_stats,
    save_readable_index, fetch_name_mappings, log_algorithm_params,
    log_processing_step
)


def calculate_time_weight(event_time, reference_time, decay_factor=0.95, days_unit=30):
    """
    计算时间衰减权重
    
    Args:
        event_time: 事件发生时间
        reference_time: 参考时间(通常是当前时间)
        decay_factor: 衰减因子
        days_unit: 衰减单位(天)
    
    Returns:
        时间权重
    """
    if pd.isna(event_time):
        return 1.0
    
    time_diff = (reference_time - event_time).days
    if time_diff < 0:
        return 1.0
    
    # 计算衰减权重
    periods = time_diff / days_unit
    weight = math.pow(decay_factor, periods)
    return weight


def swing_algorithm(df, alpha=0.5, time_decay=True, decay_factor=0.95, logger=None, debug=False):
    """
    Swing算法实现
    
    Args:
        df: DataFrame with columns: user_id, item_id, weight, create_time
        alpha: Swing算法的alpha参数
        time_decay: 是否使用时间衰减
        decay_factor: 时间衰减因子
        logger: 日志记录器
        debug: 是否开启debug模式
    
    Returns:
        Dict[item_id, List[Tuple(similar_item_id, score)]]
    """
    start_time = datetime.now()
    if logger:
        logger.debug(f"开始Swing算法计算,参数: alpha={alpha}, time_decay={time_decay}")
    
    # 如果使用时间衰减,计算时间权重
    reference_time = datetime.now()
    if time_decay and 'create_time' in df.columns:
        if logger:
            logger.debug("应用时间衰减...")
        df['time_weight'] = df['create_time'].apply(
            lambda x: calculate_time_weight(x, reference_time, decay_factor)
        )
        df['weight'] = df['weight'] * df['time_weight']
        if logger and debug:
            logger.debug(f"时间权重统计: min={df['time_weight'].min():.4f}, max={df['time_weight'].max():.4f}, avg={df['time_weight'].mean():.4f}")
    
    # 构建用户-物品倒排索引
    if logger:
        log_processing_step(logger, "步骤1: 构建用户-物品倒排索引")
    
    user_items = defaultdict(set)
    item_users = defaultdict(set)
    item_freq = defaultdict(float)
    
    for _, row in df.iterrows():
        user_id = row['user_id']
        item_id = row['item_id']
        weight = row['weight']
        
        user_items[user_id].add(item_id)
        item_users[item_id].add(user_id)
        item_freq[item_id] += weight
    
    if logger:
        logger.info(f"总用户数: {len(user_items)}, 总商品数: {len(item_users)}")
        if debug:
            log_dict_stats(logger, dict(list(user_items.items())[:1000]), "用户-商品倒排索引(采样)", top_n=5)
            log_dict_stats(logger, dict(list(item_users.items())[:1000]), "商品-用户倒排索引(采样)", top_n=5)
    
    # 计算物品相似度
    if logger:
        log_processing_step(logger, "步骤2: 计算Swing物品相似度")
    
    item_sim_dict = defaultdict(lambda: defaultdict(float))
    
    # 遍历每个物品对
    processed_pairs = 0
    total_items = len(item_users)
    
    for idx_i, item_i in enumerate(item_users):
        users_i = item_users[item_i]
        
        # 找到所有与item_i共现的物品
        for item_j in item_users:
            if item_i >= item_j:  # 避免重复计算
                continue
            
            users_j = item_users[item_j]
            common_users = users_i & users_j
            
            if len(common_users) < 2:
                continue
            
            # 计算Swing相似度
            sim_score = 0.0
            common_users_list = list(common_users)
            
            for idx_u in range(len(common_users_list)):
                user_u = common_users_list[idx_u]
                items_u = user_items[user_u]
                
                for idx_v in range(idx_u + 1, len(common_users_list)):
                    user_v = common_users_list[idx_v]
                    items_v = user_items[user_v]
                    
                    # 计算用户u和用户v的共同物品数
                    common_items = items_u & items_v
                    
                    # Swing公式
                    sim_score += 1.0 / (alpha + len(common_items))
            
            item_sim_dict[item_i][item_j] = sim_score
            item_sim_dict[item_j][item_i] = sim_score
            processed_pairs += 1
        
        # Debug: 显示处理进度
        if logger and debug and (idx_i + 1) % 50 == 0:
            logger.debug(f"已处理 {idx_i + 1}/{total_items} 个商品 ({(idx_i+1)/total_items*100:.1f}%)")
    
    if logger:
        logger.info(f"计算了 {processed_pairs} 对商品相似度")
    
    # 对相似度进行归一化并排序
    if logger:
        log_processing_step(logger, "步骤3: 整理和排序结果")
    
    result = {}
    for item_i in item_sim_dict:
        sims = item_sim_dict[item_i]
        
        # 按相似度排序
        sorted_sims = sorted(sims.items(), key=lambda x: -x[1])
        result[item_i] = sorted_sims
    
    if logger:
        total_time = (datetime.now() - start_time).total_seconds()
        logger.info(f"Swing算法完成: {len(result)} 个商品有相似推荐")
        logger.info(f"总耗时: {total_time:.2f}秒")
        
        # 统计每个商品的相似商品数
        sim_counts = [len(sims) for sims in result.values()]
        if sim_counts:
            logger.info(f"相似商品数统计: min={min(sim_counts)}, max={max(sim_counts)}, avg={sum(sim_counts)/len(sim_counts):.2f}")
        
        # 采样展示结果
        if debug:
            sample_results = list(result.items())[:3]
            for item_i, sims in sample_results:
                logger.debug(f"  商品 {item_i} 的Top5相似商品: {sims[:5]}")
    
    return result


def main():
    parser = argparse.ArgumentParser(description='Run Swing algorithm for i2i similarity')
    parser.add_argument('--alpha', type=float, default=I2I_CONFIG['swing']['alpha'],
                       help='Alpha parameter for Swing algorithm')
    parser.add_argument('--top_n', type=int, default=DEFAULT_I2I_TOP_N,
                       help=f'Top N similar items to output (default: {DEFAULT_I2I_TOP_N})')
    parser.add_argument('--lookback_days', type=int, default=DEFAULT_LOOKBACK_DAYS,
                       help=f'Number of days to look back for user behavior (default: {DEFAULT_LOOKBACK_DAYS})')
    parser.add_argument('--time_decay', action='store_true', default=False,
                       help='Use time decay for behavior weights (default: False for B2B low-frequency scenarios)')
    parser.add_argument('--decay_factor', type=float, default=0.95,
                       help='Time decay factor')
    parser.add_argument('--output', type=str, default=None,
                       help='Output file path')
    parser.add_argument('--debug', action='store_true',
                       help='Enable debug mode with detailed logging and readable output')
    
    args = parser.parse_args()
    
    # 设置日志
    logger = setup_debug_logger('i2i_swing', debug=args.debug)
    
    # 记录参数
    log_algorithm_params(logger, {
        'alpha': args.alpha,
        'top_n': args.top_n,
        'lookback_days': args.lookback_days,
        'time_decay': args.time_decay,
        'decay_factor': args.decay_factor,
        'debug': args.debug
    })
    
    # 创建数据库连接
    logger.info("连接数据库...")
    engine = create_db_connection(
        DB_CONFIG['host'],
        DB_CONFIG['port'],
        DB_CONFIG['database'],
        DB_CONFIG['username'],
        DB_CONFIG['password']
    )
    
    # 获取时间范围
    start_date, end_date = get_time_range(args.lookback_days)
    logger.info(f"获取数据: {start_date} 到 {end_date}")
    
    # SQL查询 - 获取用户行为数据
    sql_query = f"""
    SELECT 
        se.anonymous_id AS user_id,
        se.item_id,
        se.event AS event_type,
        se.create_time,
        pgs.name AS item_name
    FROM 
        sensors_events se
    LEFT JOIN prd_goods_sku pgs ON se.item_id = pgs.id
    WHERE 
        se.event IN ('contactFactory', 'addToPool', 'addToCart', 'purchase')
        AND se.create_time >= '{start_date}'
        AND se.create_time <= '{end_date}'
        AND se.item_id IS NOT NULL
        AND se.anonymous_id IS NOT NULL
    ORDER BY 
        se.create_time
    """
    
    try:
        logger.info("执行SQL查询...")
        df = pd.read_sql(sql_query, engine)
        logger.info(f"获取到 {len(df)} 条记录")
        
        # Debug: 显示数据详情
        if args.debug:
            log_dataframe_info(logger, df, "用户行为数据", sample_size=10)
    except Exception as e:
        logger.error(f"获取数据失败: {e}")
        return
    
    if len(df) == 0:
        logger.warning("没有找到数据")
        return
    
    # 转换create_time为datetime
    df['create_time'] = pd.to_datetime(df['create_time'])
    
    # 定义行为权重
    behavior_weights = {
        'contactFactory': 5.0,
        'addToPool': 2.0,
        'addToCart': 3.0,
        'purchase': 10.0
    }
    
    # 添加权重列
    df['weight'] = df['event_type'].map(behavior_weights).fillna(1.0)
    
    if logger and args.debug:
        logger.debug(f"行为类型分布:")
        event_counts = df['event_type'].value_counts()
        for event, count in event_counts.items():
            logger.debug(f"  {event}: {count} ({count/len(df)*100:.2f}%)")
    
    # 运行Swing算法
    logger.info("运行Swing算法...")
    result = swing_algorithm(
        df,
        alpha=args.alpha,
        time_decay=args.time_decay,
        decay_factor=args.decay_factor,
        logger=logger,
        debug=args.debug
    )
    
    # 创建item_id到name的映射(key转为字符串,与name_mappings一致)
    item_name_map = dict(zip(df['item_id'].unique().astype(str), df.groupby('item_id')['item_name'].first()))
    
    # 输出结果
    output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_swing_{datetime.now().strftime("%Y%m%d")}.txt')
    
    logger.info(f"保存结果到: {output_file}")
    output_count = 0
    with open(output_file, 'w', encoding='utf-8') as f:
        for item_id, sims in result.items():
            # item_name_map的key是字符串,需要转换
            item_name = item_name_map.get(str(item_id), 'Unknown')
            
            # 只取前N个最相似的商品
            top_sims = sims[:args.top_n]
            
            if not top_sims:
                continue
            
            # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
            sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in top_sims])
            f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
            output_count += 1
    
    logger.info(f"输出了 {output_count} 个商品的推荐")
    
    # Debug模式:生成明文文件
    if args.debug:
        logger.info("Debug模式:生成明文索引文件...")
        try:
            # 获取名称映射
            logger.debug("获取ID到名称的映射...")
            name_mappings = fetch_name_mappings(engine, debug=True)
            
            # 准备索引数据(合并已有的item_name_map)
            # item_name_map的key已经是str类型,可以直接更新
            name_mappings['item'].update(item_name_map)
            
            if args.debug:
                logger.debug(f"name_mappings['item']共有 {len(name_mappings['item'])} 个商品名称")
            
            index_data = {}
            for item_id, sims in result.items():
                top_sims = sims[:args.top_n]
                if top_sims:
                    index_data[f"i2i:swing:{item_id}"] = top_sims
            
            # 保存明文文件
            readable_file = save_readable_index(
                output_file,
                index_data,
                name_mappings,
                description=f"Swing算法 i2i相似度推荐 (alpha={args.alpha}, lookback_days={args.lookback_days})"
            )
            logger.info(f"明文索引文件: {readable_file}")
        except Exception as e:
            logger.error(f"生成明文文件失败: {e}", exc_info=True)
    
    logger.info("完成!")


if __name__ == '__main__':
    main()