deepwalk.py 10.2 KB
import random
import numpy as np
import networkx as nx
from joblib import Parallel, delayed
import itertools
from alias import create_alias_table, alias_sample
from tqdm import tqdm
import argparse
import multiprocessing
import logging
import os

def softmax(x, temperature=1.0):
    """
    计算带有温度参数的softmax,并加入防止溢出的技巧
    """
    x = np.array(x)
    x_max = np.max(x)
    exp_x = np.exp((x - x_max) / temperature)  # 加入temperature参数
    return exp_x / np.sum(exp_x)

class DeepWalk:
    def __init__(self, edge_file, node_tag_file, use_softmax=True, temperature=1.0, p_tag_walk=0.5):
        """
        初始化DeepWalk实例,构建图和标签索引,预处理alias采样表
        """
        logging.info(f"Initializing DeepWalk with edge file: {edge_file} and node-tag file: {node_tag_file}")
        self.graph = self.build_graph_from_edge_file(edge_file)
        if node_tag_file:
            self.node_to_tags, self.tag_to_nodes = self.build_tag_index(node_tag_file)
        else:
            self.node_to_tags = None
            self.tag_to_nodes = None
            
        self.alias_nodes = {}
        self.p_tag_walk = p_tag_walk
        logging.info(f"Graph built with {self.graph.number_of_nodes()} nodes and {self.graph.number_of_edges()} edges.")

        if use_softmax:
            logging.info(f"Using softmax with temperature: {temperature}")
            self.preprocess_transition_probs__softmax(temperature)
        else:
            logging.info("Using standard alias sampling.")
            self.preprocess_transition_probs()

    def build_graph_from_edge_file(self, edge_file):
        """
        从edge文件构建图
        edge文件格式: bid1 \t bid2:weight1,bid2:weight2,...
        """
        G = nx.Graph()

        # 打开edge文件并读取内容
        with open(edge_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) != 2:
                    continue
                node, edges_str = parts
                edges = edges_str.split(',')
                
                for edge in edges:
                    nbr, weight = edge.split(':')
                    try:
                        node, nbr = int(node), int(nbr)
                    except ValueError:
                        continue
                    weight = float(weight)
                    
                    # 检查图中是否已存在这条边
                    if G.has_edge(node, nbr):
                        # 如果已经有这条边,更新权重,累加新权重
                        G[node][nbr]['weight'] += weight
                    else:
                        # 如果没有这条边,直接添加
                        G.add_edge(node, nbr, weight=weight)

        return G

    def build_tag_index(self, node_tag_file):
        """
        构建节点-标签的正排和倒排索引
        node_tag_file格式: book_id \t tag1,tag2,tag3
        """
        node_to_tags = {}
        tag_to_nodes = {}

        with open(node_tag_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) != 2:
                    continue
                node, tags_str = parts
                try:
                    node = int(node)
                except ValueError:
                    continue
                # 只保留有过用户行为的node
                if not node in self.graph:
                    continue
                tags = tags_str.split(',')
                node_to_tags[node] = tags
                for tag in tags:
                    tag_to_nodes.setdefault(tag, []).append(node)

        return node_to_tags, tag_to_nodes

    def preprocess_transition_probs(self):
        """
        预处理节点的alias采样表,用于快速加权随机游走
        """
        G = self.graph

        for node in G.nodes():
            unnormalized_probs = [G[node][nbr].get('weight', 1.0) for nbr in G.neighbors(node)]
            norm_const = sum(unnormalized_probs)
            normalized_probs = [float(u_prob) / norm_const for u_prob in unnormalized_probs]
            self.alias_nodes[node] = create_alias_table(normalized_probs)

    def preprocess_transition_probs__softmax(self, temperature=1.0):
        """
        预处理节点的alias采样表,用于快速加权随机游走
        """
        G = self.graph

        for node in G.nodes():
            unnormalized_probs = [G[node][nbr].get('weight', 1.0) for nbr in G.neighbors(node)]
            normalized_probs = softmax(unnormalized_probs, temperature)
            self.alias_nodes[node] = create_alias_table(normalized_probs)

    def deepwalk_walk(self, walk_length, start_node):
        """
        执行一次DeepWalk随机游走,基于alias方法加速,支持通过标签游走
        """
        G = self.graph
        alias_nodes = self.alias_nodes
        walk = [start_node]

        while len(walk) < walk_length:
            cur = walk[-1]

            # 根据p_tag_walk的概率决定是通过邻居游走还是通过tag游走
            if self.node_to_tags and random.random() < self.p_tag_walk and cur in self.node_to_tags:
                walk = self.tag_based_walk(cur, walk)
            else:
                walk = self.neighbor_based_walk(cur, alias_nodes, walk)

            if not walk:
                break

        return walk

    def neighbor_based_walk(self, cur, alias_nodes, walk):
        """
        基于邻居的随机游走
        """
        G = self.graph
        cur_nbrs = list(G.neighbors(cur))
        if len(cur_nbrs) > 0:
            idx = alias_sample(alias_nodes[cur][0], alias_nodes[cur][1])
            walk.append(cur_nbrs[idx])
        else:
            return None
        return walk

    def tag_based_walk(self, cur, walk):
        """
        基于标签的随机游走
        """
        tags = self.node_to_tags[cur]
        if not tags:
            return None

        # 随机选择一个tag
        chosen_tag = random.choice(tags)

        # 获取该tag下的节点列表
        nodes_with_tag = self.tag_to_nodes.get(chosen_tag, [])
        if not nodes_with_tag:
            return None

        # 随机选择一个节点
        chosen_node = random.choice(nodes_with_tag)
        walk.append(chosen_node)
        return walk

    def simulate_walks(self, num_walks, walk_length, workers, output_file):
        """
        多进程模拟多次随机游走,并将游走结果保存到文件
        """
        G = self.graph
        nodes = list(G.nodes())
        num_walks_per_worker = max(1, num_walks // workers)
        logging.info(f"Starting simulation with {num_walks_per_worker} walks per node, walk length {walk_length}, using {workers} workers.")

        # 
        # results = Parallel(n_jobs=workers)(
        # results = Parallel(n_jobs=workers, backend='multiprocessing')(
        # results = Parallel(n_jobs=workers, backend='loky')(
        results = Parallel(n_jobs=workers)(
            delayed(self._simulate_walks)(nodes, num_walks_per_worker, walk_length)
            for _ in range(workers)
        )
        walks = list(itertools.chain(*results))

        # 保存游走结果到文件
        self.save_walks_to_file(walks, output_file)

    def _simulate_walks(self, nodes, num_walks, walk_length):
        
        """
        模拟多次随机游走
        """
        logging.info(f"_simulate_walks started, num_walks:{num_walks}, walk_length:{walk_length}")
        walks = []
        for i in range(num_walks):
            logging.info(f"_simulate_walks run num_walks of {i}.")
            random.shuffle(nodes)
            for node in nodes:
                walks.append(self.deepwalk_walk(walk_length=walk_length, start_node=node))
        return walks

    def save_walks_to_file(self, walks, output_file):
        """
        将游走结果保存到文件,按Word2Vec的输入格式
        """
        logging.info(f"Saving walks to file: {output_file}")
        with open(output_file, 'w') as f:
            for walk in walks:
                walk_str = ' '.join(map(str, walk))
                f.write(walk_str + '\n')
        logging.info(f"Successfully saved {len(walks)} walks to {output_file}.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run DeepWalk with tag-based random walks")
    parser.add_argument('--edge-file', type=str, required=True, help="Path to the edge file") # ../../fetch_data/data/edge.txt.20240923
    parser.add_argument('--node-tag-file', type=str, help="Path to the node-tag file")
    parser.add_argument('--num-walks', type=int, default=100, help="Number of walks per node (default: 10)")
    parser.add_argument('--walk-length', type=int, default=40, help="Length of each walk (default: 40)")
    parser.add_argument('--workers', type=int, default=multiprocessing.cpu_count() - 1, help="Number of workers (default: CPU cores - 1)")
    parser.add_argument('--use-softmax', action='store_true', help="Use softmax-based alias sampling (default: False)")
    parser.add_argument('--temperature', type=float, default=1.0, help="Temperature for softmax (default: 1.0)")
    parser.add_argument('--p-tag-walk', type=float, default=0.2, help="Probability to walk through tag-based neighbors (default: 0.5)")
    parser.add_argument('--output-file', type=str, required=True, help="Path to save the walks file")

    args = parser.parse_args()

    # 初始化日志记录
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)

    # 初始化DeepWalk实例,传入边文件和节点标签文件
    deepwalk = DeepWalk(
        edge_file=args.edge_file,
        node_tag_file=args.node_tag_file,
        use_softmax=args.use_softmax,
        temperature=args.temperature,
        p_tag_walk=args.p_tag_walk
    )

    # 模拟随机游走并将结果保存到文件
    deepwalk.simulate_walks(
        num_walks=args.num_walks,
        walk_length=args.walk_length,
        workers=args.workers,
        output_file=args.output_file
    )