prepare_data.py
1.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import json
import sys
def main(input_file, output_file, max_sentence_length):
"""
主函数,读取输入文件,处理每一行json,将结果写入输出文件。
参数:
- input_file: 输入文件路径
- output_file: 输出文件路径
- max_sentence_length: 最大句子长度
"""
max_sentence_length = int(max_sentence_length)
with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
for line in infile:
# 去除空行
line = line.strip()
if not line:
continue
# 处理当前行
uid, session = line.split('\t')
data = json.loads(session)
keys = list(data.keys())
if len(keys) < 3:
continue
# 如果keys数量超出最大句子长度,则按最大句子长度拆分
sentences = [keys[i:i + max_sentence_length] for i in range(0, len(keys), max_sentence_length)]
# 写入每个分割后的句子到输出文件
for sentence in sentences:
outfile.write(" ".join(sentence) + "\n")
if __name__ == "__main__":
# 从命令行读取参数
if len(sys.argv) != 4:
print("用法: python prepare_data.py <输入文件> <输出文件> <最大句子长度>")
sys.exit(1)
input_file = sys.argv[1]
output_file = sys.argv[2]
max_sentence_length = sys.argv[3]
main(input_file, output_file, max_sentence_length)