Blame view

offline_tasks/scripts/add_names_to_swing.py 4.53 KB
5b61955e   tangwang   offline tasks: me...
1
2
3
4
5
6
7
8
9
10
11
  """
  Swing算法输出结果添加name映射
  输入格式: item_id \t similar_item_id1:score1,similar_item_id2:score2,...
  输出格式: item_id:name \t similar_item_id1:name1:score1,similar_item_id2:name2:score2,...
  """
  import sys
  import os
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  
  import argparse
  from datetime import datetime
12118125   tangwang   offline tasks: me...
12
  from offline_tasks.scripts.debug_utils import setup_debug_logger, load_name_mappings_from_file
5b61955e   tangwang   offline tasks: me...
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
  
  
  def add_names_to_swing_result(input_file, output_file, name_mappings, logger=None, debug=False):
      """
      Swing结果添加name映射
      
      Args:
          input_file: 输入文件路径
          output_file: 输出文件路径
          name_mappings: ID到名称的映射字典
          logger: 日志记录器
          debug: 是否开启debug模式
      """
      if logger:
          logger.info(f"处理文件: {input_file}")
          logger.info(f"输出到: {output_file}")
      
      item_names = name_mappings.get('item', {})
      
      processed_lines = 0
      skipped_lines = 0
      
      with open(input_file, 'r', encoding='utf-8') as fin, \
           open(output_file, 'w', encoding='utf-8') as fout:
          
          for line in fin:
              line = line.strip()
              if not line:
                  continue
              
              parts = line.split('\t')
              if len(parts) != 2:
                  skipped_lines += 1
                  continue
              
              item_id = parts[0]
              sim_items_str = parts[1]
              
              # 获取item name
              item_name = item_names.get(str(item_id), 'Unknown')
              
              # 处理相似商品列表
              sim_items = []
              for sim_pair in sim_items_str.split(','):
                  if ':' not in sim_pair:
                      continue
                  
                  sim_id, score = sim_pair.rsplit(':', 1)
                  sim_name = item_names.get(str(sim_id), 'Unknown')
                  
                  # 格式: item_id:name:score
                  sim_items.append(f"{sim_id}:{sim_name}:{score}")
              
              # 写入输出
              sim_items_output = ','.join(sim_items)
              fout.write(f"{item_id}:{item_name}\t{sim_items_output}\n")
              
              processed_lines += 1
              
              # Debug: 显示进度
              if debug and logger and processed_lines % 1000 == 0:
                  logger.debug(f"已处理 {processed_lines} 行")
      
      if logger:
          logger.info(f"处理完成:")
          logger.info(f"  成功处理: {processed_lines} 行")
          logger.info(f"  跳过: {skipped_lines} 行")
  
  
  def main():
      parser = argparse.ArgumentParser(description='Add names to Swing algorithm output')
      parser.add_argument('input_file', type=str,
                         help='Input file path (Swing output)')
      parser.add_argument('output_file', type=str, nargs='?', default=None,
                         help='Output file path (if not specified, will add _readable suffix)')
      parser.add_argument('--debug', action='store_true',
                         help='Enable debug mode with detailed logging')
      
      args = parser.parse_args()
      
      # 设置日志
      logger = setup_debug_logger('add_names_to_swing', debug=args.debug)
      
      # 如果没有指定输出文件,自动生成
      if args.output_file is None:
          input_dir = os.path.dirname(args.input_file)
          input_basename = os.path.basename(args.input_file)
          name_without_ext = os.path.splitext(input_basename)[0]
          args.output_file = os.path.join(input_dir, f"{name_without_ext}_readable.txt")
      
      logger.info(f"输入文件: {args.input_file}")
      logger.info(f"输出文件: {args.output_file}")
      
      # 检查输入文件是否存在
      if not os.path.exists(args.input_file):
          logger.error(f"输入文件不存在: {args.input_file}")
          return
      
12118125   tangwang   offline tasks: me...
111
112
113
114
115
116
117
118
      # 从本地文件加载名称映射
      logger.info("加载ID到名称的映射...")
      name_mappings = load_name_mappings_from_file(debug=args.debug)
      
      if not name_mappings or not name_mappings.get('item'):
          logger.error("映射文件为空或加载失败")
          logger.error("请先运行: python3 scripts/fetch_item_attributes.py")
          return
5b61955e   tangwang   offline tasks: me...
119
      
12118125   tangwang   offline tasks: me...
120
      logger.info(f"加载了 {len(name_mappings['item'])} 个商品名称")
5b61955e   tangwang   offline tasks: me...
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
      
      # 处理文件
      add_names_to_swing_result(
          args.input_file,
          args.output_file,
          name_mappings,
          logger=logger,
          debug=args.debug
      )
      
      logger.info("完成!")
  
  
  if __name__ == '__main__':
      main()