load_index_to_redis.py
8.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
"""
将生成的索引加载到Redis
用于在线推荐系统查询
"""
import redis
import argparse
import logging
import os
import sys
from datetime import datetime
from config.offline_config import REDIS_CONFIG, OUTPUT_DIR
def setup_logger():
"""设置logger配置"""
# 创建logs目录
logs_dir = 'logs'
os.makedirs(logs_dir, exist_ok=True)
# 创建logger
logger = logging.getLogger('load_index_to_redis')
logger.setLevel(logging.INFO)
# 避免重复添加handler
if logger.handlers:
return logger
# 创建文件handler
log_file = os.path.join(logs_dir, f'load_index_to_redis_{datetime.now().strftime("%Y%m%d")}.log')
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setLevel(logging.INFO)
# 创建控制台handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 创建formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# 添加handler到logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
logger = setup_logger()
def load_index_file(file_path, redis_client, key_prefix, expire_seconds=None):
"""
加载索引文件到Redis
Args:
file_path: 索引文件路径
redis_client: Redis客户端
key_prefix: Redis key前缀
expire_seconds: 过期时间(秒),None表示不过期
Returns:
加载的记录数
"""
if not os.path.exists(file_path):
logger.error(f"File not found: {file_path}")
return 0
count = 0
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split('\t')
if len(parts) < 2:
logger.warning(f"Invalid line format (expected at least 2 fields): {line}")
continue
# 支持2字段和3字段格式
# 格式1 (2字段): item_id \t similar_items
# 格式2 (3字段): item_id \t item_name \t similar_items (推荐格式)
# 取第一个字段作为key,最后一个字段作为value
key_suffix = parts[0]
# 修复:将浮点数ID转换为整数(如 "60678.0" -> "60678")
try:
if '.' in key_suffix:
key_suffix = str(int(float(key_suffix)))
except (ValueError, OverflowError):
# 如果转换失败,保持原样
pass
value = parts[-1]
redis_key = f"{key_prefix}:{key_suffix}"
# 存储到Redis
redis_client.set(redis_key, value)
# 设置过期时间
if expire_seconds:
redis_client.expire(redis_key, expire_seconds)
count += 1
if count % 1000 == 0:
logger.info(f"Loaded {count} records...")
return count
def load_cpp_swing_index(redis_client, expire_days=7):
"""
加载C++ Swing相似度索引
Args:
redis_client: Redis客户端
expire_days: 过期天数
Returns:
加载的记录数
"""
# C++ Swing输出文件
file_path = os.path.join(os.path.dirname(OUTPUT_DIR), 'collaboration', 'output', 'swing_similar.txt')
if not os.path.exists(file_path):
logger.warning(f"C++ Swing file not found: {file_path}, skipping...")
return 0
expire_seconds = expire_days * 24 * 3600 if expire_days else None
logger.info(f"Loading C++ Swing indices from {file_path}...")
count = load_index_file(
file_path,
redis_client,
"item:similar:swing_cpp",
expire_seconds
)
logger.info(f"Loaded {count} C++ Swing indices")
return count
def load_i2i_indices(redis_client, date_str=None, expire_days=7):
"""
加载i2i相似度索引
Args:
redis_client: Redis客户端
date_str: 日期字符串,格式YYYYMMDD,None表示使用今天
expire_days: 过期天数
"""
if not date_str:
date_str = datetime.now().strftime('%Y%m%d')
expire_seconds = expire_days * 24 * 3600 if expire_days else None
# i2i索引类型
i2i_types = ['swing', 'session_w2v', 'deepwalk', 'content_name', 'content_pic', 'item_behavior']
for i2i_type in i2i_types:
file_path = os.path.join(OUTPUT_DIR, f'i2i_{i2i_type}_{date_str}.txt')
if not os.path.exists(file_path):
logger.warning(f"File not found: {file_path}, skipping...")
continue
logger.info(f"Loading {i2i_type} indices...")
count = load_index_file(
file_path,
redis_client,
f"item:similar:{i2i_type}", # 修复: 使用正确的key前缀
expire_seconds
)
logger.info(f"Loaded {count} {i2i_type} indices")
def load_interest_indices(redis_client, date_str=None, expire_days=7):
"""
加载兴趣点聚合索引
Args:
redis_client: Redis客户端
date_str: 日期字符串,格式YYYYMMDD,None表示使用今天
expire_days: 过期天数
"""
if not date_str:
date_str = datetime.now().strftime('%Y%m%d')
expire_seconds = expire_days * 24 * 3600 if expire_days else None
# 兴趣点索引类型
list_types = ['hot', 'cart', 'new', 'global']
for list_type in list_types:
file_path = os.path.join(OUTPUT_DIR, f'interest_aggregation_{list_type}_{date_str}.txt')
if not os.path.exists(file_path):
logger.warning(f"File not found: {file_path}, skipping...")
continue
logger.info(f"Loading {list_type} interest indices...")
count = load_index_file(
file_path,
redis_client,
f"interest:{list_type}",
expire_seconds
)
logger.info(f"Loaded {count} {list_type} indices")
def main():
parser = argparse.ArgumentParser(description='Load recommendation indices to Redis')
parser.add_argument('--redis-host', type=str, default=REDIS_CONFIG.get('host', 'localhost'),
help='Redis host')
parser.add_argument('--redis-port', type=int, default=REDIS_CONFIG.get('port', 6379),
help='Redis port')
parser.add_argument('--redis-db', type=int, default=REDIS_CONFIG.get('db', 0),
help='Redis database')
parser.add_argument('--redis-password', type=str, default=REDIS_CONFIG.get('password'),
help='Redis password')
parser.add_argument('--date', type=str, default=None,
help='Date string (YYYYMMDD), default is today')
parser.add_argument('--expire-days', type=int, default=7,
help='Expire days for Redis keys')
parser.add_argument('--load-i2i', action='store_true', default=True,
help='Load i2i indices')
parser.add_argument('--load-interest', action='store_true', default=True,
help='Load interest indices')
parser.add_argument('--flush-db', action='store_true',
help='Flush database before loading (危险操作!)')
args = parser.parse_args()
# 创建Redis连接
logger.info("Connecting to Redis...")
redis_client = redis.Redis(
host=args.redis_host,
port=args.redis_port,
db=args.redis_db,
password=args.redis_password,
decode_responses=True
)
# 测试连接
try:
redis_client.ping()
logger.info("Redis connection successful")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
return 1
# Flush数据库(如果需要)
if args.flush_db:
logger.warning("Flushing Redis database...")
redis_client.flushdb()
logger.info("Database flushed")
# 加载C++ Swing索引
if args.load_i2i:
logger.info("\n" + "="*80)
logger.info("Loading C++ Swing indices")
logger.info("="*80)
load_cpp_swing_index(redis_client, args.expire_days)
# 加载i2i索引
if args.load_i2i:
logger.info("\n" + "="*80)
logger.info("Loading i2i indices")
logger.info("="*80)
load_i2i_indices(redis_client, args.date, args.expire_days)
# 加载兴趣点索引
if args.load_interest:
logger.info("\n" + "="*80)
logger.info("Loading interest aggregation indices")
logger.info("="*80)
load_interest_indices(redis_client, args.date, args.expire_days)
logger.info("\n" + "="*80)
logger.info("All indices loaded successfully!")
logger.info("="*80)
return 0
if __name__ == '__main__':
sys.exit(main())