add_names_to_swing.py 4.53 KB
"""
给Swing算法输出结果添加name映射
输入格式: item_id \t similar_item_id1:score1,similar_item_id2:score2,...
输出格式: item_id:name \t similar_item_id1:name1:score1,similar_item_id2:name2:score2,...
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

import argparse
from datetime import datetime
from offline_tasks.scripts.debug_utils import setup_debug_logger, load_name_mappings_from_file


def add_names_to_swing_result(input_file, output_file, name_mappings, logger=None, debug=False):
    """
    给Swing结果添加name映射
    
    Args:
        input_file: 输入文件路径
        output_file: 输出文件路径
        name_mappings: ID到名称的映射字典
        logger: 日志记录器
        debug: 是否开启debug模式
    """
    if logger:
        logger.info(f"处理文件: {input_file}")
        logger.info(f"输出到: {output_file}")
    
    item_names = name_mappings.get('item', {})
    
    processed_lines = 0
    skipped_lines = 0
    
    with open(input_file, 'r', encoding='utf-8') as fin, \
         open(output_file, 'w', encoding='utf-8') as fout:
        
        for line in fin:
            line = line.strip()
            if not line:
                continue
            
            parts = line.split('\t')
            if len(parts) != 2:
                skipped_lines += 1
                continue
            
            item_id = parts[0]
            sim_items_str = parts[1]
            
            # 获取item name
            item_name = item_names.get(str(item_id), 'Unknown')
            
            # 处理相似商品列表
            sim_items = []
            for sim_pair in sim_items_str.split(','):
                if ':' not in sim_pair:
                    continue
                
                sim_id, score = sim_pair.rsplit(':', 1)
                sim_name = item_names.get(str(sim_id), 'Unknown')
                
                # 格式: item_id:name:score
                sim_items.append(f"{sim_id}:{sim_name}:{score}")
            
            # 写入输出
            sim_items_output = ','.join(sim_items)
            fout.write(f"{item_id}:{item_name}\t{sim_items_output}\n")
            
            processed_lines += 1
            
            # Debug: 显示进度
            if debug and logger and processed_lines % 1000 == 0:
                logger.debug(f"已处理 {processed_lines} 行")
    
    if logger:
        logger.info(f"处理完成:")
        logger.info(f"  成功处理: {processed_lines} 行")
        logger.info(f"  跳过: {skipped_lines} 行")


def main():
    parser = argparse.ArgumentParser(description='Add names to Swing algorithm output')
    parser.add_argument('input_file', type=str,
                       help='Input file path (Swing output)')
    parser.add_argument('output_file', type=str, nargs='?', default=None,
                       help='Output file path (if not specified, will add _readable suffix)')
    parser.add_argument('--debug', action='store_true',
                       help='Enable debug mode with detailed logging')
    
    args = parser.parse_args()
    
    # 设置日志
    logger = setup_debug_logger('add_names_to_swing', debug=args.debug)
    
    # 如果没有指定输出文件,自动生成
    if args.output_file is None:
        input_dir = os.path.dirname(args.input_file)
        input_basename = os.path.basename(args.input_file)
        name_without_ext = os.path.splitext(input_basename)[0]
        args.output_file = os.path.join(input_dir, f"{name_without_ext}_readable.txt")
    
    logger.info(f"输入文件: {args.input_file}")
    logger.info(f"输出文件: {args.output_file}")
    
    # 检查输入文件是否存在
    if not os.path.exists(args.input_file):
        logger.error(f"输入文件不存在: {args.input_file}")
        return
    
    # 从本地文件加载名称映射
    logger.info("加载ID到名称的映射...")
    name_mappings = load_name_mappings_from_file(debug=args.debug)
    
    if not name_mappings or not name_mappings.get('item'):
        logger.error("映射文件为空或加载失败")
        logger.error("请先运行: python3 scripts/fetch_item_attributes.py")
        return
    
    logger.info(f"加载了 {len(name_mappings['item'])} 个商品名称")
    
    # 处理文件
    add_names_to_swing_result(
        args.input_file,
        args.output_file,
        name_mappings,
        logger=logger,
        debug=args.debug
    )
    
    logger.info("完成!")


if __name__ == '__main__':
    main()