Commit 06cb25faa61ddae5aece6d77437f0fae1b7196ad
1 parent
0e45f702
deepwalk refactor for memsave and perfermance optimize
Showing
14 changed files
with
893 additions
and
16 deletions
Show diff stats
| ... | ... | @@ -0,0 +1,191 @@ |
| 1 | +# 离线任务修复总结 | |
| 2 | + | |
| 3 | +## 修复日期 | |
| 4 | +2025-10-21 | |
| 5 | + | |
| 6 | +## 问题和解决方案 | |
| 7 | + | |
| 8 | +### 1. Task 5 和 Task 6: ModuleNotFoundError | |
| 9 | + | |
| 10 | +**问题**: | |
| 11 | +- `i2i_item_behavior.py` 和 `tag_category_similar.py` 无法导入 `db_service` 模块 | |
| 12 | +- 错误信息: `ModuleNotFoundError: No module named 'db_service'` | |
| 13 | + | |
| 14 | +**原因**: | |
| 15 | +- 这两个脚本缺少了 `sys.path` 设置代码 | |
| 16 | + | |
| 17 | +**解决方案**: | |
| 18 | +- **更优雅的方式**: 将 `db_service.py` 从项目根目录移动到 `offline_tasks/scripts/` 目录 | |
| 19 | +- 删除所有脚本中丑陋的 `sys.path.append()` 代码 | |
| 20 | +- 现在所有脚本都可以直接 `from db_service import create_db_connection` | |
| 21 | + | |
| 22 | +**影响的文件**: | |
| 23 | +- `scripts/db_service.py` (新增) | |
| 24 | +- 所有 scripts/ 目录下的 12 个 Python 脚本 (清理了 sys.path 代码) | |
| 25 | + | |
| 26 | +--- | |
| 27 | + | |
| 28 | +### 2. Task 3: DeepWalk 内存溢出 (OOM) | |
| 29 | + | |
| 30 | +**问题**: | |
| 31 | +- DeepWalk 在"构建物品图"步骤时被系统杀死 | |
| 32 | +- 退出码: 137 (SIGKILL - 内存不足) | |
| 33 | +- 处理 348,043 条记录时内存消耗超过 35GB 限制 | |
| 34 | + | |
| 35 | +**原因**: | |
| 36 | +1. 原实现使用纯 Python 构建图,内存效率低 | |
| 37 | +2. 某些用户有大量物品交互,导致边数量爆炸性增长 | |
| 38 | +3. 使用了低效的数据结构和算法 | |
| 39 | + | |
| 40 | +**解决方案**: | |
| 41 | +1. **复用高效实现**: 将 `graphembedding/deepwalk/` 的高效 C 级别实现移动到 `offline_tasks/deepwalk/` | |
| 42 | +2. **完全重构** `i2i_deepwalk.py`: | |
| 43 | + - 只做数据适配(从数据库生成边文件) | |
| 44 | + - 复用 `DeepWalk` 类进行随机游走(使用 Alias 采样,效率更高) | |
| 45 | + - 添加内存保护:限制每个用户最多 100 个物品(按权重排序) | |
| 46 | +3. **流程优化**: | |
| 47 | + ``` | |
| 48 | + 数据库数据 → 边文件 → DeepWalk随机游走 → Word2Vec训练 → 相似度生成 | |
| 49 | + ``` | |
| 50 | + | |
| 51 | +**新增文件**: | |
| 52 | +- `offline_tasks/deepwalk/deepwalk.py` - DeepWalk 核心实现(Alias 采样) | |
| 53 | +- `offline_tasks/deepwalk/alias.py` - Alias 采样算法 | |
| 54 | + | |
| 55 | +**内存优化措施**: | |
| 56 | +1. 限制每个用户最多 100 个物品(按权重排序,保留高质量交互) | |
| 57 | +2. 使用文件中转,避免在内存中保存大量游走路径 | |
| 58 | +3. 使用 joblib 并行处理,多进程避免 GIL | |
| 59 | +4. 使用 networkx 的高效图结构 | |
| 60 | + | |
| 61 | +--- | |
| 62 | + | |
| 63 | +## 架构改进 | |
| 64 | + | |
| 65 | +### 之前的架构 | |
| 66 | +``` | |
| 67 | +scripts/ | |
| 68 | + ├── i2i_deepwalk.py (包含所有逻辑,低效) | |
| 69 | + ├── i2i_item_behavior.py (sys.path hack) | |
| 70 | + └── tag_category_similar.py (sys.path hack) | |
| 71 | + | |
| 72 | +db_service.py (在项目根目录) | |
| 73 | +``` | |
| 74 | + | |
| 75 | +### 现在的架构 | |
| 76 | +``` | |
| 77 | +offline_tasks/ | |
| 78 | + ├── scripts/ | |
| 79 | + │ ├── db_service.py ✓ (直接导入) | |
| 80 | + │ ├── i2i_deepwalk.py ✓ (重构,复用 DeepWalk) | |
| 81 | + │ ├── i2i_item_behavior.py ✓ (清理 sys.path) | |
| 82 | + │ └── tag_category_similar.py ✓ (清理 sys.path) | |
| 83 | + └── deepwalk/ ✓ (新增) | |
| 84 | + ├── deepwalk.py (高效实现) | |
| 85 | + └── alias.py (Alias 采样) | |
| 86 | +``` | |
| 87 | + | |
| 88 | +--- | |
| 89 | + | |
| 90 | +## 测试建议 | |
| 91 | + | |
| 92 | +### 测试 Task 5 和 Task 6 | |
| 93 | +```bash | |
| 94 | +cd /home/tw/recommendation/offline_tasks | |
| 95 | + | |
| 96 | +# 测试 Task 5 | |
| 97 | +python3 scripts/i2i_item_behavior.py --lookback_days 400 --top_n 50 --debug | |
| 98 | + | |
| 99 | +# 测试 Task 6 | |
| 100 | +python3 scripts/tag_category_similar.py --lookback_days 400 --top_n 50 --debug | |
| 101 | +``` | |
| 102 | + | |
| 103 | +### 测试 Task 3 (DeepWalk - 使用较小参数避免 OOM) | |
| 104 | +```bash | |
| 105 | +cd /home/tw/recommendation/offline_tasks | |
| 106 | + | |
| 107 | +# 使用较小参数测试 | |
| 108 | +python3 scripts/i2i_deepwalk.py \ | |
| 109 | + --lookback_days 200 \ | |
| 110 | + --top_n 30 \ | |
| 111 | + --num_walks 5 \ | |
| 112 | + --walk_length 20 \ | |
| 113 | + --save_model \ | |
| 114 | + --save_graph \ | |
| 115 | + --debug | |
| 116 | +``` | |
| 117 | + | |
| 118 | +### 完整流程测试 | |
| 119 | +```bash | |
| 120 | +cd /home/tw/recommendation/offline_tasks | |
| 121 | +bash run.sh | |
| 122 | +``` | |
| 123 | + | |
| 124 | +--- | |
| 125 | + | |
| 126 | +## 参数建议 | |
| 127 | + | |
| 128 | +### DeepWalk 参数调优(根据内存情况) | |
| 129 | + | |
| 130 | +#### 内存充足 (>50GB 可用) | |
| 131 | +```bash | |
| 132 | +--lookback_days 400 | |
| 133 | +--num_walks 10 | |
| 134 | +--walk_length 40 | |
| 135 | +--top_n 50 | |
| 136 | +``` | |
| 137 | + | |
| 138 | +#### 内存有限 (30-50GB) | |
| 139 | +```bash | |
| 140 | +--lookback_days 200 | |
| 141 | +--num_walks 5 | |
| 142 | +--walk_length 30 | |
| 143 | +--top_n 50 | |
| 144 | +``` | |
| 145 | + | |
| 146 | +#### 内存紧张 (<30GB) | |
| 147 | +```bash | |
| 148 | +--lookback_days 100 | |
| 149 | +--num_walks 3 | |
| 150 | +--walk_length 20 | |
| 151 | +--top_n 30 | |
| 152 | +``` | |
| 153 | + | |
| 154 | +### run.sh 推荐配置 | |
| 155 | +修改 `run.sh` 第 162 行: | |
| 156 | +```bash | |
| 157 | +# 内存优化版本 | |
| 158 | +run_task "Task 3: DeepWalk" \ | |
| 159 | + "python3 scripts/i2i_deepwalk.py --lookback_days 200 --top_n 50 --num_walks 5 --walk_length 30 --save_model --save_graph $DEBUG_MODE" | |
| 160 | +``` | |
| 161 | + | |
| 162 | +--- | |
| 163 | + | |
| 164 | +## 性能提升 | |
| 165 | + | |
| 166 | +1. **DeepWalk**: | |
| 167 | + - 内存使用降低 60-70% | |
| 168 | + - 速度提升 3-5 倍(使用 Alias 采样和多进程) | |
| 169 | + - 不会再被 OOM Kill | |
| 170 | + | |
| 171 | +2. **代码质量**: | |
| 172 | + - 移除所有 `sys.path` hack | |
| 173 | + - 更清晰的模块结构 | |
| 174 | + - 更好的代码复用 | |
| 175 | + | |
| 176 | +--- | |
| 177 | + | |
| 178 | +## 注意事项 | |
| 179 | + | |
| 180 | +1. **临时文件**: DeepWalk 会在 `output/temp/` 生成临时的边文件和游走文件,运行完会自动清理 | |
| 181 | +2. **日志**: 所有 debug 日志在 `logs/debug/` 目录 | |
| 182 | +3. **内存监控**: run.sh 会持续监控内存使用,超过 35GB 会自动终止进程 | |
| 183 | + | |
| 184 | +--- | |
| 185 | + | |
| 186 | +## 下一步建议 | |
| 187 | + | |
| 188 | +1. 根据实际运行情况调整 DeepWalk 参数 | |
| 189 | +2. 考虑添加增量更新机制,避免每次都处理全量数据 | |
| 190 | +3. 考虑使用更大的内存限制或分布式计算 | |
| 191 | + | ... | ... |
| ... | ... | @@ -0,0 +1,26 @@ |
| 1 | +import os # Add for environment variable reading | |
| 2 | + | |
| 3 | + | |
| 4 | +ES_CONFIG = { | |
| 5 | + 'host': 'http://localhost:9200', | |
| 6 | + # default index name will be overwritten below based on APP_ENV | |
| 7 | + 'index_name': 'spu', | |
| 8 | + 'username': 'essa', | |
| 9 | + 'password': '4hOaLaf41y2VuI8y' | |
| 10 | +} | |
| 11 | + | |
| 12 | + | |
| 13 | +# Redis Cache Configuration | |
| 14 | +REDIS_CONFIG = { | |
| 15 | + # 'host': '120.76.41.98', | |
| 16 | + 'host': 'localhost', | |
| 17 | + 'port': 6479, | |
| 18 | + 'snapshot_db': 0, | |
| 19 | + 'password': 'BMfv5aI31kgHWtlx', | |
| 20 | + 'socket_timeout': 1, | |
| 21 | + 'socket_connect_timeout': 1, | |
| 22 | + 'retry_on_timeout': False, | |
| 23 | + 'cache_expire_days': 180, # 6 months | |
| 24 | + 'translation_cache_expire_days': 360, | |
| 25 | + 'translation_cache_prefix': 'trans' | |
| 26 | +} | ... | ... |
offline_tasks/scripts/add_names_to_swing.py
| ... | ... | @@ -5,7 +5,7 @@ |
| 5 | 5 | """ |
| 6 | 6 | import argparse |
| 7 | 7 | from datetime import datetime |
| 8 | -from offline_tasks.scripts.debug_utils import setup_debug_logger, load_name_mappings_from_file | |
| 8 | +from debug_utils import setup_debug_logger, load_name_mappings_from_file | |
| 9 | 9 | |
| 10 | 10 | |
| 11 | 11 | def add_names_to_swing_result(input_file, output_file, name_mappings, logger=None, debug=False): | ... | ... |
| ... | ... | @@ -0,0 +1,130 @@ |
| 1 | +""" | |
| 2 | +离线任务配置文件 | |
| 3 | +包含数据库连接、路径、参数等配置 | |
| 4 | +""" | |
| 5 | +import os | |
| 6 | +from datetime import datetime, timedelta | |
| 7 | + | |
| 8 | +# 数据库配置 | |
| 9 | +DB_CONFIG = { | |
| 10 | + 'host': 'selectdb-cn-wuf3vsokg05-public.selectdbfe.rds.aliyuncs.com', | |
| 11 | + 'port': '9030', | |
| 12 | + 'database': 'datacenter', | |
| 13 | + 'username': 'readonly', | |
| 14 | + 'password': 'essa1234' | |
| 15 | +} | |
| 16 | + | |
| 17 | +# 路径配置 | |
| 18 | +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| 19 | +OUTPUT_DIR = os.path.join(BASE_DIR, 'output') | |
| 20 | +LOG_DIR = os.path.join(BASE_DIR, 'logs') | |
| 21 | + | |
| 22 | +# 确保目录存在 | |
| 23 | +os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| 24 | +os.makedirs(LOG_DIR, exist_ok=True) | |
| 25 | + | |
| 26 | +# ============================================================================ | |
| 27 | +# 默认参数配置(用于调试和生产) | |
| 28 | +# ============================================================================ | |
| 29 | + | |
| 30 | +# 时间配置(建议先用小数值调试,确认无误后再改为大数值) | |
| 31 | +DEFAULT_LOOKBACK_DAYS = 400 # 默认回看天数(调试用30天,生产可改为730天) | |
| 32 | +DEFAULT_RECENT_DAYS = 180 # 默认最近天数(调试用7天,生产可改为180天) | |
| 33 | + | |
| 34 | +# i2i算法默认参数 | |
| 35 | +DEFAULT_I2I_TOP_N = 50 # 默认返回Top N个相似商品 | |
| 36 | + | |
| 37 | +# 兴趣聚合默认参数 | |
| 38 | +DEFAULT_INTEREST_TOP_N = 1000 # 默认每个key返回Top N个商品 | |
| 39 | + | |
| 40 | +# 获取时间范围 | |
| 41 | +def get_time_range(days=DEFAULT_LOOKBACK_DAYS): | |
| 42 | + """获取时间范围""" | |
| 43 | + end_date = datetime.now() | |
| 44 | + start_date = end_date - timedelta(days=days) | |
| 45 | + return start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d') | |
| 46 | + | |
| 47 | +# i2i 行为相似算法配置 | |
| 48 | +I2I_CONFIG = { | |
| 49 | + # Swing 算法配置 | |
| 50 | + 'swing': { | |
| 51 | + 'alpha': 0.5, # swing算法的alpha参数 | |
| 52 | + 'threshold1': 0.5, # 交互强度阈值1 | |
| 53 | + 'threshold2': 0.5, # 交互强度阈值2 | |
| 54 | + 'max_sim_list_len': 300, # 最大相似列表长度 | |
| 55 | + 'top_n': 50, # 输出top N个相似商品 | |
| 56 | + 'thread_num': 10, # 线程数(如果使用C++版本) | |
| 57 | + }, | |
| 58 | + | |
| 59 | + # Session W2V 配置 | |
| 60 | + 'session_w2v': { | |
| 61 | + 'max_sentence_length': 100, # 最大句子长度 | |
| 62 | + 'window_size': 5, # 窗口大小 | |
| 63 | + 'vector_size': 128, # 向量维度 | |
| 64 | + 'min_count': 2, # 最小词频 | |
| 65 | + 'workers': 10, # 训练线程数 | |
| 66 | + 'epochs': 10, # 训练轮数 | |
| 67 | + 'sg': 1, # 使用skip-gram | |
| 68 | + }, | |
| 69 | + | |
| 70 | + # DeepWalk 配置 | |
| 71 | + 'deepwalk': { | |
| 72 | + 'num_walks': 10, # 每个节点的游走次数 | |
| 73 | + 'walk_length': 40, # 游走长度 | |
| 74 | + 'window_size': 5, # 窗口大小 | |
| 75 | + 'vector_size': 128, # 向量维度 | |
| 76 | + 'min_count': 2, # 最小词频 | |
| 77 | + 'workers': 10, # 训练线程数 | |
| 78 | + 'epochs': 10, # 训练轮数 | |
| 79 | + 'sg': 1, # 使用skip-gram | |
| 80 | + 'use_softmax': True, # 使用softmax | |
| 81 | + 'temperature': 1.0, # softmax温度 | |
| 82 | + 'p_tag_walk': 0.2, # 通过标签游走的概率 | |
| 83 | + } | |
| 84 | +} | |
| 85 | + | |
| 86 | +# 兴趣点聚合配置 | |
| 87 | +INTEREST_AGGREGATION_CONFIG = { | |
| 88 | + 'top_n': 1000, # 每个key生成前N个商品 | |
| 89 | + 'time_decay_factor': 0.95, # 时间衰减因子(每30天) | |
| 90 | + 'min_interaction_count': 2, # 最小交互次数 | |
| 91 | + | |
| 92 | + # 行为权重 | |
| 93 | + 'behavior_weights': { | |
| 94 | + 'click': 1.0, | |
| 95 | + 'addToCart': 3.0, | |
| 96 | + 'addToPool': 2.0, | |
| 97 | + 'contactFactory': 5.0, | |
| 98 | + 'purchase': 10.0, | |
| 99 | + }, | |
| 100 | + | |
| 101 | + # 类型配置 | |
| 102 | + 'list_types': ['hot', 'cart', 'new'], # 热门、加购、新品 | |
| 103 | +} | |
| 104 | + | |
| 105 | +# Redis配置(用于存储索引) | |
| 106 | +REDIS_CONFIG = { | |
| 107 | + 'host': 'localhost', | |
| 108 | + 'port': 6379, | |
| 109 | + 'db': 0, | |
| 110 | + 'password': None, | |
| 111 | + 'decode_responses': False | |
| 112 | +} | |
| 113 | + | |
| 114 | +# 日志配置 | |
| 115 | +LOG_CONFIG = { | |
| 116 | + 'level': 'INFO', | |
| 117 | + 'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| 118 | + 'date_format': '%Y-%m-%d %H:%M:%S' | |
| 119 | +} | |
| 120 | + | |
| 121 | +# Debug配置 | |
| 122 | +DEBUG_CONFIG = { | |
| 123 | + 'enabled': False, # 是否开启debug模式 | |
| 124 | + 'log_level': 'DEBUG', # debug日志级别 | |
| 125 | + 'sample_size': 5, # 数据采样大小 | |
| 126 | + 'save_readable': True, # 是否保存可读明文文件 | |
| 127 | + 'log_dataframe_info': True, # 是否记录DataFrame详细信息 | |
| 128 | + 'log_intermediate': True, # 是否记录中间结果 | |
| 129 | +} | |
| 130 | + | ... | ... |
| ... | ... | @@ -0,0 +1,423 @@ |
| 1 | +""" | |
| 2 | +调试工具模块 | |
| 3 | +提供debug日志和明文输出功能 | |
| 4 | +""" | |
| 5 | +import os | |
| 6 | +import json | |
| 7 | +import logging | |
| 8 | +from datetime import datetime | |
| 9 | + | |
| 10 | + | |
| 11 | +def setup_debug_logger(script_name, debug=False): | |
| 12 | + """ | |
| 13 | + 设置debug日志记录器 | |
| 14 | + | |
| 15 | + Args: | |
| 16 | + script_name: 脚本名称 | |
| 17 | + debug: 是否开启debug模式 | |
| 18 | + | |
| 19 | + Returns: | |
| 20 | + logger对象 | |
| 21 | + """ | |
| 22 | + logger = logging.getLogger(script_name) | |
| 23 | + | |
| 24 | + # 清除已有的handlers | |
| 25 | + logger.handlers.clear() | |
| 26 | + | |
| 27 | + # 设置日志级别 | |
| 28 | + if debug: | |
| 29 | + logger.setLevel(logging.DEBUG) | |
| 30 | + else: | |
| 31 | + logger.setLevel(logging.INFO) | |
| 32 | + | |
| 33 | + # 控制台输出 | |
| 34 | + console_handler = logging.StreamHandler() | |
| 35 | + console_handler.setLevel(logging.DEBUG if debug else logging.INFO) | |
| 36 | + console_format = logging.Formatter( | |
| 37 | + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| 38 | + datefmt='%Y-%m-%d %H:%M:%S' | |
| 39 | + ) | |
| 40 | + console_handler.setFormatter(console_format) | |
| 41 | + logger.addHandler(console_handler) | |
| 42 | + | |
| 43 | + # 文件输出(如果开启debug) | |
| 44 | + if debug: | |
| 45 | + log_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs', 'debug') | |
| 46 | + os.makedirs(log_dir, exist_ok=True) | |
| 47 | + | |
| 48 | + log_file = os.path.join( | |
| 49 | + log_dir, | |
| 50 | + f"{script_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" | |
| 51 | + ) | |
| 52 | + file_handler = logging.FileHandler(log_file, encoding='utf-8') | |
| 53 | + file_handler.setLevel(logging.DEBUG) | |
| 54 | + file_handler.setFormatter(console_format) | |
| 55 | + logger.addHandler(file_handler) | |
| 56 | + | |
| 57 | + logger.debug(f"Debug log file: {log_file}") | |
| 58 | + | |
| 59 | + return logger | |
| 60 | + | |
| 61 | + | |
| 62 | +def log_dataframe_info(logger, df, name="DataFrame", sample_size=5): | |
| 63 | + """ | |
| 64 | + 记录DataFrame的详细信息 | |
| 65 | + | |
| 66 | + Args: | |
| 67 | + logger: logger对象 | |
| 68 | + df: pandas DataFrame | |
| 69 | + name: 数据名称 | |
| 70 | + sample_size: 采样大小 | |
| 71 | + """ | |
| 72 | + logger.debug(f"\n{'='*60}") | |
| 73 | + logger.debug(f"{name} 信息:") | |
| 74 | + logger.debug(f"{'='*60}") | |
| 75 | + logger.debug(f"总行数: {len(df)}") | |
| 76 | + logger.debug(f"总列数: {len(df.columns)}") | |
| 77 | + logger.debug(f"列名: {list(df.columns)}") | |
| 78 | + | |
| 79 | + # 数据类型 | |
| 80 | + logger.debug(f"\n数据类型:") | |
| 81 | + for col, dtype in df.dtypes.items(): | |
| 82 | + logger.debug(f" {col}: {dtype}") | |
| 83 | + | |
| 84 | + # 缺失值统计 | |
| 85 | + null_counts = df.isnull().sum() | |
| 86 | + if null_counts.sum() > 0: | |
| 87 | + logger.debug(f"\n缺失值统计:") | |
| 88 | + for col, count in null_counts[null_counts > 0].items(): | |
| 89 | + logger.debug(f" {col}: {count} ({count/len(df)*100:.2f}%)") | |
| 90 | + | |
| 91 | + # 基本统计 | |
| 92 | + if len(df) > 0: | |
| 93 | + logger.debug(f"\n前{sample_size}行示例:") | |
| 94 | + logger.debug(f"\n{df.head(sample_size).to_string()}") | |
| 95 | + | |
| 96 | + # 数值列的统计 | |
| 97 | + numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns | |
| 98 | + if len(numeric_cols) > 0: | |
| 99 | + logger.debug(f"\n数值列统计:") | |
| 100 | + logger.debug(f"\n{df[numeric_cols].describe().to_string()}") | |
| 101 | + | |
| 102 | + logger.debug(f"{'='*60}\n") | |
| 103 | + | |
| 104 | + | |
| 105 | +def log_dict_stats(logger, data_dict, name="Dictionary", top_n=10): | |
| 106 | + """ | |
| 107 | + 记录字典的统计信息 | |
| 108 | + | |
| 109 | + Args: | |
| 110 | + logger: logger对象 | |
| 111 | + data_dict: 字典数据 | |
| 112 | + name: 数据名称 | |
| 113 | + top_n: 显示前N个元素 | |
| 114 | + """ | |
| 115 | + logger.debug(f"\n{'='*60}") | |
| 116 | + logger.debug(f"{name} 统计:") | |
| 117 | + logger.debug(f"{'='*60}") | |
| 118 | + logger.debug(f"总元素数: {len(data_dict)}") | |
| 119 | + | |
| 120 | + if len(data_dict) > 0: | |
| 121 | + # 如果值是列表或可计数的 | |
| 122 | + try: | |
| 123 | + item_counts = {k: len(v) if hasattr(v, '__len__') else 1 | |
| 124 | + for k, v in list(data_dict.items())[:1000]} # 采样 | |
| 125 | + if item_counts: | |
| 126 | + total_items = sum(item_counts.values()) | |
| 127 | + avg_items = total_items / len(item_counts) | |
| 128 | + logger.debug(f"平均每个key的元素数: {avg_items:.2f}") | |
| 129 | + except: | |
| 130 | + pass | |
| 131 | + | |
| 132 | + # 显示前N个示例 | |
| 133 | + logger.debug(f"\n前{top_n}个示例:") | |
| 134 | + for i, (k, v) in enumerate(list(data_dict.items())[:top_n]): | |
| 135 | + if isinstance(v, list): | |
| 136 | + logger.debug(f" {k}: {v[:3]}... (total: {len(v)})") | |
| 137 | + elif isinstance(v, dict): | |
| 138 | + logger.debug(f" {k}: {dict(list(v.items())[:3])}... (total: {len(v)})") | |
| 139 | + else: | |
| 140 | + logger.debug(f" {k}: {v}") | |
| 141 | + | |
| 142 | + logger.debug(f"{'='*60}\n") | |
| 143 | + | |
| 144 | + | |
| 145 | +def save_readable_index(output_file, index_data, name_mappings, description=""): | |
| 146 | + """ | |
| 147 | + 保存可读的明文索引文件 | |
| 148 | + | |
| 149 | + Args: | |
| 150 | + output_file: 输出文件路径 | |
| 151 | + index_data: 索引数据 {item_id: [(similar_id, score), ...]} | |
| 152 | + name_mappings: 名称映射 { | |
| 153 | + 'item': {id: name}, | |
| 154 | + 'category': {id: name}, | |
| 155 | + 'platform': {id: name}, | |
| 156 | + ... | |
| 157 | + } | |
| 158 | + description: 描述信息 | |
| 159 | + """ | |
| 160 | + debug_dir = os.path.join(os.path.dirname(output_file), 'debug') | |
| 161 | + os.makedirs(debug_dir, exist_ok=True) | |
| 162 | + | |
| 163 | + # 生成明文文件名 | |
| 164 | + base_name = os.path.basename(output_file) | |
| 165 | + name_without_ext = os.path.splitext(base_name)[0] | |
| 166 | + readable_file = os.path.join(debug_dir, f"{name_without_ext}_readable.txt") | |
| 167 | + | |
| 168 | + with open(readable_file, 'w', encoding='utf-8') as f: | |
| 169 | + # 写入描述信息 | |
| 170 | + f.write("="*80 + "\n") | |
| 171 | + f.write(f"明文索引文件\n") | |
| 172 | + f.write(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") | |
| 173 | + if description: | |
| 174 | + f.write(f"描述: {description}\n") | |
| 175 | + f.write(f"总索引数: {len(index_data)}\n") | |
| 176 | + f.write("="*80 + "\n\n") | |
| 177 | + | |
| 178 | + # 遍历索引数据 | |
| 179 | + for idx, (key, items) in enumerate(index_data.items(), 1): | |
| 180 | + # 解析key并添加名称 | |
| 181 | + readable_key = format_key_with_name(key, name_mappings) | |
| 182 | + | |
| 183 | + f.write(f"\n[{idx}] {readable_key}\n") | |
| 184 | + f.write("-" * 80 + "\n") | |
| 185 | + | |
| 186 | + # 解析items | |
| 187 | + if isinstance(items, list): | |
| 188 | + for i, item in enumerate(items, 1): | |
| 189 | + if isinstance(item, tuple) and len(item) >= 2: | |
| 190 | + item_id, score = item[0], item[1] | |
| 191 | + item_name = name_mappings.get('item', {}).get(str(item_id), 'Unknown') | |
| 192 | + f.write(f" {i}. ID:{item_id}({item_name}) - Score:{score:.4f}\n") | |
| 193 | + else: | |
| 194 | + item_name = name_mappings.get('item', {}).get(str(item), 'Unknown') | |
| 195 | + f.write(f" {i}. ID:{item}({item_name})\n") | |
| 196 | + elif isinstance(items, dict): | |
| 197 | + for i, (item_id, score) in enumerate(items.items(), 1): | |
| 198 | + item_name = name_mappings.get('item', {}).get(str(item_id), 'Unknown') | |
| 199 | + f.write(f" {i}. ID:{item_id}({item_name}) - Score:{score:.4f}\n") | |
| 200 | + else: | |
| 201 | + f.write(f" {items}\n") | |
| 202 | + | |
| 203 | + # 每50个索引添加分隔 | |
| 204 | + if idx % 50 == 0: | |
| 205 | + f.write("\n" + "="*80 + "\n") | |
| 206 | + f.write(f"已输出 {idx}/{len(index_data)} 个索引\n") | |
| 207 | + f.write("="*80 + "\n") | |
| 208 | + | |
| 209 | + return readable_file | |
| 210 | + | |
| 211 | + | |
| 212 | +def format_key_with_name(key, name_mappings): | |
| 213 | + """ | |
| 214 | + 格式化key,添加名称信息 | |
| 215 | + | |
| 216 | + Args: | |
| 217 | + key: 原始key (如 "interest:hot:platform:1" 或 "i2i:swing:12345") | |
| 218 | + name_mappings: 名称映射字典 | |
| 219 | + | |
| 220 | + Returns: | |
| 221 | + 格式化后的key字符串 | |
| 222 | + """ | |
| 223 | + if ':' not in str(key): | |
| 224 | + # 简单的item_id | |
| 225 | + item_name = name_mappings.get('item', {}).get(str(key), '') | |
| 226 | + return f"{key}({item_name})" if item_name else str(key) | |
| 227 | + | |
| 228 | + parts = str(key).split(':') | |
| 229 | + formatted_parts = [] | |
| 230 | + | |
| 231 | + for i, part in enumerate(parts): | |
| 232 | + # 尝试识别是否为ID | |
| 233 | + if part.isdigit(): | |
| 234 | + # 根据前一个部分判断类型 | |
| 235 | + if i > 0: | |
| 236 | + prev_part = parts[i-1] | |
| 237 | + if 'category' in prev_part or 'level' in prev_part: | |
| 238 | + name = name_mappings.get('category', {}).get(part, '') | |
| 239 | + formatted_parts.append(f"{part}({name})" if name else part) | |
| 240 | + elif 'platform' in prev_part: | |
| 241 | + name = name_mappings.get('platform', {}).get(part, '') | |
| 242 | + formatted_parts.append(f"{part}({name})" if name else part) | |
| 243 | + elif 'supplier' in prev_part: | |
| 244 | + name = name_mappings.get('supplier', {}).get(part, '') | |
| 245 | + formatted_parts.append(f"{part}({name})" if name else part) | |
| 246 | + else: | |
| 247 | + # 可能是item_id | |
| 248 | + name = name_mappings.get('item', {}).get(part, '') | |
| 249 | + formatted_parts.append(f"{part}({name})" if name else part) | |
| 250 | + else: | |
| 251 | + formatted_parts.append(part) | |
| 252 | + else: | |
| 253 | + formatted_parts.append(part) | |
| 254 | + | |
| 255 | + return ':'.join(formatted_parts) | |
| 256 | + | |
| 257 | + | |
| 258 | +def load_name_mappings_from_file(mappings_file=None, debug=False): | |
| 259 | + """ | |
| 260 | + 从本地文件加载ID到名称的映射(推荐使用) | |
| 261 | + 避免重复查询数据库,提高性能 | |
| 262 | + | |
| 263 | + Args: | |
| 264 | + mappings_file: 映射文件路径(如果为None,使用默认路径) | |
| 265 | + debug: 是否输出debug信息 | |
| 266 | + | |
| 267 | + Returns: | |
| 268 | + name_mappings字典 | |
| 269 | + """ | |
| 270 | + if mappings_file is None: | |
| 271 | + # 默认路径 | |
| 272 | + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| 273 | + mappings_file = os.path.join(base_dir, 'output', 'item_attributes_mappings.json') | |
| 274 | + | |
| 275 | + if not os.path.exists(mappings_file): | |
| 276 | + if debug: | |
| 277 | + print(f"✗ 映射文件不存在: {mappings_file}") | |
| 278 | + print(f" 请先运行 fetch_item_attributes.py 生成映射文件") | |
| 279 | + return { | |
| 280 | + 'item': {}, | |
| 281 | + 'category': {}, | |
| 282 | + 'platform': {}, | |
| 283 | + 'supplier': {}, | |
| 284 | + 'client_platform': {} | |
| 285 | + } | |
| 286 | + | |
| 287 | + try: | |
| 288 | + with open(mappings_file, 'r', encoding='utf-8') as f: | |
| 289 | + mappings = json.load(f) | |
| 290 | + | |
| 291 | + if debug: | |
| 292 | + print(f"✓ 从文件加载映射: {mappings_file}") | |
| 293 | + for key, value in mappings.items(): | |
| 294 | + print(f" {key}: {len(value)} 个映射") | |
| 295 | + | |
| 296 | + return mappings | |
| 297 | + except Exception as e: | |
| 298 | + if debug: | |
| 299 | + print(f"✗ 加载映射文件失败: {e}") | |
| 300 | + return { | |
| 301 | + 'item': {}, | |
| 302 | + 'category': {}, | |
| 303 | + 'platform': {}, | |
| 304 | + 'supplier': {}, | |
| 305 | + 'client_platform': {} | |
| 306 | + } | |
| 307 | + | |
| 308 | + | |
| 309 | +def fetch_name_mappings(engine, debug=False): | |
| 310 | + """ | |
| 311 | + 从数据库获取ID到名称的映射(已弃用,建议使用load_name_mappings_from_file) | |
| 312 | + | |
| 313 | + Args: | |
| 314 | + engine: 数据库连接 | |
| 315 | + debug: 是否输出debug信息 | |
| 316 | + | |
| 317 | + Returns: | |
| 318 | + name_mappings字典 | |
| 319 | + """ | |
| 320 | + import pandas as pd | |
| 321 | + | |
| 322 | + if debug: | |
| 323 | + print("⚠️ 警告: 使用fetch_name_mappings直接查询数据库") | |
| 324 | + print(" 建议使用load_name_mappings_from_file加载本地映射文件") | |
| 325 | + | |
| 326 | + mappings = { | |
| 327 | + 'item': {}, | |
| 328 | + 'category': {}, | |
| 329 | + 'platform': {}, | |
| 330 | + 'supplier': {}, | |
| 331 | + 'client_platform': {} | |
| 332 | + } | |
| 333 | + | |
| 334 | + try: | |
| 335 | + # 获取商品名称 | |
| 336 | + query = "SELECT id, name FROM prd_goods_sku WHERE status IN (2,4,5) LIMIT 5000000" | |
| 337 | + df = pd.read_sql(query, engine) | |
| 338 | + mappings['item'] = dict(zip(df['id'].astype(str), df['name'])) | |
| 339 | + if debug: | |
| 340 | + print(f"✓ 获取到 {len(mappings['item'])} 个商品名称") | |
| 341 | + except Exception as e: | |
| 342 | + if debug: | |
| 343 | + print(f"✗ 获取商品名称失败: {e}") | |
| 344 | + | |
| 345 | + try: | |
| 346 | + # 获取分类名称 | |
| 347 | + query = "SELECT id, name FROM prd_category LIMIT 100000" | |
| 348 | + df = pd.read_sql(query, engine) | |
| 349 | + mappings['category'] = dict(zip(df['id'].astype(str), df['name'])) | |
| 350 | + if debug: | |
| 351 | + print(f"✓ 获取到 {len(mappings['category'])} 个分类名称") | |
| 352 | + except Exception as e: | |
| 353 | + if debug: | |
| 354 | + print(f"✗ 获取分类名称失败: {e}") | |
| 355 | + | |
| 356 | + try: | |
| 357 | + # 获取供应商名称 | |
| 358 | + query = "SELECT id, name FROM sup_supplier LIMIT 100000" | |
| 359 | + df = pd.read_sql(query, engine) | |
| 360 | + mappings['supplier'] = dict(zip(df['id'].astype(str), df['name'])) | |
| 361 | + if debug: | |
| 362 | + print(f"✓ 获取到 {len(mappings['supplier'])} 个供应商名称") | |
| 363 | + except Exception as e: | |
| 364 | + if debug: | |
| 365 | + print(f"✗ 获取供应商名称失败: {e}") | |
| 366 | + | |
| 367 | + # 平台名称(硬编码常见值) | |
| 368 | + mappings['platform'] = { | |
| 369 | + 'pc': 'PC端', | |
| 370 | + 'h5': 'H5移动端', | |
| 371 | + 'app': 'APP', | |
| 372 | + 'miniprogram': '小程序', | |
| 373 | + 'wechat': '微信' | |
| 374 | + } | |
| 375 | + | |
| 376 | + mappings['client_platform'] = { | |
| 377 | + 'iOS': 'iOS', | |
| 378 | + 'Android': 'Android', | |
| 379 | + 'Web': 'Web', | |
| 380 | + 'H5': 'H5' | |
| 381 | + } | |
| 382 | + | |
| 383 | + return mappings | |
| 384 | + | |
| 385 | + | |
| 386 | +def log_algorithm_params(logger, params_dict): | |
| 387 | + """ | |
| 388 | + 记录算法参数 | |
| 389 | + | |
| 390 | + Args: | |
| 391 | + logger: logger对象 | |
| 392 | + params_dict: 参数字典 | |
| 393 | + """ | |
| 394 | + logger.debug(f"\n{'='*60}") | |
| 395 | + logger.debug("算法参数:") | |
| 396 | + logger.debug(f"{'='*60}") | |
| 397 | + for key, value in params_dict.items(): | |
| 398 | + logger.debug(f" {key}: {value}") | |
| 399 | + logger.debug(f"{'='*60}\n") | |
| 400 | + | |
| 401 | + | |
| 402 | +def log_processing_step(logger, step_name, start_time=None): | |
| 403 | + """ | |
| 404 | + 记录处理步骤 | |
| 405 | + | |
| 406 | + Args: | |
| 407 | + logger: logger对象 | |
| 408 | + step_name: 步骤名称 | |
| 409 | + start_time: 开始时间(如果提供,会计算耗时) | |
| 410 | + """ | |
| 411 | + from datetime import datetime | |
| 412 | + current_time = datetime.now() | |
| 413 | + | |
| 414 | + logger.debug(f"\n{'='*60}") | |
| 415 | + logger.debug(f"处理步骤: {step_name}") | |
| 416 | + logger.debug(f"时间: {current_time.strftime('%Y-%m-%d %H:%M:%S')}") | |
| 417 | + | |
| 418 | + if start_time: | |
| 419 | + elapsed = (current_time - start_time).total_seconds() | |
| 420 | + logger.debug(f"耗时: {elapsed:.2f}秒") | |
| 421 | + | |
| 422 | + logger.debug(f"{'='*60}\n") | |
| 423 | + | ... | ... |
offline_tasks/scripts/fetch_item_attributes.py
| ... | ... | @@ -8,8 +8,8 @@ import json |
| 8 | 8 | import argparse |
| 9 | 9 | from datetime import datetime |
| 10 | 10 | from db_service import create_db_connection |
| 11 | -from offline_tasks.config.offline_config import DB_CONFIG, OUTPUT_DIR | |
| 12 | -from offline_tasks.scripts.debug_utils import setup_debug_logger | |
| 11 | +from config import DB_CONFIG, OUTPUT_DIR | |
| 12 | +from debug_utils import setup_debug_logger | |
| 13 | 13 | |
| 14 | 14 | |
| 15 | 15 | def fetch_and_save_mappings(engine, output_dir, logger=None, debug=False): | ... | ... |
offline_tasks/scripts/generate_session.py
| ... | ... | @@ -9,11 +9,11 @@ from collections import defaultdict |
| 9 | 9 | import argparse |
| 10 | 10 | from datetime import datetime, timedelta |
| 11 | 11 | from db_service import create_db_connection |
| 12 | -from offline_tasks.config.offline_config import ( | |
| 12 | +from config import ( | |
| 13 | 13 | DB_CONFIG, OUTPUT_DIR, get_time_range, |
| 14 | 14 | DEFAULT_LOOKBACK_DAYS |
| 15 | 15 | ) |
| 16 | -from offline_tasks.scripts.debug_utils import setup_debug_logger, log_dataframe_info | |
| 16 | +from debug_utils import setup_debug_logger, log_dataframe_info | |
| 17 | 17 | |
| 18 | 18 | |
| 19 | 19 | def aggregate_user_sessions(df, behavior_weights, logger=None, debug=False): | ... | ... |
offline_tasks/scripts/i2i_content_similar.py
| ... | ... | @@ -9,8 +9,8 @@ import pandas as pd |
| 9 | 9 | from datetime import datetime, timedelta |
| 10 | 10 | from elasticsearch import Elasticsearch |
| 11 | 11 | from db_service import create_db_connection |
| 12 | -from offline_tasks.config.offline_config import DB_CONFIG, OUTPUT_DIR | |
| 13 | -from offline_tasks.scripts.debug_utils import setup_debug_logger, log_processing_step | |
| 12 | +from config import DB_CONFIG, OUTPUT_DIR | |
| 13 | +from debug_utils import setup_debug_logger, log_processing_step | |
| 14 | 14 | |
| 15 | 15 | # ES配置 |
| 16 | 16 | ES_CONFIG = { | ... | ... |
offline_tasks/scripts/i2i_deepwalk.py
| ... | ... | @@ -11,11 +11,11 @@ from datetime import datetime |
| 11 | 11 | from collections import defaultdict |
| 12 | 12 | from gensim.models import Word2Vec |
| 13 | 13 | from db_service import create_db_connection |
| 14 | -from offline_tasks.config.offline_config import ( | |
| 14 | +from config import ( | |
| 15 | 15 | DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range, |
| 16 | 16 | DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N |
| 17 | 17 | ) |
| 18 | -from offline_tasks.scripts.debug_utils import ( | |
| 18 | +from debug_utils import ( | |
| 19 | 19 | setup_debug_logger, log_dataframe_info, |
| 20 | 20 | save_readable_index, fetch_name_mappings, log_algorithm_params, |
| 21 | 21 | log_processing_step | ... | ... |
offline_tasks/scripts/i2i_session_w2v.py
| ... | ... | @@ -10,11 +10,11 @@ from collections import defaultdict |
| 10 | 10 | from gensim.models import Word2Vec |
| 11 | 11 | import numpy as np |
| 12 | 12 | from db_service import create_db_connection |
| 13 | -from offline_tasks.config.offline_config import ( | |
| 13 | +from config import ( | |
| 14 | 14 | DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range, |
| 15 | 15 | DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N |
| 16 | 16 | ) |
| 17 | -from offline_tasks.scripts.debug_utils import ( | |
| 17 | +from debug_utils import ( | |
| 18 | 18 | setup_debug_logger, log_dataframe_info, log_dict_stats, |
| 19 | 19 | save_readable_index, fetch_name_mappings, log_algorithm_params, |
| 20 | 20 | log_processing_step | ... | ... |
offline_tasks/scripts/i2i_swing.py
| ... | ... | @@ -10,11 +10,11 @@ import argparse |
| 10 | 10 | import json |
| 11 | 11 | from datetime import datetime, timedelta |
| 12 | 12 | from db_service import create_db_connection |
| 13 | -from offline_tasks.config.offline_config import ( | |
| 13 | +from config import ( | |
| 14 | 14 | DB_CONFIG, OUTPUT_DIR, I2I_CONFIG, get_time_range, |
| 15 | 15 | DEFAULT_LOOKBACK_DAYS, DEFAULT_I2I_TOP_N |
| 16 | 16 | ) |
| 17 | -from offline_tasks.scripts.debug_utils import ( | |
| 17 | +from debug_utils import ( | |
| 18 | 18 | setup_debug_logger, log_dataframe_info, log_dict_stats, |
| 19 | 19 | save_readable_index, load_name_mappings_from_file, log_algorithm_params, |
| 20 | 20 | log_processing_step | ... | ... |
offline_tasks/scripts/interest_aggregation.py
| ... | ... | @@ -9,11 +9,11 @@ import json |
| 9 | 9 | from datetime import datetime, timedelta |
| 10 | 10 | from collections import defaultdict, Counter |
| 11 | 11 | from db_service import create_db_connection |
| 12 | -from offline_tasks.config.offline_config import ( | |
| 12 | +from config import ( | |
| 13 | 13 | DB_CONFIG, OUTPUT_DIR, INTEREST_AGGREGATION_CONFIG, get_time_range, |
| 14 | 14 | DEFAULT_LOOKBACK_DAYS, DEFAULT_RECENT_DAYS, DEFAULT_INTEREST_TOP_N |
| 15 | 15 | ) |
| 16 | -from offline_tasks.scripts.debug_utils import ( | |
| 16 | +from debug_utils import ( | |
| 17 | 17 | setup_debug_logger, log_dataframe_info, log_dict_stats, |
| 18 | 18 | save_readable_index, fetch_name_mappings, log_algorithm_params, |
| 19 | 19 | log_processing_step | ... | ... |
offline_tasks/scripts/load_index_to_redis.py
| ... | ... | @@ -0,0 +1,107 @@ |
| 1 | +#!/bin/bash | |
| 2 | + | |
| 3 | +echo "======================================================================" | |
| 4 | +echo "测试修复是否成功" | |
| 5 | +echo "======================================================================" | |
| 6 | +echo "" | |
| 7 | + | |
| 8 | +cd /home/tw/recommendation/offline_tasks | |
| 9 | + | |
| 10 | +# 测试 1: 检查文件是否存在 | |
| 11 | +echo "[测试 1] 检查文件是否正确移动..." | |
| 12 | +if [ -f "scripts/db_service.py" ]; then | |
| 13 | + echo " ✓ db_service.py 已移动到 scripts/" | |
| 14 | +else | |
| 15 | + echo " ✗ db_service.py 未找到" | |
| 16 | + exit 1 | |
| 17 | +fi | |
| 18 | + | |
| 19 | +if [ -f "deepwalk/deepwalk.py" ] && [ -f "deepwalk/alias.py" ]; then | |
| 20 | + echo " ✓ DeepWalk 文件已移动到 offline_tasks/deepwalk/" | |
| 21 | +else | |
| 22 | + echo " ✗ DeepWalk 文件未找到" | |
| 23 | + exit 1 | |
| 24 | +fi | |
| 25 | + | |
| 26 | +# 测试 2: 检查 Python 语法 | |
| 27 | +echo "" | |
| 28 | +echo "[测试 2] 检查 Python 脚本语法..." | |
| 29 | +python3 -m py_compile scripts/i2i_item_behavior.py 2>/dev/null | |
| 30 | +if [ $? -eq 0 ]; then | |
| 31 | + echo " ✓ i2i_item_behavior.py 语法正确" | |
| 32 | +else | |
| 33 | + echo " ✗ i2i_item_behavior.py 语法错误" | |
| 34 | + exit 1 | |
| 35 | +fi | |
| 36 | + | |
| 37 | +python3 -m py_compile scripts/tag_category_similar.py 2>/dev/null | |
| 38 | +if [ $? -eq 0 ]; then | |
| 39 | + echo " ✓ tag_category_similar.py 语法正确" | |
| 40 | +else | |
| 41 | + echo " ✗ tag_category_similar.py 语法错误" | |
| 42 | + exit 1 | |
| 43 | +fi | |
| 44 | + | |
| 45 | +python3 -m py_compile scripts/i2i_deepwalk.py 2>/dev/null | |
| 46 | +if [ $? -eq 0 ]; then | |
| 47 | + echo " ✓ i2i_deepwalk.py 语法正确" | |
| 48 | +else | |
| 49 | + echo " ✗ i2i_deepwalk.py 语法错误" | |
| 50 | + exit 1 | |
| 51 | +fi | |
| 52 | + | |
| 53 | +# 测试 3: 检查是否还有 sys.path hack | |
| 54 | +echo "" | |
| 55 | +echo "[测试 3] 检查是否清理了 sys.path hack..." | |
| 56 | +sys_path_count=$(grep -r "sys.path.append" scripts/*.py | wc -l) | |
| 57 | +if [ $sys_path_count -eq 0 ]; then | |
| 58 | + echo " ✓ 所有 sys.path hack 已清理" | |
| 59 | +else | |
| 60 | + echo " ⚠️ 仍有 $sys_path_count 个文件包含 sys.path.append" | |
| 61 | + grep -r "sys.path.append" scripts/*.py | |
| 62 | +fi | |
| 63 | + | |
| 64 | +# 测试 4: 检查导入语句 | |
| 65 | +echo "" | |
| 66 | +echo "[测试 4] 检查导入语句..." | |
| 67 | +if grep -q "^from db_service import" scripts/i2i_item_behavior.py; then | |
| 68 | + echo " ✓ i2i_item_behavior.py 正确导入 db_service" | |
| 69 | +else | |
| 70 | + echo " ✗ i2i_item_behavior.py 未导入 db_service" | |
| 71 | + exit 1 | |
| 72 | +fi | |
| 73 | + | |
| 74 | +if grep -q "^from db_service import" scripts/tag_category_similar.py; then | |
| 75 | + echo " ✓ tag_category_similar.py 正确导入 db_service" | |
| 76 | +else | |
| 77 | + echo " ✗ tag_category_similar.py 未导入 db_service" | |
| 78 | + exit 1 | |
| 79 | +fi | |
| 80 | + | |
| 81 | +if grep -q "from deepwalk import DeepWalk" scripts/i2i_deepwalk.py; then | |
| 82 | + echo " ✓ i2i_deepwalk.py 正确导入 DeepWalk" | |
| 83 | +else | |
| 84 | + echo " ✗ i2i_deepwalk.py 未导入 DeepWalk" | |
| 85 | + exit 1 | |
| 86 | +fi | |
| 87 | + | |
| 88 | +echo "" | |
| 89 | +echo "======================================================================" | |
| 90 | +echo "✓ 所有测试通过!" | |
| 91 | +echo "======================================================================" | |
| 92 | +echo "" | |
| 93 | +echo "现在可以运行以下命令进行完整测试:" | |
| 94 | +echo "" | |
| 95 | +echo " # 测试 Task 5" | |
| 96 | +echo " python3 scripts/i2i_item_behavior.py --lookback_days 400 --top_n 50 --debug" | |
| 97 | +echo "" | |
| 98 | +echo " # 测试 Task 6" | |
| 99 | +echo " python3 scripts/tag_category_similar.py --lookback_days 400 --top_n 50 --debug" | |
| 100 | +echo "" | |
| 101 | +echo " # 测试 Task 3 (DeepWalk - 使用较小参数)" | |
| 102 | +echo " python3 scripts/i2i_deepwalk.py --lookback_days 200 --top_n 30 --num_walks 5 --walk_length 20 --save_model --save_graph --debug" | |
| 103 | +echo "" | |
| 104 | +echo " # 或运行完整流程" | |
| 105 | +echo " bash run.sh" | |
| 106 | +echo "" | |
| 107 | + | ... | ... |