load_index_to_redis.py 9.06 KB
"""
将生成的索引加载到Redis
用于在线推荐系统查询
"""
import redis
import argparse
import logging
import os
import sys
from datetime import datetime

# 添加父目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config.offline_config import REDIS_CONFIG, OUTPUT_DIR

def setup_logger():
    """设置logger配置"""
    # 创建logs目录
    logs_dir = 'logs'
    os.makedirs(logs_dir, exist_ok=True)
    
    # 创建logger
    logger = logging.getLogger('load_index_to_redis')
    logger.setLevel(logging.INFO)
    
    # 避免重复添加handler
    if logger.handlers:
        return logger
    
    # 创建文件handler
    log_file = os.path.join(logs_dir, f'load_index_to_redis_{datetime.now().strftime("%Y%m%d")}.log')
    file_handler = logging.FileHandler(log_file, encoding='utf-8')
    file_handler.setLevel(logging.INFO)
    
    # 创建控制台handler
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    
    # 创建formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)
    
    # 添加handler到logger
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    
    return logger

logger = setup_logger()


def load_index_file(file_path, redis_client, key_prefix, expire_seconds=None):
    """
    加载索引文件到Redis
    
    Args:
        file_path: 索引文件路径
        redis_client: Redis客户端
        key_prefix: Redis key前缀
        expire_seconds: 过期时间(秒),None表示不过期
    
    Returns:
        加载的记录数
    """
    if not os.path.exists(file_path):
        logger.error(f"File not found: {file_path}")
        return 0
    
    count = 0
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            
            parts = line.split('\t')
            if len(parts) < 2:
                logger.warning(f"Invalid line format (expected at least 2 fields): {line}")
                continue
            
            # 支持2字段和3字段格式
            # 格式1 (2字段): item_id \t similar_items
            # 格式2 (3字段): item_id \t item_name \t similar_items (推荐格式)
            # 取第一个字段作为key,最后一个字段作为value
            key_suffix = parts[0]
            
            # 修复:将浮点数ID转换为整数(如 "60678.0" -> "60678")
            try:
                if '.' in key_suffix:
                    key_suffix = str(int(float(key_suffix)))
            except (ValueError, OverflowError):
                # 如果转换失败,保持原样
                pass
                
            value = parts[-1]
            redis_key = f"{key_prefix}:{key_suffix}"
            
            # 存储到Redis
            redis_client.set(redis_key, value)
            
            # 设置过期时间
            if expire_seconds:
                redis_client.expire(redis_key, expire_seconds)
            
            count += 1
            
            if count % 1000 == 0:
                logger.info(f"Loaded {count} records...")
    
    return count


def load_cpp_swing_index(redis_client, expire_days=7):
    """
    加载C++ Swing相似度索引
    
    Args:
        redis_client: Redis客户端
        expire_days: 过期天数
    
    Returns:
        加载的记录数
    """
    # C++ Swing输出文件
    file_path = os.path.join(os.path.dirname(OUTPUT_DIR), 'collaboration', 'output', 'swing_similar.txt')
    
    if not os.path.exists(file_path):
        logger.warning(f"C++ Swing file not found: {file_path}, skipping...")
        return 0
    
    expire_seconds = expire_days * 24 * 3600 if expire_days else None
    
    logger.info(f"Loading C++ Swing indices from {file_path}...")
    count = load_index_file(
        file_path,
        redis_client,
        "item:similar:swing_cpp",
        expire_seconds
    )
    logger.info(f"Loaded {count} C++ Swing indices")
    return count


def load_i2i_indices(redis_client, date_str=None, expire_days=7):
    """
    加载i2i相似度索引
    
    Args:
        redis_client: Redis客户端
        date_str: 日期字符串,格式YYYYMMDD,None表示使用今天
        expire_days: 过期天数
    """
    if not date_str:
        date_str = datetime.now().strftime('%Y%m%d')
    
    expire_seconds = expire_days * 24 * 3600 if expire_days else None
    
    # i2i索引类型
    i2i_types = ['swing', 'session_w2v', 'deepwalk', 'content_name', 'content_pic', 'item_behavior']
    
    for i2i_type in i2i_types:
        file_path = os.path.join(OUTPUT_DIR, f'i2i_{i2i_type}_{date_str}.txt')
        
        if not os.path.exists(file_path):
            logger.warning(f"File not found: {file_path}, skipping...")
            continue
        
        logger.info(f"Loading {i2i_type} indices...")
        count = load_index_file(
            file_path,
            redis_client,
            f"item:similar:{i2i_type}",  # 修复: 使用正确的key前缀
            expire_seconds
        )
        logger.info(f"Loaded {count} {i2i_type} indices")


def load_interest_indices(redis_client, date_str=None, expire_days=7):
    """
    加载兴趣点聚合索引
    
    Args:
        redis_client: Redis客户端
        date_str: 日期字符串,格式YYYYMMDD,None表示使用今天
        expire_days: 过期天数
    """
    if not date_str:
        date_str = datetime.now().strftime('%Y%m%d')
    
    expire_seconds = expire_days * 24 * 3600 if expire_days else None
    
    # 兴趣点索引类型
    list_types = ['hot', 'cart', 'new', 'global']
    
    for list_type in list_types:
        file_path = os.path.join(OUTPUT_DIR, f'interest_aggregation_{list_type}_{date_str}.txt')
        
        if not os.path.exists(file_path):
            logger.warning(f"File not found: {file_path}, skipping...")
            continue
        
        logger.info(f"Loading {list_type} interest indices...")
        count = load_index_file(
            file_path,
            redis_client,
            f"interest:{list_type}",
            expire_seconds
        )
        logger.info(f"Loaded {count} {list_type} indices")


def main():
    parser = argparse.ArgumentParser(description='Load recommendation indices to Redis')
    parser.add_argument('--redis-host', type=str, default=REDIS_CONFIG['host'],
                       help='Redis host')
    parser.add_argument('--redis-port', type=int, default=REDIS_CONFIG['port'],
                       help='Redis port')
    parser.add_argument('--redis-db', type=int, default=REDIS_CONFIG['db'],
                       help='Redis database')
    parser.add_argument('--redis-password', type=str, default=REDIS_CONFIG['password'],
                       help='Redis password')
    parser.add_argument('--date', type=str, default=None,
                       help='Date string (YYYYMMDD), default is today')
    parser.add_argument('--expire-days', type=int, default=7,
                       help='Expire days for Redis keys')
    parser.add_argument('--load-i2i', action='store_true', default=True,
                       help='Load i2i indices')
    parser.add_argument('--load-interest', action='store_true', default=True,
                       help='Load interest indices')
    parser.add_argument('--flush-db', action='store_true',
                       help='Flush database before loading (危险操作!)')
    
    args = parser.parse_args()
    
    # 创建Redis连接
    logger.info("Connecting to Redis...")
    redis_client = redis.Redis(
        host=args.redis_host,
        port=args.redis_port,
        db=args.redis_db,
        password=args.redis_password,
        decode_responses=True
    )
    
    # 测试连接
    try:
        redis_client.ping()
        logger.info("Redis connection successful")
    except Exception as e:
        logger.error(f"Failed to connect to Redis: {e}")
        return 1
    
    # Flush数据库(如果需要)
    if args.flush_db:
        logger.warning("Flushing Redis database...")
        redis_client.flushdb()
        logger.info("Database flushed")
    
    # 加载C++ Swing索引
    if args.load_i2i:
        logger.info("\n" + "="*80)
        logger.info("Loading C++ Swing indices")
        logger.info("="*80)
        load_cpp_swing_index(redis_client, args.expire_days)
    
    # 加载i2i索引
    if args.load_i2i:
        logger.info("\n" + "="*80)
        logger.info("Loading i2i indices")
        logger.info("="*80)
        load_i2i_indices(redis_client, args.date, args.expire_days)
    
    # 加载兴趣点索引
    if args.load_interest:
        logger.info("\n" + "="*80)
        logger.info("Loading interest aggregation indices")
        logger.info("="*80)
        load_interest_indices(redis_client, args.date, args.expire_days)
    
    logger.info("\n" + "="*80)
    logger.info("All indices loaded successfully!")
    logger.info("="*80)
    
    return 0


if __name__ == '__main__':
    sys.exit(main())