Commit 1088c2610d5b675bb560fd4d3ada5fd5be4d0f9e
1 parent
8095cb00
mv files
Showing
14 changed files
with
369 additions
and
38 deletions
Show diff stats
| ... | ... | @@ -0,0 +1,55 @@ |
| 1 | +import numpy as np | |
| 2 | + | |
| 3 | + | |
| 4 | +def create_alias_table(area_ratio): | |
| 5 | + """ | |
| 6 | + | |
| 7 | + :param area_ratio: sum(area_ratio)=1 | |
| 8 | + :return: accept,alias | |
| 9 | + """ | |
| 10 | + l = len(area_ratio) | |
| 11 | + area_ratio = [prop * l for prop in area_ratio] | |
| 12 | + accept, alias = [0] * l, [0] * l | |
| 13 | + small, large = [], [] | |
| 14 | + | |
| 15 | + for i, prob in enumerate(area_ratio): | |
| 16 | + if prob < 1.0: | |
| 17 | + small.append(i) | |
| 18 | + else: | |
| 19 | + large.append(i) | |
| 20 | + | |
| 21 | + while small and large: | |
| 22 | + small_idx, large_idx = small.pop(), large.pop() | |
| 23 | + accept[small_idx] = area_ratio[small_idx] | |
| 24 | + alias[small_idx] = large_idx | |
| 25 | + area_ratio[large_idx] = area_ratio[large_idx] - \ | |
| 26 | + (1 - area_ratio[small_idx]) | |
| 27 | + if area_ratio[large_idx] < 1.0: | |
| 28 | + small.append(large_idx) | |
| 29 | + else: | |
| 30 | + large.append(large_idx) | |
| 31 | + | |
| 32 | + while large: | |
| 33 | + large_idx = large.pop() | |
| 34 | + accept[large_idx] = 1 | |
| 35 | + while small: | |
| 36 | + small_idx = small.pop() | |
| 37 | + accept[small_idx] = 1 | |
| 38 | + | |
| 39 | + return accept, alias | |
| 40 | + | |
| 41 | + | |
| 42 | +def alias_sample(accept, alias): | |
| 43 | + """ | |
| 44 | + | |
| 45 | + :param accept: | |
| 46 | + :param alias: | |
| 47 | + :return: sample index | |
| 48 | + """ | |
| 49 | + N = len(accept) | |
| 50 | + i = int(np.random.random()*N) | |
| 51 | + r = np.random.random() | |
| 52 | + if r < accept[i]: | |
| 53 | + return i | |
| 54 | + else: | |
| 55 | + return alias[i] | ... | ... |
| ... | ... | @@ -0,0 +1,266 @@ |
| 1 | +import random | |
| 2 | +import numpy as np | |
| 3 | +import networkx as nx | |
| 4 | +from joblib import Parallel, delayed | |
| 5 | +import itertools | |
| 6 | +from alias import create_alias_table, alias_sample | |
| 7 | +from tqdm import tqdm | |
| 8 | +import argparse | |
| 9 | +import multiprocessing | |
| 10 | +import logging | |
| 11 | +import os | |
| 12 | + | |
| 13 | +def softmax(x, temperature=1.0): | |
| 14 | + """ | |
| 15 | + 计算带有温度参数的softmax,并加入防止溢出的技巧 | |
| 16 | + """ | |
| 17 | + x = np.array(x) | |
| 18 | + x_max = np.max(x) | |
| 19 | + exp_x = np.exp((x - x_max) / temperature) # 加入temperature参数 | |
| 20 | + return exp_x / np.sum(exp_x) | |
| 21 | + | |
| 22 | +class DeepWalk: | |
| 23 | + def __init__(self, edge_file, node_tag_file, use_softmax=True, temperature=1.0, p_tag_walk=0.5): | |
| 24 | + """ | |
| 25 | + 初始化DeepWalk实例,构建图和标签索引,预处理alias采样表 | |
| 26 | + """ | |
| 27 | + logging.info(f"Initializing DeepWalk with edge file: {edge_file} and node-tag file: {node_tag_file}") | |
| 28 | + self.graph = self.build_graph_from_edge_file(edge_file) | |
| 29 | + if node_tag_file: | |
| 30 | + self.node_to_tags, self.tag_to_nodes = self.build_tag_index(node_tag_file) | |
| 31 | + else: | |
| 32 | + self.node_to_tags = None | |
| 33 | + self.tag_to_nodes = None | |
| 34 | + | |
| 35 | + self.alias_nodes = {} | |
| 36 | + self.p_tag_walk = p_tag_walk | |
| 37 | + logging.info(f"Graph built with {self.graph.number_of_nodes()} nodes and {self.graph.number_of_edges()} edges.") | |
| 38 | + | |
| 39 | + if use_softmax: | |
| 40 | + logging.info(f"Using softmax with temperature: {temperature}") | |
| 41 | + self.preprocess_transition_probs__softmax(temperature) | |
| 42 | + else: | |
| 43 | + logging.info("Using standard alias sampling.") | |
| 44 | + self.preprocess_transition_probs() | |
| 45 | + | |
| 46 | + def build_graph_from_edge_file(self, edge_file): | |
| 47 | + """ | |
| 48 | + 从edge文件构建图 | |
| 49 | + edge文件格式: bid1 \t bid2:weight1,bid2:weight2,... | |
| 50 | + """ | |
| 51 | + G = nx.Graph() | |
| 52 | + | |
| 53 | + # 打开edge文件并读取内容 | |
| 54 | + with open(edge_file, 'r') as f: | |
| 55 | + for line in f: | |
| 56 | + parts = line.strip().split('\t') | |
| 57 | + if len(parts) != 2: | |
| 58 | + continue | |
| 59 | + node, edges_str = parts | |
| 60 | + edges = edges_str.split(',') | |
| 61 | + | |
| 62 | + for edge in edges: | |
| 63 | + nbr, weight = edge.split(':') | |
| 64 | + try: | |
| 65 | + node, nbr = int(node), int(nbr) | |
| 66 | + except ValueError: | |
| 67 | + continue | |
| 68 | + weight = float(weight) | |
| 69 | + | |
| 70 | + # 检查图中是否已存在这条边 | |
| 71 | + if G.has_edge(node, nbr): | |
| 72 | + # 如果已经有这条边,更新权重,累加新权重 | |
| 73 | + G[node][nbr]['weight'] += weight | |
| 74 | + else: | |
| 75 | + # 如果没有这条边,直接添加 | |
| 76 | + G.add_edge(node, nbr, weight=weight) | |
| 77 | + | |
| 78 | + return G | |
| 79 | + | |
| 80 | + def build_tag_index(self, node_tag_file): | |
| 81 | + """ | |
| 82 | + 构建节点-标签的正排和倒排索引 | |
| 83 | + node_tag_file格式: book_id \t tag1,tag2,tag3 | |
| 84 | + """ | |
| 85 | + node_to_tags = {} | |
| 86 | + tag_to_nodes = {} | |
| 87 | + | |
| 88 | + with open(node_tag_file, 'r') as f: | |
| 89 | + for line in f: | |
| 90 | + parts = line.strip().split('\t') | |
| 91 | + if len(parts) != 2: | |
| 92 | + continue | |
| 93 | + node, tags_str = parts | |
| 94 | + try: | |
| 95 | + node = int(node) | |
| 96 | + except ValueError: | |
| 97 | + continue | |
| 98 | + # 只保留有过用户行为的node | |
| 99 | + if not node in self.graph: | |
| 100 | + continue | |
| 101 | + tags = tags_str.split(',') | |
| 102 | + node_to_tags[node] = tags | |
| 103 | + for tag in tags: | |
| 104 | + tag_to_nodes.setdefault(tag, []).append(node) | |
| 105 | + | |
| 106 | + return node_to_tags, tag_to_nodes | |
| 107 | + | |
| 108 | + def preprocess_transition_probs(self): | |
| 109 | + """ | |
| 110 | + 预处理节点的alias采样表,用于快速加权随机游走 | |
| 111 | + """ | |
| 112 | + G = self.graph | |
| 113 | + | |
| 114 | + for node in G.nodes(): | |
| 115 | + unnormalized_probs = [G[node][nbr].get('weight', 1.0) for nbr in G.neighbors(node)] | |
| 116 | + norm_const = sum(unnormalized_probs) | |
| 117 | + normalized_probs = [float(u_prob) / norm_const for u_prob in unnormalized_probs] | |
| 118 | + self.alias_nodes[node] = create_alias_table(normalized_probs) | |
| 119 | + | |
| 120 | + def preprocess_transition_probs__softmax(self, temperature=1.0): | |
| 121 | + """ | |
| 122 | + 预处理节点的alias采样表,用于快速加权随机游走 | |
| 123 | + """ | |
| 124 | + G = self.graph | |
| 125 | + | |
| 126 | + for node in G.nodes(): | |
| 127 | + unnormalized_probs = [G[node][nbr].get('weight', 1.0) for nbr in G.neighbors(node)] | |
| 128 | + normalized_probs = softmax(unnormalized_probs, temperature) | |
| 129 | + self.alias_nodes[node] = create_alias_table(normalized_probs) | |
| 130 | + | |
| 131 | + def deepwalk_walk(self, walk_length, start_node): | |
| 132 | + """ | |
| 133 | + 执行一次DeepWalk随机游走,基于alias方法加速,支持通过标签游走 | |
| 134 | + """ | |
| 135 | + G = self.graph | |
| 136 | + alias_nodes = self.alias_nodes | |
| 137 | + walk = [start_node] | |
| 138 | + | |
| 139 | + while len(walk) < walk_length: | |
| 140 | + cur = walk[-1] | |
| 141 | + | |
| 142 | + # 根据p_tag_walk的概率决定是通过邻居游走还是通过tag游走 | |
| 143 | + if self.node_to_tags and random.random() < self.p_tag_walk and cur in self.node_to_tags: | |
| 144 | + walk = self.tag_based_walk(cur, walk) | |
| 145 | + else: | |
| 146 | + walk = self.neighbor_based_walk(cur, alias_nodes, walk) | |
| 147 | + | |
| 148 | + if not walk: | |
| 149 | + break | |
| 150 | + | |
| 151 | + return walk | |
| 152 | + | |
| 153 | + def neighbor_based_walk(self, cur, alias_nodes, walk): | |
| 154 | + """ | |
| 155 | + 基于邻居的随机游走 | |
| 156 | + """ | |
| 157 | + G = self.graph | |
| 158 | + cur_nbrs = list(G.neighbors(cur)) | |
| 159 | + if len(cur_nbrs) > 0: | |
| 160 | + idx = alias_sample(alias_nodes[cur][0], alias_nodes[cur][1]) | |
| 161 | + walk.append(cur_nbrs[idx]) | |
| 162 | + else: | |
| 163 | + return None | |
| 164 | + return walk | |
| 165 | + | |
| 166 | + def tag_based_walk(self, cur, walk): | |
| 167 | + """ | |
| 168 | + 基于标签的随机游走 | |
| 169 | + """ | |
| 170 | + tags = self.node_to_tags[cur] | |
| 171 | + if not tags: | |
| 172 | + return None | |
| 173 | + | |
| 174 | + # 随机选择一个tag | |
| 175 | + chosen_tag = random.choice(tags) | |
| 176 | + | |
| 177 | + # 获取该tag下的节点列表 | |
| 178 | + nodes_with_tag = self.tag_to_nodes.get(chosen_tag, []) | |
| 179 | + if not nodes_with_tag: | |
| 180 | + return None | |
| 181 | + | |
| 182 | + # 随机选择一个节点 | |
| 183 | + chosen_node = random.choice(nodes_with_tag) | |
| 184 | + walk.append(chosen_node) | |
| 185 | + return walk | |
| 186 | + | |
| 187 | + def simulate_walks(self, num_walks, walk_length, workers, output_file): | |
| 188 | + """ | |
| 189 | + 多进程模拟多次随机游走,并将游走结果保存到文件 | |
| 190 | + """ | |
| 191 | + G = self.graph | |
| 192 | + nodes = list(G.nodes()) | |
| 193 | + num_walks_per_worker = max(1, num_walks // workers) | |
| 194 | + logging.info(f"Starting simulation with {num_walks_per_worker} walks per node, walk length {walk_length}, using {workers} workers.") | |
| 195 | + | |
| 196 | + # | |
| 197 | + # results = Parallel(n_jobs=workers)( | |
| 198 | + # results = Parallel(n_jobs=workers, backend='multiprocessing')( | |
| 199 | + # results = Parallel(n_jobs=workers, backend='loky')( | |
| 200 | + results = Parallel(n_jobs=workers)( | |
| 201 | + delayed(self._simulate_walks)(nodes, num_walks_per_worker, walk_length) | |
| 202 | + for _ in range(workers) | |
| 203 | + ) | |
| 204 | + walks = list(itertools.chain(*results)) | |
| 205 | + | |
| 206 | + # 保存游走结果到文件 | |
| 207 | + self.save_walks_to_file(walks, output_file) | |
| 208 | + | |
| 209 | + def _simulate_walks(self, nodes, num_walks, walk_length): | |
| 210 | + | |
| 211 | + """ | |
| 212 | + 模拟多次随机游走 | |
| 213 | + """ | |
| 214 | + logging.info(f"_simulate_walks started, num_walks:{num_walks}, walk_length:{walk_length}") | |
| 215 | + walks = [] | |
| 216 | + for i in range(num_walks): | |
| 217 | + logging.info(f"_simulate_walks run num_walks of {i}.") | |
| 218 | + random.shuffle(nodes) | |
| 219 | + for node in nodes: | |
| 220 | + walks.append(self.deepwalk_walk(walk_length=walk_length, start_node=node)) | |
| 221 | + return walks | |
| 222 | + | |
| 223 | + def save_walks_to_file(self, walks, output_file): | |
| 224 | + """ | |
| 225 | + 将游走结果保存到文件,按Word2Vec的输入格式 | |
| 226 | + """ | |
| 227 | + logging.info(f"Saving walks to file: {output_file}") | |
| 228 | + with open(output_file, 'w') as f: | |
| 229 | + for walk in walks: | |
| 230 | + walk_str = ' '.join(map(str, walk)) | |
| 231 | + f.write(walk_str + '\n') | |
| 232 | + logging.info(f"Successfully saved {len(walks)} walks to {output_file}.") | |
| 233 | + | |
| 234 | +if __name__ == "__main__": | |
| 235 | + parser = argparse.ArgumentParser(description="Run DeepWalk with tag-based random walks") | |
| 236 | + parser.add_argument('--edge-file', type=str, required=True, help="Path to the edge file") # ../../fetch_data/data/edge.txt.20240923 | |
| 237 | + parser.add_argument('--node-tag-file', type=str, help="Path to the node-tag file") | |
| 238 | + parser.add_argument('--num-walks', type=int, default=100, help="Number of walks per node (default: 10)") | |
| 239 | + parser.add_argument('--walk-length', type=int, default=40, help="Length of each walk (default: 40)") | |
| 240 | + parser.add_argument('--workers', type=int, default=multiprocessing.cpu_count() - 1, help="Number of workers (default: CPU cores - 1)") | |
| 241 | + parser.add_argument('--use-softmax', action='store_true', help="Use softmax-based alias sampling (default: False)") | |
| 242 | + parser.add_argument('--temperature', type=float, default=1.0, help="Temperature for softmax (default: 1.0)") | |
| 243 | + parser.add_argument('--p-tag-walk', type=float, default=0.2, help="Probability to walk through tag-based neighbors (default: 0.5)") | |
| 244 | + parser.add_argument('--output-file', type=str, required=True, help="Path to save the walks file") | |
| 245 | + | |
| 246 | + args = parser.parse_args() | |
| 247 | + | |
| 248 | + # 初始化日志记录 | |
| 249 | + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) | |
| 250 | + | |
| 251 | + # 初始化DeepWalk实例,传入边文件和节点标签文件 | |
| 252 | + deepwalk = DeepWalk( | |
| 253 | + edge_file=args.edge_file, | |
| 254 | + node_tag_file=args.node_tag_file, | |
| 255 | + use_softmax=args.use_softmax, | |
| 256 | + temperature=args.temperature, | |
| 257 | + p_tag_walk=args.p_tag_walk | |
| 258 | + ) | |
| 259 | + | |
| 260 | + # 模拟随机游走并将结果保存到文件 | |
| 261 | + deepwalk.simulate_walks( | |
| 262 | + num_walks=args.num_walks, | |
| 263 | + walk_length=args.walk_length, | |
| 264 | + workers=args.workers, | |
| 265 | + output_file=args.output_file | |
| 266 | + ) | ... | ... |
offline_tasks/scripts/add_names_to_swing.py
| ... | ... | @@ -3,10 +3,6 @@ |
| 3 | 3 | 输入格式: item_id \t similar_item_id1:score1,similar_item_id2:score2,... |
| 4 | 4 | 输出格式: item_id:name \t similar_item_id1:name1:score1,similar_item_id2:name2:score2,... |
| 5 | 5 | """ |
| 6 | -import sys | |
| 7 | -import os | |
| 8 | -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| 9 | - | |
| 10 | 6 | import argparse |
| 11 | 7 | from datetime import datetime |
| 12 | 8 | from offline_tasks.scripts.debug_utils import setup_debug_logger, load_name_mappings_from_file | ... | ... |
| ... | ... | @@ -0,0 +1,48 @@ |
| 1 | +""" | |
| 2 | +数据库连接服务模块 | |
| 3 | +提供统一的数据库连接接口 | |
| 4 | +""" | |
| 5 | +from sqlalchemy import create_engine | |
| 6 | +from urllib.parse import quote_plus | |
| 7 | +import logging | |
| 8 | + | |
| 9 | +logging.basicConfig(level=logging.INFO) | |
| 10 | +logger = logging.getLogger(__name__) | |
| 11 | + | |
| 12 | + | |
| 13 | +def create_db_connection(host, port, database, username, password): | |
| 14 | + """ | |
| 15 | + 创建数据库连接 | |
| 16 | + | |
| 17 | + Args: | |
| 18 | + host: 数据库主机地址 | |
| 19 | + port: 端口 | |
| 20 | + database: 数据库名 | |
| 21 | + username: 用户名 | |
| 22 | + password: 密码 | |
| 23 | + | |
| 24 | + Returns: | |
| 25 | + SQLAlchemy engine对象 | |
| 26 | + """ | |
| 27 | + try: | |
| 28 | + # 对密码进行URL编码,处理特殊字符 | |
| 29 | + encoded_password = quote_plus(password) | |
| 30 | + | |
| 31 | + # 构建连接字符串 | |
| 32 | + connection_string = f'mysql+pymysql://{username}:{encoded_password}@{host}:{port}/{database}' | |
| 33 | + | |
| 34 | + # 创建引擎 | |
| 35 | + engine = create_engine( | |
| 36 | + connection_string, | |
| 37 | + pool_pre_ping=True, # 连接池预检 | |
| 38 | + pool_recycle=3600, # 连接回收时间 | |
| 39 | + echo=False | |
| 40 | + ) | |
| 41 | + | |
| 42 | + logger.info(f"Database connection created successfully: {host}:{port}/{database}") | |
| 43 | + return engine | |
| 44 | + | |
| 45 | + except Exception as e: | |
| 46 | + logger.error(f"Failed to create database connection: {e}") | |
| 47 | + raise | |
| 48 | + | ... | ... |
offline_tasks/scripts/fetch_item_attributes.py
offline_tasks/scripts/generate_session.py
| ... | ... | @@ -3,10 +3,6 @@ |
| 3 | 3 | 从数据库读取用户行为,生成适用于C++ Swing算法的session文件 |
| 4 | 4 | 输出格式: uid \t {"item_id":score,"item_id":score,...} |
| 5 | 5 | """ |
| 6 | -import sys | |
| 7 | -import os | |
| 8 | -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| 9 | - | |
| 10 | 6 | import pandas as pd |
| 11 | 7 | import json |
| 12 | 8 | from collections import defaultdict | ... | ... |
offline_tasks/scripts/i2i_content_similar.py
| ... | ... | @@ -4,10 +4,6 @@ i2i - 基于ES向量的内容相似索引 |
| 4 | 4 | 1. 基于名称文本向量的相似度 |
| 5 | 5 | 2. 基于图片向量的相似度 |
| 6 | 6 | """ |
| 7 | -import sys | |
| 8 | -import os | |
| 9 | -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| 10 | - | |
| 11 | 7 | import json |
| 12 | 8 | import pandas as pd |
| 13 | 9 | from datetime import datetime, timedelta | ... | ... |
offline_tasks/scripts/i2i_item_behavior.py
offline_tasks/scripts/i2i_session_w2v.py
offline_tasks/scripts/i2i_swing.py
| ... | ... | @@ -3,10 +3,6 @@ i2i - Swing算法实现 |
| 3 | 3 | 基于用户行为的物品相似度计算 |
| 4 | 4 | 参考item_sim.py的数据格式,适配真实数据 |
| 5 | 5 | """ |
| 6 | -import sys | |
| 7 | -import os | |
| 8 | -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| 9 | - | |
| 10 | 6 | import pandas as pd |
| 11 | 7 | import math |
| 12 | 8 | from collections import defaultdict | ... | ... |
offline_tasks/scripts/interest_aggregation.py
offline_tasks/scripts/load_index_to_redis.py
offline_tasks/scripts/tag_category_similar.py
offline_tasks/scripts/test_es_connection.py