add_names_to_swing.py
4.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
给Swing算法输出结果添加name映射
输入格式: item_id \t similar_item_id1:score1,similar_item_id2:score2,...
输出格式: item_id:name \t similar_item_id1:name1:score1,similar_item_id2:name2:score2,...
"""
import argparse
from datetime import datetime
from offline_tasks.scripts.debug_utils import setup_debug_logger, load_name_mappings_from_file
def add_names_to_swing_result(input_file, output_file, name_mappings, logger=None, debug=False):
"""
给Swing结果添加name映射
Args:
input_file: 输入文件路径
output_file: 输出文件路径
name_mappings: ID到名称的映射字典
logger: 日志记录器
debug: 是否开启debug模式
"""
if logger:
logger.info(f"处理文件: {input_file}")
logger.info(f"输出到: {output_file}")
item_names = name_mappings.get('item', {})
processed_lines = 0
skipped_lines = 0
with open(input_file, 'r', encoding='utf-8') as fin, \
open(output_file, 'w', encoding='utf-8') as fout:
for line in fin:
line = line.strip()
if not line:
continue
parts = line.split('\t')
if len(parts) != 2:
skipped_lines += 1
continue
item_id = parts[0]
sim_items_str = parts[1]
# 获取item name
item_name = item_names.get(str(item_id), 'Unknown')
# 处理相似商品列表
sim_items = []
for sim_pair in sim_items_str.split(','):
if ':' not in sim_pair:
continue
sim_id, score = sim_pair.rsplit(':', 1)
sim_name = item_names.get(str(sim_id), 'Unknown')
# 格式: item_id:name:score
sim_items.append(f"{sim_id}:{sim_name}:{score}")
# 写入输出
sim_items_output = ','.join(sim_items)
fout.write(f"{item_id}:{item_name}\t{sim_items_output}\n")
processed_lines += 1
# Debug: 显示进度
if debug and logger and processed_lines % 1000 == 0:
logger.debug(f"已处理 {processed_lines} 行")
if logger:
logger.info(f"处理完成:")
logger.info(f" 成功处理: {processed_lines} 行")
logger.info(f" 跳过: {skipped_lines} 行")
def main():
parser = argparse.ArgumentParser(description='Add names to Swing algorithm output')
parser.add_argument('input_file', type=str,
help='Input file path (Swing output)')
parser.add_argument('output_file', type=str, nargs='?', default=None,
help='Output file path (if not specified, will add _readable suffix)')
parser.add_argument('--debug', action='store_true',
help='Enable debug mode with detailed logging')
args = parser.parse_args()
# 设置日志
logger = setup_debug_logger('add_names_to_swing', debug=args.debug)
# 如果没有指定输出文件,自动生成
if args.output_file is None:
input_dir = os.path.dirname(args.input_file)
input_basename = os.path.basename(args.input_file)
name_without_ext = os.path.splitext(input_basename)[0]
args.output_file = os.path.join(input_dir, f"{name_without_ext}_readable.txt")
logger.info(f"输入文件: {args.input_file}")
logger.info(f"输出文件: {args.output_file}")
# 检查输入文件是否存在
if not os.path.exists(args.input_file):
logger.error(f"输入文件不存在: {args.input_file}")
return
# 从本地文件加载名称映射
logger.info("加载ID到名称的映射...")
name_mappings = load_name_mappings_from_file(debug=args.debug)
if not name_mappings or not name_mappings.get('item'):
logger.error("映射文件为空或加载失败")
logger.error("请先运行: python3 scripts/fetch_item_attributes.py")
return
logger.info(f"加载了 {len(name_mappings['item'])} 个商品名称")
# 处理文件
add_names_to_swing_result(
args.input_file,
args.output_file,
name_mappings,
logger=logger,
debug=args.debug
)
logger.info("完成!")
if __name__ == '__main__':
main()