prepare_data.py 1.56 KB
import json
import sys


def main(input_file, output_file, max_sentence_length):
    """
    主函数,读取输入文件,处理每一行json,将结果写入输出文件。
    
    参数:
    - input_file: 输入文件路径
    - output_file: 输出文件路径
    - max_sentence_length: 最大句子长度
    """
    max_sentence_length = int(max_sentence_length)
    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        for line in infile:
            # 去除空行
            line = line.strip()
            if not line:
                continue
            
            # 处理当前行
            uid, session = line.split('\t')
            data = json.loads(session)
            keys = list(data.keys())
            if len(keys) < 3:
                continue
            
            # 如果keys数量超出最大句子长度,则按最大句子长度拆分
            sentences = [keys[i:i + max_sentence_length] for i in range(0, len(keys), max_sentence_length)]
            
            # 写入每个分割后的句子到输出文件
            for sentence in sentences:
                outfile.write(" ".join(sentence) + "\n")

if __name__ == "__main__":
    # 从命令行读取参数
    if len(sys.argv) != 4:
        print("用法: python prepare_data.py <输入文件> <输出文件> <最大句子长度>")
        sys.exit(1)

    input_file = sys.argv[1]
    output_file = sys.argv[2]
    max_sentence_length = sys.argv[3]

    main(input_file, output_file, max_sentence_length)