5ab1c29c
tangwang
first commit
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
|
#!/home/SanJunipero/anaconda3/bin/python
# -*- coding:UTF-8 -*-
import os,sys,json,re,time
import numpy as np
import pandas as pd
from itertools import combinations
import logging
import traceback
import cgitb
from argparse import ArgumentParser
sim_index = {}
max_fea = 20 #最多用x个历史交互id去召回
max_recall_len = 1200
def para_define(parser):
parser.add_argument('-s', '--sim_index', type=str, default='')
def parse_sim_item_pair(x):
x = x.split(':')
return (int(x[0]), float(x[1]))
def parse_session_item_pair(x):
x = x.split(':')
return (int(x[0][1:-1]), float(x[1]))
def run_eval(FLAGS):
with open(FLAGS.sim_index) as f:
for line in f:
segs = line.rstrip().split('\t')
if len(segs) != 2:
continue
k, vlist = segs
sim_index[int(k)] = [parse_sim_item_pair(x) for x in vlist.split(',')]
statis = []
for line in sys.stdin:
line = line.strip()
segs = line.split('\t')
uid = segs[0]
session = segs[1][1:-1]
if not session:
continue
session_list = [parse_session_item_pair(x) for x in session.split(',')]
score_list = {}
for item_id, wei in session_list[1:1+max_fea]:
for sim_item_id, sim_value in sim_index.get(item_id, []):
score_list.setdefault(sim_item_id, 0.0)
score_list[sim_item_id] += wei*sim_value
score_list.items()
sorted_score_list = sorted(score_list.items(), key = lambda k:k[1], reverse=True)[:max_recall_len]
target_item_id = session_list[0][0]
hit_pos = -1
for idx, (k, v) in enumerate(sorted_score_list):
if target_item_id == k:
hit_pos = idx
break
if hit_pos == -1 or hit_pos > max_recall_len:
hit_pos = max_recall_len
info = (1, hit_pos, len(sorted_score_list),
int(hit_pos < 25),
int(hit_pos < 50),
int(hit_pos < 100),
int(hit_pos < 200),
int(hit_pos < 400),
int(hit_pos < 800),
int(hit_pos < max_recall_len),
)
statis.append(info)
statis = np.array(statis)
desc = '''(1, hit_pos, len(sorted_score_list),
int(hit_pos != -1 and hit_pos < 25),
int(hit_pos != -1 and hit_pos < 50),
int(hit_pos != -1 and hit_pos < 100),
int(hit_pos != -1 and hit_pos < 200),
int(hit_pos != -1 and hit_pos < 400),
int(hit_pos != -1 and hit_pos < 800),
int(hit_pos != -1),
)'''
print(desc)
np.set_printoptions(suppress=True)
print(FLAGS.sim_index, 'mean', '\t'.join([str(x) for x in statis.mean(axis=0)]), sep='\t')
print(FLAGS.sim_index, 'sum', '\t'.join([str(x) for x in statis.sum(axis=0)]), sep='\t')
def main():
cgitb.enable(format='text')
# op config
parser = ArgumentParser()
para_define(parser)
FLAGS, unparsed = parser.parse_known_args()
print(FLAGS)
run_eval(FLAGS)
if __name__ == "__main__":
main()
|