Blame view

offline_tasks/scripts/load_index_to_redis.py 6.32 KB
5ab1c29c   tangwang   first commit
1
2
3
4
  """
  将生成的索引加载到Redis
  用于在线推荐系统查询
  """
5ab1c29c   tangwang   first commit
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  import redis
  import argparse
  import logging
  from datetime import datetime
  from offline_tasks.config.offline_config import REDIS_CONFIG, OUTPUT_DIR
  
  logging.basicConfig(
      level=logging.INFO,
      format='%(asctime)s - %(levelname)s - %(message)s'
  )
  logger = logging.getLogger(__name__)
  
  
  def load_index_file(file_path, redis_client, key_prefix, expire_seconds=None):
      """
      加载索引文件到Redis
      
      Args:
          file_path: 索引文件路径
          redis_client: Redis客户端
          key_prefix: Redis key前缀
          expire_seconds: 过期时间(秒),None表示不过期
      
      Returns:
          加载的记录数
      """
      if not os.path.exists(file_path):
          logger.error(f"File not found: {file_path}")
          return 0
      
      count = 0
      with open(file_path, 'r', encoding='utf-8') as f:
          for line in f:
              line = line.strip()
              if not line:
                  continue
              
              parts = line.split('\t')
              if len(parts) != 2:
                  logger.warning(f"Invalid line format: {line}")
                  continue
              
              key_suffix, value = parts
              redis_key = f"{key_prefix}:{key_suffix}"
              
              # 存储到Redis
              redis_client.set(redis_key, value)
              
              # 设置过期时间
              if expire_seconds:
                  redis_client.expire(redis_key, expire_seconds)
              
              count += 1
              
              if count % 1000 == 0:
                  logger.info(f"Loaded {count} records...")
      
      return count
  
  
  def load_i2i_indices(redis_client, date_str=None, expire_days=7):
      """
      加载i2i相似度索引
      
      Args:
          redis_client: Redis客户端
          date_str: 日期字符串,格式YYYYMMDDNone表示使用今天
          expire_days: 过期天数
      """
      if not date_str:
          date_str = datetime.now().strftime('%Y%m%d')
      
      expire_seconds = expire_days * 24 * 3600 if expire_days else None
      
      # i2i索引类型
b57c6eb4   tangwang   offline tasks: fi...
80
      i2i_types = ['swing', 'session_w2v', 'deepwalk', 'content_name', 'content_pic']
5ab1c29c   tangwang   first commit
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
      
      for i2i_type in i2i_types:
          file_path = os.path.join(OUTPUT_DIR, f'i2i_{i2i_type}_{date_str}.txt')
          
          if not os.path.exists(file_path):
              logger.warning(f"File not found: {file_path}, skipping...")
              continue
          
          logger.info(f"Loading {i2i_type} indices...")
          count = load_index_file(
              file_path,
              redis_client,
              f"i2i:{i2i_type}",
              expire_seconds
          )
          logger.info(f"Loaded {count} {i2i_type} indices")
  
  
  def load_interest_indices(redis_client, date_str=None, expire_days=7):
      """
      加载兴趣点聚合索引
      
      Args:
          redis_client: Redis客户端
          date_str: 日期字符串,格式YYYYMMDDNone表示使用今天
          expire_days: 过期天数
      """
      if not date_str:
          date_str = datetime.now().strftime('%Y%m%d')
      
      expire_seconds = expire_days * 24 * 3600 if expire_days else None
      
      # 兴趣点索引类型
      list_types = ['hot', 'cart', 'new', 'global']
      
      for list_type in list_types:
          file_path = os.path.join(OUTPUT_DIR, f'interest_aggregation_{list_type}_{date_str}.txt')
          
          if not os.path.exists(file_path):
              logger.warning(f"File not found: {file_path}, skipping...")
              continue
          
          logger.info(f"Loading {list_type} interest indices...")
          count = load_index_file(
              file_path,
              redis_client,
              f"interest:{list_type}",
              expire_seconds
          )
          logger.info(f"Loaded {count} {list_type} indices")
  
  
  def main():
      parser = argparse.ArgumentParser(description='Load recommendation indices to Redis')
      parser.add_argument('--redis-host', type=str, default=REDIS_CONFIG.get('host', 'localhost'),
                         help='Redis host')
      parser.add_argument('--redis-port', type=int, default=REDIS_CONFIG.get('port', 6379),
                         help='Redis port')
      parser.add_argument('--redis-db', type=int, default=REDIS_CONFIG.get('db', 0),
                         help='Redis database')
      parser.add_argument('--redis-password', type=str, default=REDIS_CONFIG.get('password'),
                         help='Redis password')
      parser.add_argument('--date', type=str, default=None,
                         help='Date string (YYYYMMDD), default is today')
      parser.add_argument('--expire-days', type=int, default=7,
                         help='Expire days for Redis keys')
      parser.add_argument('--load-i2i', action='store_true', default=True,
                         help='Load i2i indices')
      parser.add_argument('--load-interest', action='store_true', default=True,
                         help='Load interest indices')
      parser.add_argument('--flush-db', action='store_true',
                         help='Flush database before loading (危险操作!)')
      
      args = parser.parse_args()
      
      # 创建Redis连接
      logger.info("Connecting to Redis...")
      redis_client = redis.Redis(
          host=args.redis_host,
          port=args.redis_port,
          db=args.redis_db,
          password=args.redis_password,
          decode_responses=True
      )
      
      # 测试连接
      try:
          redis_client.ping()
          logger.info("Redis connection successful")
      except Exception as e:
          logger.error(f"Failed to connect to Redis: {e}")
          return 1
      
      # Flush数据库(如果需要)
      if args.flush_db:
          logger.warning("Flushing Redis database...")
          redis_client.flushdb()
          logger.info("Database flushed")
      
      # 加载i2i索引
      if args.load_i2i:
          logger.info("\n" + "="*80)
          logger.info("Loading i2i indices")
          logger.info("="*80)
          load_i2i_indices(redis_client, args.date, args.expire_days)
      
      # 加载兴趣点索引
      if args.load_interest:
          logger.info("\n" + "="*80)
          logger.info("Loading interest aggregation indices")
          logger.info("="*80)
          load_interest_indices(redis_client, args.date, args.expire_days)
      
      logger.info("\n" + "="*80)
      logger.info("All indices loaded successfully!")
      logger.info("="*80)
      
      return 0
  
  
  if __name__ == '__main__':
      sys.exit(main())