Blame view

graphembedding/session_w2v/prepare_data.py 1.56 KB
5ab1c29c   tangwang   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
  import json

  import sys

  

  

  def main(input_file, output_file, max_sentence_length):

      """

      主函数,读取输入文件,处理每一行json,将结果写入输出文件。

      

      参数:

      - input_file: 输入文件路径

      - output_file: 输出文件路径

      - max_sentence_length: 最大句子长度

      """

      max_sentence_length = int(max_sentence_length)

      with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:

          for line in infile:

              # 去除空行

              line = line.strip()

              if not line:

                  continue

              

              # 处理当前行

              uid, session = line.split('\t')

              data = json.loads(session)

              keys = list(data.keys())

              if len(keys) < 3:

                  continue

              

              # 如果keys数量超出最大句子长度,则按最大句子长度拆分

              sentences = [keys[i:i + max_sentence_length] for i in range(0, len(keys), max_sentence_length)]

              

              # 写入每个分割后的句子到输出文件

              for sentence in sentences:

                  outfile.write(" ".join(sentence) + "\n")

  

  if __name__ == "__main__":

      # 从命令行读取参数

      if len(sys.argv) != 4:

          print("用法: python prepare_data.py <输入文件> <输出文件> <最大句子长度>")

          sys.exit(1)

  

      input_file = sys.argv[1]

      output_file = sys.argv[2]

      max_sentence_length = sys.argv[3]

  

      main(input_file, output_file, max_sentence_length)