i2i_session_w2v.py 10.9 KB
"""
i2i - Session Word2Vec算法实现
基于用户会话序列训练Word2Vec模型,获取物品向量相似度
"""
import pandas as pd
import json
import argparse
from datetime import datetime
from collections import defaultdict
from gensim.models import Word2Vec
import numpy as np
from db_service import create_db_connection
from offline_tasks.config.offline_config import (
    DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
    DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
)
from offline_tasks.scripts.debug_utils import (
    setup_debug_logger, log_dataframe_info, log_dict_stats,
    save_readable_index, fetch_name_mappings, log_algorithm_params,
    log_processing_step
)


def prepare_session_data(df, max_session_length=50, min_session_length=2, logger=None):
    """
    准备会话数据 - 基于固定长度分块,适合B2B低频场景
    
    Args:
        df: DataFrame with columns: user_id, item_id, create_time
        session_gap_minutes: 会话间隔时间(分钟)
        logger: Logger instance for debugging
    
    Returns:
        List of sessions, each session is a list of item_ids
    """
    sessions = []
    
    if logger:
        logger.debug(f"开始准备会话数据(固定长度分块):max_length={max_session_length}, min_length={min_session_length}")
    
    # 按用户和时间排序
    df = df.sort_values(['user_id', 'create_time'])
    
    # 按用户分组,获取每个用户的行为序列
    for user_id, user_df in df.groupby('user_id'):
        # 获取用户的item序列
        item_sequence = user_df['item_id'].astype(str).tolist()
        
        # 如果序列太短,跳过
        if len(item_sequence) < min_session_length:
            continue
        
        # 按最大长度分块(不重叠)
        user_sessions = [
            item_sequence[i:i + max_session_length] 
            for i in range(0, len(item_sequence), max_session_length)
        ]
        
        # 过滤掉长度不足的最后一块
        user_sessions = [s for s in user_sessions if len(s) >= min_session_length]
        
        sessions.extend(user_sessions)
    
    if logger:
        if sessions:
            session_lengths = [len(s) for s in sessions]
            logger.debug(f"生成 {len(sessions)} 个会话")
            logger.debug(f"会话长度统计:最小={min(session_lengths)}, 最大={max(session_lengths)}, "
                        f"平均={sum(session_lengths)/len(session_lengths):.2f}")
        else:
            logger.warning("未生成任何会话!")
    
    return sessions


def train_word2vec(sessions, config, logger=None):
    """
    训练Word2Vec模型
    
    Args:
        sessions: List of sessions
        config: Word2Vec配置
        logger: Logger instance for debugging
    
    Returns:
        Word2Vec模型
    """
    if logger:
        logger.info(f"训练Word2Vec模型,共 {len(sessions)} 个会话")
        logger.debug(f"模型参数:vector_size={config['vector_size']}, window={config['window_size']}, "
                    f"min_count={config['min_count']}, epochs={config['epochs']}")
    else:
        print(f"Training Word2Vec with {len(sessions)} sessions...")
    
    model = Word2Vec(
        sentences=sessions,
        vector_size=config['vector_size'],
        window=config['window_size'],
        min_count=config['min_count'],
        workers=config['workers'],
        sg=config['sg'],
        epochs=config['epochs'],
        seed=42
    )
    
    if logger:
        logger.info(f"训练完成。词汇表大小:{len(model.wv)}")
    else:
        print(f"Training completed. Vocabulary size: {len(model.wv)}")
    return model


def generate_similarities(model, top_n=50, logger=None):
    """
    生成物品相似度
    
    Args:
        model: Word2Vec模型
        top_n: Top N similar items
        logger: Logger instance for debugging
    
    Returns:
        Dict[item_id, List[Tuple(similar_item_id, score)]]
    """
    result = {}
    
    if logger:
        logger.info(f"生成Top {top_n} 相似物品")
    
    for item_id in model.wv.index_to_key:
        try:
            similar_items = model.wv.most_similar(item_id, topn=top_n)
            result[item_id] = [(sim_id, float(score)) for sim_id, score in similar_items]
        except KeyError:
            continue
    
    if logger:
        logger.info(f"生成了 {len(result)} 个物品的相似度")
    
    return result


def main():
    parser = argparse.ArgumentParser(description='Run Session Word2Vec for i2i similarity')
    parser.add_argument('--window_size', type=int, default=I2I_CONFIG['session_w2v']['window_size'],
                       help='Window size for Word2Vec')
    parser.add_argument('--vector_size', type=int, default=I2I_CONFIG['session_w2v']['vector_size'],
                       help='Vector size for Word2Vec')
    parser.add_argument('--min_count', type=int, default=I2I_CONFIG['session_w2v']['min_count'],
                       help='Minimum word count')
    parser.add_argument('--workers', type=int, default=I2I_CONFIG['session_w2v']['workers'],
                       help='Number of workers')
    parser.add_argument('--epochs', type=int, default=I2I_CONFIG['session_w2v']['epochs'],
                       help='Number of epochs')
    parser.add_argument('--top_n', type=int, default=DEFAULT_I2I_TOP_N,
                       help=f'Top N similar items to output (default: {DEFAULT_I2I_TOP_N})')
    parser.add_argument('--lookback_days', type=int, default=DEFAULT_LOOKBACK_DAYS,
                       help=f'Number of days to look back (default: {DEFAULT_LOOKBACK_DAYS})')
    parser.add_argument('--max_session_length', type=int, default=50,
                       help='Maximum session length for chunking (default: 50)')
    parser.add_argument('--min_session_length', type=int, default=2,
                       help='Minimum session length to keep (default: 2)')
    parser.add_argument('--output', type=str, default=None,
                       help='Output file path')
    parser.add_argument('--save_model', action='store_true',
                       help='Save Word2Vec model')
    parser.add_argument('--debug', action='store_true',
                       help='Enable debug mode with detailed logging and readable output')
    
    args = parser.parse_args()
    
    # 设置logger
    logger = setup_debug_logger('i2i_session_w2v', debug=args.debug)
    
    # 记录算法参数
    params = {
        'window_size': args.window_size,
        'vector_size': args.vector_size,
        'min_count': args.min_count,
        'workers': args.workers,
        'epochs': args.epochs,
        'top_n': args.top_n,
        'lookback_days': args.lookback_days,
        'max_session_length': args.max_session_length,
        'min_session_length': args.min_session_length,
        'debug': args.debug
    }
    log_algorithm_params(logger, params)
    
    # 创建数据库连接
    logger.info("连接数据库...")
    engine = create_db_connection(
        DB_CONFIG['host'],
        DB_CONFIG['port'],
        DB_CONFIG['database'],
        DB_CONFIG['username'],
        DB_CONFIG['password']
    )
    
    # 获取时间范围
    start_date, end_date = get_time_range(args.lookback_days)
    logger.info(f"获取数据范围:{start_date} 到 {end_date}")
    
    # SQL查询 - 获取用户行为序列
    sql_query = f"""
    SELECT 
        se.anonymous_id AS user_id,
        se.item_id,
        se.create_time,
        pgs.name AS item_name
    FROM 
        sensors_events se
    LEFT JOIN prd_goods_sku pgs ON se.item_id = pgs.id
    WHERE 
        se.event IN ('click', 'contactFactory', 'addToPool', 'addToCart', 'purchase')
        AND se.create_time >= '{start_date}'
        AND se.create_time <= '{end_date}'
        AND se.item_id IS NOT NULL
        AND se.anonymous_id IS NOT NULL
    ORDER BY 
        se.anonymous_id,
        se.create_time
    """
    
    logger.info("执行SQL查询...")
    df = pd.read_sql(sql_query, engine)
    logger.info(f"获取到 {len(df)} 条记录")
    
    # 记录数据信息
    log_dataframe_info(logger, df, "用户行为数据")
    
    # 转换create_time为datetime
    df['create_time'] = pd.to_datetime(df['create_time'])
    
    # 准备会话数据
    log_processing_step(logger, "准备会话数据")
    sessions = prepare_session_data(
        df, 
        max_session_length=args.max_session_length,
        min_session_length=args.min_session_length,
        logger=logger
    )
    logger.info(f"生成 {len(sessions)} 个会话")
    
    # 训练Word2Vec模型
    log_processing_step(logger, "训练Word2Vec模型")
    w2v_config = {
        'vector_size': args.vector_size,
        'window_size': args.window_size,
        'min_count': args.min_count,
        'workers': args.workers,
        'epochs': args.epochs,
        'sg': 1
    }
    
    model = train_word2vec(sessions, w2v_config, logger=logger)
    
    # 保存模型(可选)
    if args.save_model:
        model_path = os.path.join(OUTPUT_DIR, f'session_w2v_model_{datetime.now().strftime("%Y%m%d")}.model')
        model.save(model_path)
        logger.info(f"模型已保存到 {model_path}")
    
    # 生成相似度
    log_processing_step(logger, "生成相似度")
    result = generate_similarities(model, top_n=args.top_n, logger=logger)
    
    # 输出结果
    log_processing_step(logger, "保存结果")
    output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_session_w2v_{datetime.now().strftime("%Y%m%d")}.txt')
    
    # 获取name mappings用于标准输出格式
    name_mappings = {}
    if args.debug:
        logger.info("获取物品名称映射...")
        name_mappings = fetch_name_mappings(engine, debug=True)
    
    logger.info(f"写入结果到 {output_file}...")
    with open(output_file, 'w', encoding='utf-8') as f:
        for item_id, sims in result.items():
            # 使用name_mappings获取名称,如果没有则从df中获取
            item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown'
            if item_name == 'Unknown' and 'item_name' in df.columns:
                item_name = df[df['item_id'].astype(str) == item_id]['item_name'].iloc[0] if len(df[df['item_id'].astype(str) == item_id]) > 0 else 'Unknown'
            
            if not sims:
                continue
            
            # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
            sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])
            f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
    
    logger.info(f"完成!为 {len(result)} 个物品生成了相似度")
    logger.info(f"输出保存到:{output_file}")
    
    # 如果启用debug模式,保存可读格式
    if args.debug:
        log_processing_step(logger, "保存Debug可读格式")
        save_readable_index(
            output_file,
            result,
            name_mappings,
            description='i2i:session_w2v'
        )


if __name__ == '__main__':
    main()