i2i_deepwalk.py 13.9 KB
"""
i2i - DeepWalk算法实现
基于用户-物品图结构训练DeepWalk模型,获取物品向量相似度
复用 graphembedding/deepwalk/ 的高效实现
"""
import pandas as pd
import argparse
import os
import sys
from datetime import datetime
from collections import defaultdict
from gensim.models import Word2Vec
from db_service import create_db_connection
from config import (
    DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
    DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
)
from debug_utils import (
    setup_debug_logger, log_dataframe_info,
    save_readable_index, fetch_name_mappings, log_algorithm_params,
    log_processing_step
)

# 导入 DeepWalk 实现
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), 'deepwalk'))
from deepwalk import DeepWalk


def build_edge_file_from_db(df, behavior_weights, output_path, logger):
    """
    从数据库数据构建边文件
    边文件格式: item_id \t neighbor_id1:weight1,neighbor_id2:weight2,...
    
    Args:
        df: DataFrame with columns: user_id, item_id, event_type
        behavior_weights: 行为权重字典
        output_path: 边文件输出路径
        logger: 日志对象
    """
    logger.info("开始构建物品图...")
    
    # 构建用户-物品列表
    user_items = defaultdict(list)
    
    for _, row in df.iterrows():
        user_id = row['user_id']
        item_id = str(row['item_id'])
        event_type = row['event_type']
        weight = behavior_weights.get(event_type, 1.0)
        user_items[user_id].append((item_id, weight))
    
    logger.info(f"共有 {len(user_items)} 个用户")
    
    # 构建物品图边
    edge_dict = defaultdict(lambda: defaultdict(float))
    
    for user_id, items in user_items.items():
        # 限制每个用户的物品数量,避免内存爆炸
        if len(items) > 100:
            # 按权重排序,只保留前100个
            items = sorted(items, key=lambda x: -x[1])[:100]
        
        # 物品两两组合,构建边
        for i in range(len(items)):
            item_i, weight_i = items[i]
            for j in range(i + 1, len(items)):
                item_j, weight_j = items[j]
                
                # 边的权重为两个物品权重的平均值
                edge_weight = (weight_i + weight_j) / 2.0
                edge_dict[item_i][item_j] += edge_weight
                edge_dict[item_j][item_i] += edge_weight
    
    logger.info(f"构建物品图完成,共 {len(edge_dict)} 个节点")
    
    # 保存边文件
    logger.info(f"保存边文件到 {output_path}")
    with open(output_path, 'w', encoding='utf-8') as f:
        for item_id, neighbors in edge_dict.items():
            neighbor_str = ','.join([f'{nbr}:{weight:.4f}' for nbr, weight in neighbors.items()])
            f.write(f'{item_id}\t{neighbor_str}\n')
    
    logger.info(f"边文件保存完成")
    return len(edge_dict)


def train_word2vec_from_walks(walks_file, config, logger):
    """
    从游走文件训练Word2Vec模型
    
    Args:
        walks_file: 游走序列文件路径
        config: Word2Vec配置
        logger: 日志对象
    
    Returns:
        Word2Vec模型
    """
    logger.info(f"从 {walks_file} 读取游走序列...")
    
    # 读取游走序列
    sentences = []
    with open(walks_file, 'r', encoding='utf-8') as f:
        for line in f:
            walk = line.strip().split()
            if len(walk) >= 2:
                sentences.append(walk)
    
    logger.info(f"共读取 {len(sentences)} 条游走序列")
    
    # 训练Word2Vec
    logger.info("开始训练Word2Vec模型...")
    model = Word2Vec(
        sentences=sentences,
        vector_size=config['vector_size'],
        window=config['window_size'],
        min_count=config['min_count'],
        workers=config['workers'],
        sg=config['sg'],
        epochs=config['epochs'],
        seed=42
    )
    
    logger.info(f"训练完成。词汇表大小:{len(model.wv)}")
    return model


def generate_similarities(model, top_n, logger):
    """
    从Word2Vec模型生成物品相似度
    
    Args:
        model: Word2Vec模型
        top_n: Top N similar items
        logger: 日志对象
    
    Returns:
        Dict[item_id, List[Tuple(similar_item_id, score)]]
    """
    logger.info("生成相似度...")
    result = {}
    
    for item_id in model.wv.index_to_key:
        try:
            similar_items = model.wv.most_similar(item_id, topn=top_n)
            result[item_id] = [(sim_id, float(score)) for sim_id, score in similar_items]
        except KeyError:
            continue
    
    logger.info(f"为 {len(result)} 个物品生成了相似度")
    return result


def save_results(result, output_file, name_mappings, logger):
    """
    保存相似度结果到文件
    
    Args:
        result: 相似度字典
        output_file: 输出文件路径
        name_mappings: ID到名称的映射
        logger: 日志对象
    """
    logger.info(f"保存结果到 {output_file}...")
    
    with open(output_file, 'w', encoding='utf-8') as f:
        for item_id, sims in result.items():
            # 获取物品名称
            item_name = name_mappings.get(int(item_id), 'Unknown') if item_id.isdigit() else 'Unknown'
            
            if not sims:
                continue
            
            # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
            sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])
            f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
    
    logger.info(f"结果保存完成")


def main():
    parser = argparse.ArgumentParser(description='Run DeepWalk for i2i similarity')
    parser.add_argument('--num_walks', type=int, default=I2I_CONFIG['deepwalk']['num_walks'],
                       help='Number of walks per node')
    parser.add_argument('--walk_length', type=int, default=I2I_CONFIG['deepwalk']['walk_length'],
                       help='Walk length')
    parser.add_argument('--window_size', type=int, default=I2I_CONFIG['deepwalk']['window_size'],
                       help='Window size for Word2Vec')
    parser.add_argument('--vector_size', type=int, default=I2I_CONFIG['deepwalk']['vector_size'],
                       help='Vector size for Word2Vec')
    parser.add_argument('--min_count', type=int, default=I2I_CONFIG['deepwalk']['min_count'],
                       help='Minimum word count')
    parser.add_argument('--workers', type=int, default=I2I_CONFIG['deepwalk']['workers'],
                       help='Number of workers')
    parser.add_argument('--epochs', type=int, default=I2I_CONFIG['deepwalk']['epochs'],
                       help='Number of epochs')
    parser.add_argument('--top_n', type=int, default=DEFAULT_I2I_TOP_N,
                       help=f'Top N similar items to output (default: {DEFAULT_I2I_TOP_N})')
    parser.add_argument('--lookback_days', type=int, default=DEFAULT_LOOKBACK_DAYS,
                       help=f'Number of days to look back (default: {DEFAULT_LOOKBACK_DAYS})')
    parser.add_argument('--output', type=str, default=None,
                       help='Output file path')
    parser.add_argument('--save_model', action='store_true',
                       help='Save Word2Vec model')
    parser.add_argument('--save_graph', action='store_true',
                       help='Save graph edge file')
    parser.add_argument('--debug', action='store_true',
                       help='Enable debug mode with detailed logging and readable output')
    parser.add_argument('--use_softmax', action='store_true',
                       help='Use softmax-based alias sampling (default: False)')
    parser.add_argument('--temperature', type=float, default=1.0,
                       help='Temperature for softmax (default: 1.0)')
    
    args = parser.parse_args()
    
    # 设置logger
    logger = setup_debug_logger('i2i_deepwalk', debug=args.debug)
    
    # 记录算法参数
    params = {
        'num_walks': args.num_walks,
        'walk_length': args.walk_length,
        'window_size': args.window_size,
        'vector_size': args.vector_size,
        'min_count': args.min_count,
        'workers': args.workers,
        'epochs': args.epochs,
        'top_n': args.top_n,
        'lookback_days': args.lookback_days,
        'debug': args.debug,
        'use_softmax': args.use_softmax,
        'temperature': args.temperature
    }
    log_algorithm_params(logger, params)
    
    # 创建临时目录
    temp_dir = os.path.join(OUTPUT_DIR, 'temp')
    os.makedirs(temp_dir, exist_ok=True)
    
    date_str = datetime.now().strftime('%Y%m%d')
    edge_file = os.path.join(temp_dir, f'item_graph_{date_str}.txt')
    walks_file = os.path.join(temp_dir, f'walks_{date_str}.txt')
    
    # ============================================================
    # 步骤1: 从数据库获取数据并构建边文件
    # ============================================================
    log_processing_step(logger, "从数据库获取数据")
    
    # 创建数据库连接
    logger.info("连接数据库...")
    engine = create_db_connection(
        DB_CONFIG['host'],
        DB_CONFIG['port'],
        DB_CONFIG['database'],
        DB_CONFIG['username'],
        DB_CONFIG['password']
    )
    
    # 获取时间范围
    start_date, end_date = get_time_range(args.lookback_days)
    logger.info(f"获取数据范围:{start_date} 到 {end_date}")
    
    # SQL查询 - 获取用户行为数据
    sql_query = f"""
    SELECT 
        se.anonymous_id AS user_id,
        se.item_id,
        se.event AS event_type,
        pgs.name AS item_name
    FROM 
        sensors_events se
    LEFT JOIN prd_goods_sku pgs ON se.item_id = pgs.id
    WHERE 
        se.event IN ('click', 'contactFactory', 'addToPool', 'addToCart', 'purchase')
        AND se.create_time >= '{start_date}'
        AND se.create_time <= '{end_date}'
        AND se.item_id IS NOT NULL
        AND se.anonymous_id IS NOT NULL
    """
    
    logger.info("执行SQL查询...")
    df = pd.read_sql(sql_query, engine)
    logger.info(f"获取到 {len(df)} 条记录")
    
    # 记录数据信息
    log_dataframe_info(logger, df, "用户行为数据")
    
    # 定义行为权重
    behavior_weights = {
        'click': 1.0,
        'contactFactory': 5.0,
        'addToPool': 2.0,
        'addToCart': 3.0,
        'purchase': 10.0
    }
    logger.debug(f"行为权重: {behavior_weights}")
    
    # 构建边文件
    log_processing_step(logger, "构建边文件")
    num_nodes = build_edge_file_from_db(df, behavior_weights, edge_file, logger)
    
    # ============================================================
    # 步骤2: 使用DeepWalk进行随机游走
    # ============================================================
    log_processing_step(logger, "执行DeepWalk随机游走")
    
    logger.info("初始化DeepWalk...")
    deepwalk = DeepWalk(
        edge_file=edge_file,
        node_tag_file=None,  # 不使用标签游走
        use_softmax=args.use_softmax,
        temperature=args.temperature,
        p_tag_walk=0.0  # 不使用标签游走
    )
    
    logger.info("开始随机游走...")
    deepwalk.simulate_walks(
        num_walks=args.num_walks,
        walk_length=args.walk_length,
        workers=args.workers,
        output_file=walks_file
    )
    
    # ============================================================
    # 步骤3: 训练Word2Vec模型
    # ============================================================
    log_processing_step(logger, "训练Word2Vec模型")
    
    w2v_config = {
        'vector_size': args.vector_size,
        'window_size': args.window_size,
        'min_count': args.min_count,
        'workers': args.workers,
        'epochs': args.epochs,
        'sg': 1  # Skip-gram
    }
    logger.debug(f"Word2Vec配置: {w2v_config}")
    
    model = train_word2vec_from_walks(walks_file, w2v_config, logger)
    
    # 保存模型(可选)
    if args.save_model:
        model_path = os.path.join(OUTPUT_DIR, f'deepwalk_model_{date_str}.model')
        model.save(model_path)
        logger.info(f"模型已保存到 {model_path}")
    
    # ============================================================
    # 步骤4: 生成相似度
    # ============================================================
    log_processing_step(logger, "生成相似度")
    result = generate_similarities(model, args.top_n, logger)
    
    # ============================================================
    # 步骤5: 保存结果
    # ============================================================
    log_processing_step(logger, "保存结果")
    
    output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_deepwalk_{date_str}.txt')
    
    # 获取name mappings
    name_mappings = {}
    if args.debug:
        logger.info("获取物品名称映射...")
        name_mappings = fetch_name_mappings(engine, debug=True)
    
    save_results(result, output_file, name_mappings, logger)
    
    logger.info(f"✓ DeepWalk完成!")
    logger.info(f"  - 输出文件: {output_file}")
    logger.info(f"  - 商品数: {len(result)}")
    if result:
        avg_sims = sum(len(sims) for sims in result.values()) / len(result)
        logger.info(f"  - 平均相似商品数: {avg_sims:.1f}")
    
    # 如果启用debug模式,保存可读格式
    if args.debug:
        log_processing_step(logger, "保存Debug可读格式")
        save_readable_index(
            output_file,
            result,
            name_mappings,
            description='i2i:deepwalk'
        )
    
    # 清理临时文件(可选)
    if not args.save_graph:
        if os.path.exists(edge_file):
            os.remove(edge_file)
            logger.debug(f"已删除临时文件: {edge_file}")
    if os.path.exists(walks_file):
        os.remove(walks_file)
        logger.debug(f"已删除临时文件: {walks_file}")
    
    print(f"✓ DeepWalk相似度计算完成")
    print(f"  - 输出文件: {output_file}")
    print(f"  - 商品数: {len(result)}")


if __name__ == '__main__':
    main()