5b61955e
tangwang
offline tasks: me...
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
|
"""
给Swing算法输出结果添加name映射
输入格式: item_id \t similar_item_id1:score1,similar_item_id2:score2,...
输出格式: item_id:name \t similar_item_id1:name1:score1,similar_item_id2:name2:score2,...
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import argparse
from datetime import datetime
from db_service import create_db_connection
from offline_tasks.config.offline_config import DB_CONFIG
from offline_tasks.scripts.debug_utils import setup_debug_logger, fetch_name_mappings
def add_names_to_swing_result(input_file, output_file, name_mappings, logger=None, debug=False):
"""
给Swing结果添加name映射
Args:
input_file: 输入文件路径
output_file: 输出文件路径
name_mappings: ID到名称的映射字典
logger: 日志记录器
debug: 是否开启debug模式
"""
if logger:
logger.info(f"处理文件: {input_file}")
logger.info(f"输出到: {output_file}")
item_names = name_mappings.get('item', {})
processed_lines = 0
skipped_lines = 0
with open(input_file, 'r', encoding='utf-8') as fin, \
open(output_file, 'w', encoding='utf-8') as fout:
for line in fin:
line = line.strip()
if not line:
continue
parts = line.split('\t')
if len(parts) != 2:
skipped_lines += 1
continue
item_id = parts[0]
sim_items_str = parts[1]
# 获取item name
item_name = item_names.get(str(item_id), 'Unknown')
# 处理相似商品列表
sim_items = []
for sim_pair in sim_items_str.split(','):
if ':' not in sim_pair:
continue
sim_id, score = sim_pair.rsplit(':', 1)
sim_name = item_names.get(str(sim_id), 'Unknown')
# 格式: item_id:name:score
sim_items.append(f"{sim_id}:{sim_name}:{score}")
# 写入输出
sim_items_output = ','.join(sim_items)
fout.write(f"{item_id}:{item_name}\t{sim_items_output}\n")
processed_lines += 1
# Debug: 显示进度
if debug and logger and processed_lines % 1000 == 0:
logger.debug(f"已处理 {processed_lines} 行")
if logger:
logger.info(f"处理完成:")
logger.info(f" 成功处理: {processed_lines} 行")
logger.info(f" 跳过: {skipped_lines} 行")
def main():
parser = argparse.ArgumentParser(description='Add names to Swing algorithm output')
parser.add_argument('input_file', type=str,
help='Input file path (Swing output)')
parser.add_argument('output_file', type=str, nargs='?', default=None,
help='Output file path (if not specified, will add _readable suffix)')
parser.add_argument('--debug', action='store_true',
help='Enable debug mode with detailed logging')
args = parser.parse_args()
# 设置日志
logger = setup_debug_logger('add_names_to_swing', debug=args.debug)
# 如果没有指定输出文件,自动生成
if args.output_file is None:
input_dir = os.path.dirname(args.input_file)
input_basename = os.path.basename(args.input_file)
name_without_ext = os.path.splitext(input_basename)[0]
args.output_file = os.path.join(input_dir, f"{name_without_ext}_readable.txt")
logger.info(f"输入文件: {args.input_file}")
logger.info(f"输出文件: {args.output_file}")
# 检查输入文件是否存在
if not os.path.exists(args.input_file):
logger.error(f"输入文件不存在: {args.input_file}")
return
# 创建数据库连接
logger.info("连接数据库...")
engine = create_db_connection(
DB_CONFIG['host'],
DB_CONFIG['port'],
DB_CONFIG['database'],
DB_CONFIG['username'],
DB_CONFIG['password']
)
# 获取名称映射
logger.info("获取ID到名称的映射...")
name_mappings = fetch_name_mappings(engine, debug=args.debug)
logger.info(f"获取到 {len(name_mappings['item'])} 个商品名称")
# 处理文件
add_names_to_swing_result(
args.input_file,
args.output_file,
name_mappings,
logger=logger,
debug=args.debug
)
logger.info("完成!")
if __name__ == '__main__':
main()
|