Blame view

collaboration/eval.py 3.07 KB
5ab1c29c   tangwang   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  #!/home/SanJunipero/anaconda3/bin/python
  # -*- coding:UTF-8 -*-
  import os,sys,json,re,time
  import numpy as np
  import pandas as pd
  from itertools import combinations
  import logging
  import traceback
  import cgitb
  from argparse import ArgumentParser
  
  sim_index = {}
  
  max_fea = 20 #最多用x个历史交互id去召回
  max_recall_len = 1200
  
  def para_define(parser):
      parser.add_argument('-s', '--sim_index', type=str, default='')
  
  def parse_sim_item_pair(x):
      x = x.split(':')
      return (int(x[0]), float(x[1]))
  
  def parse_session_item_pair(x):
      x = x.split(':')
      return (int(x[0][1:-1]), float(x[1]))
  
  def run_eval(FLAGS):
      with open(FLAGS.sim_index) as f:
          for line in f:
              segs = line.rstrip().split('\t')
              if len(segs) != 2:
                  continue
              k, vlist = segs
              sim_index[int(k)] = [parse_sim_item_pair(x) for x in vlist.split(',')]
      
      statis = []
      for line in sys.stdin:
          line = line.strip()
          segs = line.split('\t')    
          uid = segs[0]
          session = segs[1][1:-1]
          if not session:
              continue
          session_list = [parse_session_item_pair(x) for x in session.split(',')]
      
          score_list = {}
          for item_id, wei in session_list[1:1+max_fea]:
              for sim_item_id, sim_value in sim_index.get(item_id, []):
                  score_list.setdefault(sim_item_id, 0.0)
                  score_list[sim_item_id] += wei*sim_value
          score_list.items()
          sorted_score_list = sorted(score_list.items(), key = lambda k:k[1], reverse=True)[:max_recall_len]
          
          target_item_id = session_list[0][0]
          hit_pos = -1
          for idx, (k, v) in enumerate(sorted_score_list):
              if target_item_id == k:
                  hit_pos = idx
                  break
  
          if hit_pos == -1 or hit_pos > max_recall_len:
              hit_pos =  max_recall_len
          info = (1, hit_pos, len(sorted_score_list),
              int(hit_pos < 25),
              int(hit_pos < 50),
              int(hit_pos < 100),
              int(hit_pos < 200),
              int(hit_pos < 400),
              int(hit_pos < 800),
              int(hit_pos < max_recall_len),
          )
          statis.append(info)
      statis = np.array(statis)
  
      desc = '''(1, hit_pos, len(sorted_score_list),
              int(hit_pos != -1 and hit_pos < 25),
              int(hit_pos != -1 and hit_pos < 50),
              int(hit_pos != -1 and hit_pos < 100),
              int(hit_pos != -1 and hit_pos < 200),
              int(hit_pos != -1 and hit_pos < 400),
              int(hit_pos != -1 and hit_pos < 800),
              int(hit_pos != -1),
          )'''
      print(desc)
   
      np.set_printoptions(suppress=True)
      print(FLAGS.sim_index, 'mean', '\t'.join([str(x) for x in statis.mean(axis=0)]), sep='\t')
      print(FLAGS.sim_index, 'sum', '\t'.join([str(x) for x in statis.sum(axis=0)]), sep='\t')
  
  
  
  def main():
      cgitb.enable(format='text')
      # op config
      parser = ArgumentParser()
      para_define(parser)
  
      FLAGS, unparsed = parser.parse_known_args()
      print(FLAGS)
  
      run_eval(FLAGS)
  
  if __name__ == "__main__":
      main()