i2i_session_w2v.py 7.86 KB
"""
i2i - Session Word2Vec算法实现
基于用户会话序列训练Word2Vec模型,获取物品向量相似度
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

import pandas as pd
import json
import argparse
from datetime import datetime
from collections import defaultdict
from gensim.models import Word2Vec
import numpy as np
from db_service import create_db_connection
from offline_tasks.config.offline_config import (
    DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range,
    DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N
)


def prepare_session_data(df, session_gap_minutes=30):
    """
    准备会话数据
    
    Args:
        df: DataFrame with columns: user_id, item_id, create_time
        session_gap_minutes: 会话间隔时间(分钟)
    
    Returns:
        List of sessions, each session is a list of item_ids
    """
    sessions = []
    
    # 按用户和时间排序
    df = df.sort_values(['user_id', 'create_time'])
    
    # 按用户分组
    for user_id, user_df in df.groupby('user_id'):
        user_sessions = []
        current_session = []
        last_time = None
        
        for _, row in user_df.iterrows():
            item_id = str(row['item_id'])
            current_time = row['create_time']
            
            # 判断是否需要开始新会话
            if last_time is None or (current_time - last_time).total_seconds() / 60 > session_gap_minutes:
                if current_session:
                    user_sessions.append(current_session)
                current_session = [item_id]
            else:
                current_session.append(item_id)
            
            last_time = current_time
        
        # 添加最后一个会话
        if current_session:
            user_sessions.append(current_session)
        
        sessions.extend(user_sessions)
    
    # 过滤掉长度小于2的会话
    sessions = [s for s in sessions if len(s) >= 2]
    
    return sessions


def train_word2vec(sessions, config):
    """
    训练Word2Vec模型
    
    Args:
        sessions: List of sessions
        config: Word2Vec配置
    
    Returns:
        Word2Vec模型
    """
    print(f"Training Word2Vec with {len(sessions)} sessions...")
    
    model = Word2Vec(
        sentences=sessions,
        vector_size=config['vector_size'],
        window=config['window_size'],
        min_count=config['min_count'],
        workers=config['workers'],
        sg=config['sg'],
        epochs=config['epochs'],
        seed=42
    )
    
    print(f"Training completed. Vocabulary size: {len(model.wv)}")
    return model


def generate_similarities(model, top_n=50):
    """
    生成物品相似度
    
    Args:
        model: Word2Vec模型
        top_n: Top N similar items
    
    Returns:
        Dict[item_id, List[Tuple(similar_item_id, score)]]
    """
    result = {}
    
    for item_id in model.wv.index_to_key:
        try:
            similar_items = model.wv.most_similar(item_id, topn=top_n)
            result[item_id] = [(sim_id, float(score)) for sim_id, score in similar_items]
        except KeyError:
            continue
    
    return result


def main():
    parser = argparse.ArgumentParser(description='Run Session Word2Vec for i2i similarity')
    parser.add_argument('--window_size', type=int, default=I2I_CONFIG['session_w2v']['window_size'],
                       help='Window size for Word2Vec')
    parser.add_argument('--vector_size', type=int, default=I2I_CONFIG['session_w2v']['vector_size'],
                       help='Vector size for Word2Vec')
    parser.add_argument('--min_count', type=int, default=I2I_CONFIG['session_w2v']['min_count'],
                       help='Minimum word count')
    parser.add_argument('--workers', type=int, default=I2I_CONFIG['session_w2v']['workers'],
                       help='Number of workers')
    parser.add_argument('--epochs', type=int, default=I2I_CONFIG['session_w2v']['epochs'],
                       help='Number of epochs')
    parser.add_argument('--top_n', type=int, default=DEFAULT_I2I_TOP_N,
                       help=f'Top N similar items to output (default: {DEFAULT_I2I_TOP_N})')
    parser.add_argument('--lookback_days', type=int, default=DEFAULT_LOOKBACK_DAYS,
                       help=f'Number of days to look back (default: {DEFAULT_LOOKBACK_DAYS})')
    parser.add_argument('--session_gap', type=int, default=30,
                       help='Session gap in minutes')
    parser.add_argument('--output', type=str, default=None,
                       help='Output file path')
    parser.add_argument('--save_model', action='store_true',
                       help='Save Word2Vec model')
    parser.add_argument('--debug', action='store_true',
                       help='Enable debug mode with detailed logging and readable output')
    
    args = parser.parse_args()
    
    # 创建数据库连接
    print("Connecting to database...")
    engine = create_db_connection(
        DB_CONFIG['host'],
        DB_CONFIG['port'],
        DB_CONFIG['database'],
        DB_CONFIG['username'],
        DB_CONFIG['password']
    )
    
    # 获取时间范围
    start_date, end_date = get_time_range(args.lookback_days)
    print(f"Fetching data from {start_date} to {end_date}...")
    
    # SQL查询 - 获取用户行为序列
    sql_query = f"""
    SELECT 
        se.anonymous_id AS user_id,
        se.item_id,
        se.create_time,
        pgs.name AS item_name
    FROM 
        sensors_events se
    LEFT JOIN prd_goods_sku pgs ON se.item_id = pgs.id
    WHERE 
        se.event IN ('click', 'contactFactory', 'addToPool', 'addToCart', 'purchase')
        AND se.create_time >= '{start_date}'
        AND se.create_time <= '{end_date}'
        AND se.item_id IS NOT NULL
        AND se.anonymous_id IS NOT NULL
    ORDER BY 
        se.anonymous_id,
        se.create_time
    """
    
    print("Executing SQL query...")
    df = pd.read_sql(sql_query, engine)
    print(f"Fetched {len(df)} records")
    
    # 转换create_time为datetime
    df['create_time'] = pd.to_datetime(df['create_time'])
    
    # 准备会话数据
    print("Preparing session data...")
    sessions = prepare_session_data(df, session_gap_minutes=args.session_gap)
    print(f"Generated {len(sessions)} sessions")
    
    # 训练Word2Vec模型
    w2v_config = {
        'vector_size': args.vector_size,
        'window_size': args.window_size,
        'min_count': args.min_count,
        'workers': args.workers,
        'epochs': args.epochs,
        'sg': 1
    }
    
    model = train_word2vec(sessions, w2v_config)
    
    # 保存模型(可选)
    if args.save_model:
        model_path = os.path.join(OUTPUT_DIR, f'session_w2v_model_{datetime.now().strftime("%Y%m%d")}.model')
        model.save(model_path)
        print(f"Model saved to {model_path}")
    
    # 生成相似度
    print("Generating similarities...")
    result = generate_similarities(model, top_n=args.top_n)
    
    # 创建item_id到name的映射
    item_name_map = dict(zip(df['item_id'].astype(str), df.groupby('item_id')['item_name'].first()))
    
    # 输出结果
    output_file = args.output or os.path.join(OUTPUT_DIR, f'i2i_session_w2v_{datetime.now().strftime("%Y%m%d")}.txt')
    
    print(f"Writing results to {output_file}...")
    with open(output_file, 'w', encoding='utf-8') as f:
        for item_id, sims in result.items():
            item_name = item_name_map.get(item_id, 'Unknown')
            
            if not sims:
                continue
            
            # 格式:item_id \t item_name \t similar_item_id1:score1,similar_item_id2:score2,...
            sim_str = ','.join([f'{sim_id}:{score:.4f}' for sim_id, score in sims])
            f.write(f'{item_id}\t{item_name}\t{sim_str}\n')
    
    print(f"Done! Generated i2i similarities for {len(result)} items")
    print(f"Output saved to: {output_file}")


if __name__ == '__main__':
    main()