Blame view

offline_tasks/scripts/add_names_to_swing.py 4.64 KB
5b61955e   tangwang   offline tasks: me...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
  """
  Swing算法输出结果添加name映射
  输入格式: item_id \t similar_item_id1:score1,similar_item_id2:score2,...
  输出格式: item_id:name \t similar_item_id1:name1:score1,similar_item_id2:name2:score2,...
  """
  import sys
  import os
  sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  
  import argparse
  from datetime import datetime
  from db_service import create_db_connection
  from offline_tasks.config.offline_config import DB_CONFIG
  from offline_tasks.scripts.debug_utils import setup_debug_logger, fetch_name_mappings
  
  
  def add_names_to_swing_result(input_file, output_file, name_mappings, logger=None, debug=False):
      """
      Swing结果添加name映射
      
      Args:
          input_file: 输入文件路径
          output_file: 输出文件路径
          name_mappings: ID到名称的映射字典
          logger: 日志记录器
          debug: 是否开启debug模式
      """
      if logger:
          logger.info(f"处理文件: {input_file}")
          logger.info(f"输出到: {output_file}")
      
      item_names = name_mappings.get('item', {})
      
      processed_lines = 0
      skipped_lines = 0
      
      with open(input_file, 'r', encoding='utf-8') as fin, \
           open(output_file, 'w', encoding='utf-8') as fout:
          
          for line in fin:
              line = line.strip()
              if not line:
                  continue
              
              parts = line.split('\t')
              if len(parts) != 2:
                  skipped_lines += 1
                  continue
              
              item_id = parts[0]
              sim_items_str = parts[1]
              
              # 获取item name
              item_name = item_names.get(str(item_id), 'Unknown')
              
              # 处理相似商品列表
              sim_items = []
              for sim_pair in sim_items_str.split(','):
                  if ':' not in sim_pair:
                      continue
                  
                  sim_id, score = sim_pair.rsplit(':', 1)
                  sim_name = item_names.get(str(sim_id), 'Unknown')
                  
                  # 格式: item_id:name:score
                  sim_items.append(f"{sim_id}:{sim_name}:{score}")
              
              # 写入输出
              sim_items_output = ','.join(sim_items)
              fout.write(f"{item_id}:{item_name}\t{sim_items_output}\n")
              
              processed_lines += 1
              
              # Debug: 显示进度
              if debug and logger and processed_lines % 1000 == 0:
                  logger.debug(f"已处理 {processed_lines} 行")
      
      if logger:
          logger.info(f"处理完成:")
          logger.info(f"  成功处理: {processed_lines} 行")
          logger.info(f"  跳过: {skipped_lines} 行")
  
  
  def main():
      parser = argparse.ArgumentParser(description='Add names to Swing algorithm output')
      parser.add_argument('input_file', type=str,
                         help='Input file path (Swing output)')
      parser.add_argument('output_file', type=str, nargs='?', default=None,
                         help='Output file path (if not specified, will add _readable suffix)')
      parser.add_argument('--debug', action='store_true',
                         help='Enable debug mode with detailed logging')
      
      args = parser.parse_args()
      
      # 设置日志
      logger = setup_debug_logger('add_names_to_swing', debug=args.debug)
      
      # 如果没有指定输出文件,自动生成
      if args.output_file is None:
          input_dir = os.path.dirname(args.input_file)
          input_basename = os.path.basename(args.input_file)
          name_without_ext = os.path.splitext(input_basename)[0]
          args.output_file = os.path.join(input_dir, f"{name_without_ext}_readable.txt")
      
      logger.info(f"输入文件: {args.input_file}")
      logger.info(f"输出文件: {args.output_file}")
      
      # 检查输入文件是否存在
      if not os.path.exists(args.input_file):
          logger.error(f"输入文件不存在: {args.input_file}")
          return
      
      # 创建数据库连接
      logger.info("连接数据库...")
      engine = create_db_connection(
          DB_CONFIG['host'],
          DB_CONFIG['port'],
          DB_CONFIG['database'],
          DB_CONFIG['username'],
          DB_CONFIG['password']
      )
      
      # 获取名称映射
      logger.info("获取ID到名称的映射...")
      name_mappings = fetch_name_mappings(engine, debug=args.debug)
      logger.info(f"获取到 {len(name_mappings['item'])} 个商品名称")
      
      # 处理文件
      add_names_to_swing_result(
          args.input_file,
          args.output_file,
          name_mappings,
          logger=logger,
          debug=args.debug
      )
      
      logger.info("完成!")
  
  
  if __name__ == '__main__':
      main()