Blame view

offline_tasks/scripts/debug_utils.py 11.7 KB
1721766b   tangwang   offline tasks
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
  """
  调试工具模块
  提供debug日志和明文输出功能
  """
  import os
  import json
  import logging
  from datetime import datetime
  
  
  def setup_debug_logger(script_name, debug=False):
      """
      设置debug日志记录器
      
      Args:
          script_name: 脚本名称
          debug: 是否开启debug模式
      
      Returns:
          logger对象
      """
      logger = logging.getLogger(script_name)
      
      # 清除已有的handlers
      logger.handlers.clear()
      
      # 设置日志级别
      if debug:
          logger.setLevel(logging.DEBUG)
      else:
          logger.setLevel(logging.INFO)
      
      # 控制台输出
      console_handler = logging.StreamHandler()
      console_handler.setLevel(logging.DEBUG if debug else logging.INFO)
      console_format = logging.Formatter(
          '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
          datefmt='%Y-%m-%d %H:%M:%S'
      )
      console_handler.setFormatter(console_format)
      logger.addHandler(console_handler)
      
      # 文件输出(如果开启debug)
      if debug:
          log_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs', 'debug')
          os.makedirs(log_dir, exist_ok=True)
          
          log_file = os.path.join(
              log_dir, 
              f"{script_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
          )
          file_handler = logging.FileHandler(log_file, encoding='utf-8')
          file_handler.setLevel(logging.DEBUG)
          file_handler.setFormatter(console_format)
          logger.addHandler(file_handler)
          
          logger.debug(f"Debug log file: {log_file}")
      
      return logger
  
  
  def log_dataframe_info(logger, df, name="DataFrame", sample_size=5):
      """
      记录DataFrame的详细信息
      
      Args:
          logger: logger对象
          df: pandas DataFrame
          name: 数据名称
          sample_size: 采样大小
      """
      logger.debug(f"\n{'='*60}")
      logger.debug(f"{name} 信息:")
      logger.debug(f"{'='*60}")
      logger.debug(f"总行数: {len(df)}")
      logger.debug(f"总列数: {len(df.columns)}")
      logger.debug(f"列名: {list(df.columns)}")
      
      # 数据类型
      logger.debug(f"\n数据类型:")
      for col, dtype in df.dtypes.items():
          logger.debug(f"  {col}: {dtype}")
      
      # 缺失值统计
      null_counts = df.isnull().sum()
      if null_counts.sum() > 0:
          logger.debug(f"\n缺失值统计:")
          for col, count in null_counts[null_counts > 0].items():
              logger.debug(f"  {col}: {count} ({count/len(df)*100:.2f}%)")
      
      # 基本统计
      if len(df) > 0:
          logger.debug(f"\n前{sample_size}行示例:")
          logger.debug(f"\n{df.head(sample_size).to_string()}")
          
          # 数值列的统计
          numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
          if len(numeric_cols) > 0:
              logger.debug(f"\n数值列统计:")
              logger.debug(f"\n{df[numeric_cols].describe().to_string()}")
      
      logger.debug(f"{'='*60}\n")
  
  
  def log_dict_stats(logger, data_dict, name="Dictionary", top_n=10):
      """
      记录字典的统计信息
      
      Args:
          logger: logger对象
          data_dict: 字典数据
          name: 数据名称
          top_n: 显示前N个元素
      """
      logger.debug(f"\n{'='*60}")
      logger.debug(f"{name} 统计:")
      logger.debug(f"{'='*60}")
      logger.debug(f"总元素数: {len(data_dict)}")
      
      if len(data_dict) > 0:
          # 如果值是列表或可计数的
          try:
              item_counts = {k: len(v) if hasattr(v, '__len__') else 1 
                            for k, v in list(data_dict.items())[:1000]}  # 采样
              if item_counts:
                  total_items = sum(item_counts.values())
                  avg_items = total_items / len(item_counts)
                  logger.debug(f"平均每个key的元素数: {avg_items:.2f}")
          except:
              pass
          
          # 显示前N个示例
          logger.debug(f"\n前{top_n}个示例:")
          for i, (k, v) in enumerate(list(data_dict.items())[:top_n]):
              if isinstance(v, list):
                  logger.debug(f"  {k}: {v[:3]}... (total: {len(v)})")
              elif isinstance(v, dict):
                  logger.debug(f"  {k}: {dict(list(v.items())[:3])}... (total: {len(v)})")
              else:
                  logger.debug(f"  {k}: {v}")
      
      logger.debug(f"{'='*60}\n")
  
  
  def save_readable_index(output_file, index_data, name_mappings, description=""):
      """
      保存可读的明文索引文件
      
      Args:
          output_file: 输出文件路径
          index_data: 索引数据 {item_id: [(similar_id, score), ...]}
          name_mappings: 名称映射 {
              'item': {id: name},
              'category': {id: name},
              'platform': {id: name},
              ...
          }
          description: 描述信息
      """
      debug_dir = os.path.join(os.path.dirname(output_file), 'debug')
      os.makedirs(debug_dir, exist_ok=True)
      
      # 生成明文文件名
      base_name = os.path.basename(output_file)
      name_without_ext = os.path.splitext(base_name)[0]
      readable_file = os.path.join(debug_dir, f"{name_without_ext}_readable.txt")
      
      with open(readable_file, 'w', encoding='utf-8') as f:
          # 写入描述信息
          f.write("="*80 + "\n")
          f.write(f"明文索引文件\n")
          f.write(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
          if description:
              f.write(f"描述: {description}\n")
          f.write(f"总索引数: {len(index_data)}\n")
          f.write("="*80 + "\n\n")
          
          # 遍历索引数据
          for idx, (key, items) in enumerate(index_data.items(), 1):
              # 解析key并添加名称
              readable_key = format_key_with_name(key, name_mappings)
              
              f.write(f"\n[{idx}] {readable_key}\n")
              f.write("-" * 80 + "\n")
              
              # 解析items
              if isinstance(items, list):
                  for i, item in enumerate(items, 1):
                      if isinstance(item, tuple) and len(item) >= 2:
                          item_id, score = item[0], item[1]
                          item_name = name_mappings.get('item', {}).get(str(item_id), 'Unknown')
                          f.write(f"  {i}. ID:{item_id}({item_name}) - Score:{score:.4f}\n")
                      else:
                          item_name = name_mappings.get('item', {}).get(str(item), 'Unknown')
                          f.write(f"  {i}. ID:{item}({item_name})\n")
              elif isinstance(items, dict):
                  for i, (item_id, score) in enumerate(items.items(), 1):
                      item_name = name_mappings.get('item', {}).get(str(item_id), 'Unknown')
                      f.write(f"  {i}. ID:{item_id}({item_name}) - Score:{score:.4f}\n")
              else:
                  f.write(f"  {items}\n")
              
              # 每50个索引添加分隔
              if idx % 50 == 0:
                  f.write("\n" + "="*80 + "\n")
                  f.write(f"已输出 {idx}/{len(index_data)} 个索引\n")
                  f.write("="*80 + "\n")
      
      return readable_file
  
  
  def format_key_with_name(key, name_mappings):
      """
      格式化key,添加名称信息
      
      Args:
          key: 原始key ( "interest:hot:platform:1"  "i2i:swing:12345")
          name_mappings: 名称映射字典
      
      Returns:
          格式化后的key字符串
      """
      if ':' not in str(key):
          # 简单的item_id
          item_name = name_mappings.get('item', {}).get(str(key), '')
          return f"{key}({item_name})" if item_name else str(key)
      
      parts = str(key).split(':')
      formatted_parts = []
      
      for i, part in enumerate(parts):
          # 尝试识别是否为ID
          if part.isdigit():
              # 根据前一个部分判断类型
              if i > 0:
                  prev_part = parts[i-1]
                  if 'category' in prev_part or 'level' in prev_part:
                      name = name_mappings.get('category', {}).get(part, '')
                      formatted_parts.append(f"{part}({name})" if name else part)
                  elif 'platform' in prev_part:
                      name = name_mappings.get('platform', {}).get(part, '')
                      formatted_parts.append(f"{part}({name})" if name else part)
                  elif 'supplier' in prev_part:
                      name = name_mappings.get('supplier', {}).get(part, '')
                      formatted_parts.append(f"{part}({name})" if name else part)
                  else:
                      # 可能是item_id
                      name = name_mappings.get('item', {}).get(part, '')
                      formatted_parts.append(f"{part}({name})" if name else part)
              else:
                  formatted_parts.append(part)
          else:
              formatted_parts.append(part)
      
      return ':'.join(formatted_parts)
  
  
  def fetch_name_mappings(engine, debug=False):
      """
      从数据库获取ID到名称的映射
      
      Args:
          engine: 数据库连接
          debug: 是否输出debug信息
      
      Returns:
          name_mappings字典
      """
      import pandas as pd
      
      mappings = {
          'item': {},
          'category': {},
          'platform': {},
          'supplier': {},
          'client_platform': {}
      }
      
      try:
          # 获取商品名称
a1f370ee   tangwang   offline tasks
281
          query = "SELECT id, name FROM prd_goods_sku WHERE status IN (2,4,5) LIMIT 5000000"
1721766b   tangwang   offline tasks
282
283
284
285
286
287
288
289
290
291
          df = pd.read_sql(query, engine)
          mappings['item'] = dict(zip(df['id'].astype(str), df['name']))
          if debug:
              print(f"✓ 获取到 {len(mappings['item'])} 个商品名称")
      except Exception as e:
          if debug:
              print(f"✗ 获取商品名称失败: {e}")
      
      try:
          # 获取分类名称
a1f370ee   tangwang   offline tasks
292
          query = "SELECT id, name FROM prd_category LIMIT 100000"
1721766b   tangwang   offline tasks
293
294
295
296
297
298
299
300
301
302
          df = pd.read_sql(query, engine)
          mappings['category'] = dict(zip(df['id'].astype(str), df['name']))
          if debug:
              print(f"✓ 获取到 {len(mappings['category'])} 个分类名称")
      except Exception as e:
          if debug:
              print(f"✗ 获取分类名称失败: {e}")
      
      try:
          # 获取供应商名称
a1f370ee   tangwang   offline tasks
303
          query = "SELECT id, name FROM sup_supplier LIMIT 100000"
1721766b   tangwang   offline tasks
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
          df = pd.read_sql(query, engine)
          mappings['supplier'] = dict(zip(df['id'].astype(str), df['name']))
          if debug:
              print(f"✓ 获取到 {len(mappings['supplier'])} 个供应商名称")
      except Exception as e:
          if debug:
              print(f"✗ 获取供应商名称失败: {e}")
      
      # 平台名称(硬编码常见值)
      mappings['platform'] = {
          'pc': 'PC端',
          'h5': 'H5移动端',
          'app': 'APP',
          'miniprogram': '小程序',
          'wechat': '微信'
      }
      
      mappings['client_platform'] = {
          'iOS': 'iOS',
          'Android': 'Android',
          'Web': 'Web',
          'H5': 'H5'
      }
      
      return mappings
  
  
  def log_algorithm_params(logger, params_dict):
      """
      记录算法参数
      
      Args:
          logger: logger对象
          params_dict: 参数字典
      """
      logger.debug(f"\n{'='*60}")
      logger.debug("算法参数:")
      logger.debug(f"{'='*60}")
      for key, value in params_dict.items():
          logger.debug(f"  {key}: {value}")
      logger.debug(f"{'='*60}\n")
  
  
  def log_processing_step(logger, step_name, start_time=None):
      """
      记录处理步骤
      
      Args:
          logger: logger对象
          step_name: 步骤名称
          start_time: 开始时间(如果提供,会计算耗时)
      """
      from datetime import datetime
      current_time = datetime.now()
      
      logger.debug(f"\n{'='*60}")
      logger.debug(f"处理步骤: {step_name}")
      logger.debug(f"时间: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
      
      if start_time:
          elapsed = (current_time - start_time).total_seconds()
          logger.debug(f"耗时: {elapsed:.2f}秒")
      
      logger.debug(f"{'='*60}\n")