load_index_to_redis.py 6.32 KB
"""
将生成的索引加载到Redis
用于在线推荐系统查询
"""
import redis
import argparse
import logging
from datetime import datetime
from offline_tasks.config.offline_config import REDIS_CONFIG, OUTPUT_DIR

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def load_index_file(file_path, redis_client, key_prefix, expire_seconds=None):
    """
    加载索引文件到Redis
    
    Args:
        file_path: 索引文件路径
        redis_client: Redis客户端
        key_prefix: Redis key前缀
        expire_seconds: 过期时间(秒),None表示不过期
    
    Returns:
        加载的记录数
    """
    if not os.path.exists(file_path):
        logger.error(f"File not found: {file_path}")
        return 0
    
    count = 0
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            
            parts = line.split('\t')
            if len(parts) != 2:
                logger.warning(f"Invalid line format: {line}")
                continue
            
            key_suffix, value = parts
            redis_key = f"{key_prefix}:{key_suffix}"
            
            # 存储到Redis
            redis_client.set(redis_key, value)
            
            # 设置过期时间
            if expire_seconds:
                redis_client.expire(redis_key, expire_seconds)
            
            count += 1
            
            if count % 1000 == 0:
                logger.info(f"Loaded {count} records...")
    
    return count


def load_i2i_indices(redis_client, date_str=None, expire_days=7):
    """
    加载i2i相似度索引
    
    Args:
        redis_client: Redis客户端
        date_str: 日期字符串,格式YYYYMMDD,None表示使用今天
        expire_days: 过期天数
    """
    if not date_str:
        date_str = datetime.now().strftime('%Y%m%d')
    
    expire_seconds = expire_days * 24 * 3600 if expire_days else None
    
    # i2i索引类型
    i2i_types = ['swing', 'session_w2v', 'deepwalk', 'content_name', 'content_pic']
    
    for i2i_type in i2i_types:
        file_path = os.path.join(OUTPUT_DIR, f'i2i_{i2i_type}_{date_str}.txt')
        
        if not os.path.exists(file_path):
            logger.warning(f"File not found: {file_path}, skipping...")
            continue
        
        logger.info(f"Loading {i2i_type} indices...")
        count = load_index_file(
            file_path,
            redis_client,
            f"i2i:{i2i_type}",
            expire_seconds
        )
        logger.info(f"Loaded {count} {i2i_type} indices")


def load_interest_indices(redis_client, date_str=None, expire_days=7):
    """
    加载兴趣点聚合索引
    
    Args:
        redis_client: Redis客户端
        date_str: 日期字符串,格式YYYYMMDD,None表示使用今天
        expire_days: 过期天数
    """
    if not date_str:
        date_str = datetime.now().strftime('%Y%m%d')
    
    expire_seconds = expire_days * 24 * 3600 if expire_days else None
    
    # 兴趣点索引类型
    list_types = ['hot', 'cart', 'new', 'global']
    
    for list_type in list_types:
        file_path = os.path.join(OUTPUT_DIR, f'interest_aggregation_{list_type}_{date_str}.txt')
        
        if not os.path.exists(file_path):
            logger.warning(f"File not found: {file_path}, skipping...")
            continue
        
        logger.info(f"Loading {list_type} interest indices...")
        count = load_index_file(
            file_path,
            redis_client,
            f"interest:{list_type}",
            expire_seconds
        )
        logger.info(f"Loaded {count} {list_type} indices")


def main():
    parser = argparse.ArgumentParser(description='Load recommendation indices to Redis')
    parser.add_argument('--redis-host', type=str, default=REDIS_CONFIG.get('host', 'localhost'),
                       help='Redis host')
    parser.add_argument('--redis-port', type=int, default=REDIS_CONFIG.get('port', 6379),
                       help='Redis port')
    parser.add_argument('--redis-db', type=int, default=REDIS_CONFIG.get('db', 0),
                       help='Redis database')
    parser.add_argument('--redis-password', type=str, default=REDIS_CONFIG.get('password'),
                       help='Redis password')
    parser.add_argument('--date', type=str, default=None,
                       help='Date string (YYYYMMDD), default is today')
    parser.add_argument('--expire-days', type=int, default=7,
                       help='Expire days for Redis keys')
    parser.add_argument('--load-i2i', action='store_true', default=True,
                       help='Load i2i indices')
    parser.add_argument('--load-interest', action='store_true', default=True,
                       help='Load interest indices')
    parser.add_argument('--flush-db', action='store_true',
                       help='Flush database before loading (危险操作!)')
    
    args = parser.parse_args()
    
    # 创建Redis连接
    logger.info("Connecting to Redis...")
    redis_client = redis.Redis(
        host=args.redis_host,
        port=args.redis_port,
        db=args.redis_db,
        password=args.redis_password,
        decode_responses=True
    )
    
    # 测试连接
    try:
        redis_client.ping()
        logger.info("Redis connection successful")
    except Exception as e:
        logger.error(f"Failed to connect to Redis: {e}")
        return 1
    
    # Flush数据库(如果需要)
    if args.flush_db:
        logger.warning("Flushing Redis database...")
        redis_client.flushdb()
        logger.info("Database flushed")
    
    # 加载i2i索引
    if args.load_i2i:
        logger.info("\n" + "="*80)
        logger.info("Loading i2i indices")
        logger.info("="*80)
        load_i2i_indices(redis_client, args.date, args.expire_days)
    
    # 加载兴趣点索引
    if args.load_interest:
        logger.info("\n" + "="*80)
        logger.info("Loading interest aggregation indices")
        logger.info("="*80)
        load_interest_indices(redis_client, args.date, args.expire_days)
    
    logger.info("\n" + "="*80)
    logger.info("All indices loaded successfully!")
    logger.info("="*80)
    
    return 0


if __name__ == '__main__':
    sys.exit(main())