A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019).

Overview

ClusterGCN

Arxiv codebeat badge repo sizebenedekrozemberczki

A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019).

Abstract

Graph convolutional network (GCN) has been successfully applied to many graph-based applications; however, training a large-scale GCN remains challenging. Current SGD-based algorithms suffer from either a high computational cost that exponentially grows with number of GCN layers, or a large space requirement for keeping the entire graph and the embedding of each node in memory. In this paper, we propose Cluster-GCN, a novel GCN algorithm that is suitable for SGD-based training by exploiting the graph clustering structure. Cluster-GCN works as the following: at each step, it samples a block of nodes that associate with a dense subgraph identified by a graph clustering algorithm, and restricts the neighborhood search within this subgraph. This simple but effective strategy leads to significantly improved memory and computational efficiency while being able to achieve comparable test accuracy with previous algorithms. To test the scalability of our algorithm, we create a new Amazon2M data with 2 million nodes and 61 million edges which is more than 5 times larger than the previous largest publicly available dataset (Reddit). For training a 3-layer GCN on this data, Cluster-GCN is faster than the previous state-of-the-art VR-GCN (1523 seconds vs 1961 seconds) and using much less memory (2.2GB vs 11.2GB). Furthermore, for training 4 layer GCN on this data, our algorithm can finish in around 36 minutes while all the existing GCN training algorithms fail to train due to the out-of-memory issue. Furthermore, Cluster-GCN allows us to train much deeper GCN without much time and memory overhead, which leads to improved prediction accuracy -- using a 5-layer Cluster-GCN, we achieve state-of-the-art test F1 score 99.36 on the PPI dataset, while the previous best result was 98.71.

This repository provides a PyTorch implementation of ClusterGCN as described in the paper:

Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks Wei-Lin Chiang, Xuanqing Liu, Si Si, Yang Li, Samy Bengio, Cho-Jui Hsieh. KDD, 2019. [Paper]

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx           1.11
tqdm               4.28.1
numpy              1.15.4
pandas             0.23.4
texttable          1.5.0
scipy              1.1.0
argparse           1.1.0
torch              0.4.1
torch-geometric    0.3.1
metis              0.2a.4
scikit-learn       0.20
torch_spline_conv  1.0.4
torch_sparse       0.2.2
torch_scatter      1.0.4
torch_cluster      1.1.5

Installing metis on Ubuntu:

sudo apt-get install libmetis-dev

Datasets

The code takes the **edge list** of the graph in a csv file. Every row indicates an edge between two nodes separated by a comma. The first row is a header. Nodes should be indexed starting with 0. A sample graph for `Pubmed` is included in the `input/` directory. In addition to the edgelist there is a csv file with the sparse features and another one with the target variable.

The **feature matrix** is a sparse one and it is stored as a csv. Features are indexed from 0 consecutively. The feature matrix csv is structured as:

NODE ID FEATURE ID Value
0 3 0.2
0 7 0.5
1 17 0.8
1 4 5.4
1 38 1.3
... ... ...
n 3 0.9

The **target vector** is a csv with two columns and headers, the first contains the node identifiers the second the targets. This csv is sorted by node identifiers and the target column contains the class meberships indexed from zero.

NODE ID Target
0 3
1 1
2 0
3 1
... ...
n 3

Options

The training of a ClusterGCN model is handled by the `src/main.py` script which provides the following command line arguments.

Input and output options

  --edge-path       STR    Edge list csv.         Default is `input/edges.csv`.
  --features-path   STR    Features csv.         Default is `input/features.csv`.
  --target-path     STR    Target classes csv.    Default is `input/target.csv`.

Model options

  --clustering-method   STR     Clustering method.             Default is `metis`.
  --cluster-number      INT     Number of clusters.            Default is 10. 
  --seed                INT     Random seed.                   Default is 42.
  --epochs              INT     Number of training epochs.     Default is 200.
  --test-ratio          FLOAT   Training set ratio.            Default is 0.9.
  --learning-rate       FLOAT   Adam learning rate.            Default is 0.01.
  --dropout             FLOAT   Dropout rate value.            Default is 0.5.
  --layers              LST     Layer sizes.                   Default is [16, 16, 16]. 

Examples

The following commands learn a neural network and score on the test set. Training a model on the default dataset.

$ python src/main.py

Training a ClusterGCN model for a 100 epochs.

$ python src/main.py --epochs 100

Increasing the learning rate and the dropout.

$ python src/main.py --learning-rate 0.1 --dropout 0.9

Training a model with a different layer structure:

$ python src/main.py --layers 64 64

Training a random clustered model:

$ python src/main.py --clustering-method random

License

Comments
  • Segmentation fault While running main.py on Ubuntu

    Segmentation fault While running main.py on Ubuntu

    while i am running main.py i am getting the segmentation fault error on Ubuntu.

    python3 main.py --epochs 100

    +-------------------+----------------------------------------------------------+ | Parameter | Value | +===================+==========================================================+ | Cluster number | 10 | +-------------------+----------------------------------------------------------+ | Clustering method | metis | +-------------------+----------------------------------------------------------+ | Dropout | 0.500 | +-------------------+----------------------------------------------------------+ | Edge path | /home/User/Desktop/ClusterGCN-master/input/edges.csv | +-------------------+----------------------------------------------------------+ | Epochs | 100 | +-------------------+----------------------------------------------------------+ | Features path | /home/User/Desktop/ClusterGCN- | | | master/input/features.csv | +-------------------+----------------------------------------------------------+ | Layers | [16, 16, 16] | +-------------------+----------------------------------------------------------+ | Learning rate | 0.010 | +-------------------+----------------------------------------------------------+ | Seed | 42 | +-------------------+----------------------------------------------------------+ | Target path | /home/User/Desktop/ClusterGCN- | | | master/input//target.csv | +-------------------+----------------------------------------------------------+ | Test ratio | 0.900 | +-------------------+----------------------------------------------------------+

    Metis graph clustering started.

    Segmentation fault

    opened by alamsaqib 4
  • ImportError: No module named 'torch_spline_conv'

    ImportError: No module named 'torch_spline_conv'

    I followed the instructions of installation properly, however, error above occurred.

    After checking the site packages folder, i do not find the file torch_spline_conv. I will google around for finding out why that is happening, but thought you might have some insights

    Any help is appreciated.

    The complete trace is as follows

    File "src/main.py", line 4, in <module>
        from clustergcn import ClusterGCNTrainer
      File "/media/anuj/Softwares & Study Material/Study Material/MS Stuff/RA/ClusterGCN/src/clustergcn.py", line 5, in <module>
        from layers import StackedGCN
      File "/media/anuj/Softwares & Study Material/Study Material/MS Stuff/RA/ClusterGCN/src/layers.py", line 2, in <module>
        from torch_geometric.nn import GCNConv
      File "/home/anuj/virtualenv-forest/gcn/lib/python3.5/site-packages/torch_geometric/nn/__init__.py", line 1, in <module>
        from .conv import *  # noqa
      File "/home/anuj/virtualenv-forest/gcn/lib/python3.5/site-packages/torch_geometric/nn/conv/__init__.py", line 1, in <module>
        from .spline_conv import SplineConv
      File "/home/anuj/virtualenv-forest/gcn/lib/python3.5/site-packages/torch_geometric/nn/conv/spline_conv.py", line 3, in <module>
        from torch_spline_conv import SplineConv as Conv
    ImportError: No module named 'torch_spline_conv'
    
    
    opened by 1byxero 2
  • For ppi

    For ppi

    Hello. Thanks for your work and code. It's great that Cluster-GCN achieves great performance in PPI datasets. But it seems that you have not opened source the code for PPI node classification.

    Do you find the best model on validation dataset at first then test on the unseen test dataset? I notice that GraphStar now is the SOTA. However, they don't use the validation dataset and directly find the best model on test dataset.

    Can you share code of PPI with us and mention how to split dataset in the readme file? It's important for others to follow your great job.

    opened by guochengqian 2
  • Metis hits a Segmentation fault when running _METIS_PartGraphKway

    Metis hits a Segmentation fault when running _METIS_PartGraphKway

    • I'm using the default test input files.

    • I've attached pdb screenshot during the run.

    • Environment: Ubuntu 18.04 Anaconda (Python 3.7.3),
      torch-geometric==1.3.0 torch-scatter==1.3.0 torch-sparse==0.4.0 torch-spline-conv==1.1.0 metis==0.2a.4

    PDB Error Screenshot from 2019-07-04 13-56-16

    Requirements.txt Screenshot from 2019-07-04 14-02-14

    opened by poppingtonic 2
  • The error of metis, Segmentation fault (core dumped)

    The error of metis, Segmentation fault (core dumped)

    I found that I can use the random model to divide the graph, but when using Metis, the code will terminate abnormally. I want to ask what causes this. I change "IDXTYPEWIDTH = os.getenv('METIS_IDXTYPEWIDTH', '32')" in metis.py (line 31) to "IDXTYPEWIDTH = os.getenv('METIS_IDXTYPEWIDTH', '64')", but it doesn't work!!!

    python src/main.py +-------------------+----------------------+ | Parameter | Value | +===================+======================+ | Cluster number | 10 | +-------------------+----------------------+ | Clustering method | metis | +-------------------+----------------------+ | Dropout | 0.500 | +-------------------+----------------------+ | Edge path | ./input/edges.csv | +-------------------+----------------------+ | Epochs | 200 | +-------------------+----------------------+ | Features path | ./input/features.csv | +-------------------+----------------------+ | Layers | [16, 16, 16] | +-------------------+----------------------+ | Learning rate | 0.010 | +-------------------+----------------------+ | Seed | 42 | +-------------------+----------------------+ | Target path | ./input/target.csv | +-------------------+----------------------+ | Test ratio | 0.900 | +-------------------+----------------------+

    Metis graph clustering started.

    Segmentation fault (core dumped)

    opened by yiyang-wang 1
  • TypeError: object of type 'int' has no len()

    TypeError: object of type 'int' has no len()

    hello, when I run main.py, I found the error message: File "D:\anaconda3.4\lib\site-packages\pymetis_init_.py", line 44, in _prepare_graph for i in range(len(adjacency)): TypeError: object of type 'int' has no len()

    I have installed pymetis package to solve the metis.dll, this error occurs in the pymetis_init_.py. do you know how to solve it?

    opened by tanjia123456 1
  • RuntimeError: Could not locate METIS dll.

    RuntimeError: Could not locate METIS dll.

    hello,when I run main.py, the error massage appears:

    raise RuntimeError('Could not locate METIS dll. Please set the METIS_DLL environment variable to its full path.') RuntimeError: Could not locate METIS dll. Please set the METIS_DLL environment variable to its full path.

    do you know how to solve it?

    opened by tanjia123456 1
  • Runtime error about metis

    Runtime error about metis

    At the train begining that part the full graph, the function "metis.part_graph(self.graph, self.args.cluster_number)" throws an error: Traceback (most recent call last): File "C:/Users/xieRu/Desktop/ML/ClusterGCN/src/main.py", line 30, in <module> main() File "C:/Users/xieRu/Desktop/ML/ClusterGCN/src/main.py", line 19, in main clustering_machine.decompose() File "C:\Users\xieRu\Desktop\ML\ClusterGCN\src\clustering.py", line 38, in decompose self.metis_clustering() File "C:\Users\xieRu\Desktop\ML\ClusterGCN\src\clustering.py", line 56, in metis_clustering (st, parts) = metis.part_graph(self.graph, self.args.cluster_number) File "D:\Program\Anaconda\lib\site-packages\metis.py", line 800, in part_graph _METIS_PartGraphKway(*args) File "D:\Program\Anaconda\lib\site-packages\metis.py", line 677, in _METIS_PartGraphKway adjwgt, nparts, tpwgts, ubvec, options, objval, part) OSError: exception: access violation writing 0x000001B0B9C0E000

    But I tried test package metis as follow, It works: ` import metis from networkx import karate_club_graph

    zkc = karate_club_graph() graph_clustering=metis.part_graph(zkc) ` So, what happend?

    opened by ByskyXie 1
  • some question about code

    some question about code

    It seems like your code didn't consider the connection between clusters,and normalization that are mentioned in paper ,will you add these two options?

    opened by thunderbird0902 1
  • About installation

    About installation

    Hi there: Thank you for your great work, I've finally got the code running. To make the installation in README.md more precise & complete. You may want to add the following dependancies:

    • torch_spline_conv == 1.0.4
    • torch_sparse == 0.2.2
    • torch_scatter == 1.0.4
    • torch_cluster == 1.1.5 (strict)
    opened by dkdk-ddk 1
  • Cannot run main.py

    Cannot run main.py

    src/main.py --epochs 100 +-------------------+----------------------+ | Parameter | Value | +===================+======================+ | Cluster number | 10 | +-------------------+----------------------+ | Clustering method | metis | +-------------------+----------------------+ | Dropout | 0.500 | +-------------------+----------------------+ | Edge path | ./input/edges.csv | +-------------------+----------------------+ | Epochs | 100 | +-------------------+----------------------+ | Features path | ./input/features.csv | +-------------------+----------------------+ | Layers | [16, 16, 16] | +-------------------+----------------------+ | Learning rate | 0.010 | +-------------------+----------------------+ | Seed | 42 | +-------------------+----------------------+ | Target path | ./input/target.csv | +-------------------+----------------------+ | Test ratio | 0.900 | +-------------------+----------------------+

    Metis graph clustering started.

    Traceback (most recent call last): File "src/main.py", line 24, in main() File "src/main.py", line 18, in main clustering_machine.decompose() File "/Users/linmiao/gits/ClusterGCN/src/clustering.py", line 38, in decompose self.metis_clustering() File "/Users/linmiao/gits/ClusterGCN/src/clustering.py", line 56, in metis_clustering (st, parts) = metis.part_graph(self.graph, self.args.cluster_number) File "/usr/local/lib/python3.7/site-packages/metis.py", line 765, in part_graph graph = networkx_to_metis(graph) File "/usr/local/lib/python3.7/site-packages/metis.py", line 574, in networkx_to_metis for i in H.node: AttributeError: 'Graph' object has no attribute 'node'

    opened by linkerlin 1
  • issues about the metis algorithm

    issues about the metis algorithm

    (st, parts) = metis.part_graph(self.graph, self.args.cluster_number) Thanks for your awesome code, could you please tell me how metis conduct the graph partition? Cause the self.graph here doesn't include the information about edge weights and feature attributes.

    opened by immortal13 2
Releases(v_00001)
Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
Vehicle Detection Using Deep Learning and YOLO Algorithm

VehicleDetection Vehicle Detection Using Deep Learning and YOLO Algorithm Dataset take or find vehicle images for create a special dataset for fine-tu

Maryam Boneh 96 Jan 05, 2023
In the AI for TSP competition we try to solve optimization problems using machine learning.

AI for TSP Competition Goal In the AI for TSP competition we try to solve optimization problems using machine learning. The competition will be hosted

Paulo da Costa 11 Nov 27, 2022
Towards End-to-end Video-based Eye Tracking

Towards End-to-end Video-based Eye Tracking The code accompanying our ECCV 2020 publication and dataset, EVE. Authors: Seonwook Park, Emre Aksan, Xuco

Seonwook Park 76 Dec 12, 2022
BasicNeuralNetwork - This project looks over the basic structure of a neural network and how machine learning training algorithms work

BasicNeuralNetwork - This project looks over the basic structure of a neural network and how machine learning training algorithms work. For this project, I used the sigmoid function as an activation

Manas Bommakanti 1 Jan 22, 2022
Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)

Cross-media Structured Common Space for Multimedia Event Extraction Table of Contents Overview Requirements Data Quickstart Citation Overview The code

Manling Li 49 Nov 21, 2022
Flybirds - BDD-driven natural language automated testing framework, present by Trip Flight

Flybird | English Version 行为驱动开发(Behavior-driven development,缩写BDD),是一种软件过程的思想或者

Ctrip, Inc. 706 Dec 30, 2022
This is an early in-development version of training CLIP models with hivemind.

A transformer that does not hog your GPU memory This is an early in-development codebase: if you want a stable and documented hivemind codebase, look

<a href=[email protected]"> 4 Nov 06, 2022
All course materials for the Zero to Mastery Machine Learning and Data Science course.

Zero to Mastery Machine Learning Welcome! This repository contains all of the code, notebooks, images and other materials related to the Zero to Maste

Daniel Bourke 1.6k Jan 08, 2023
Code for: Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification

Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification Prerequisite PyTorch = 1.2.0 Python3 torch

16 Dec 14, 2022
A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery

PiSL A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery. Sun, F., Liu, Y. and Sun, H., 2021. Physics-informe

Fangzheng (Andy) Sun 8 Jul 13, 2022
(AAAI2022) Style Mixing and Patchwise Prototypical Matching for One-Shot Unsupervised Domain Adaptive Semantic Segmentation

SM-PPM This is a Pytorch implementation of our paper "Style Mixing and Patchwise Prototypical Matching for One-Shot Unsupervised Domain Adaptive Seman

W-zx-Y 10 Dec 07, 2022
Image Segmentation and Object Detection in Pytorch

Image Segmentation and Object Detection in Pytorch Pytorch-Segmentation-Detection is a library for image segmentation and object detection with report

Daniil Pakhomov 732 Dec 10, 2022
FMA: A Dataset For Music Analysis

FMA: A Dataset For Music Analysis Michaël Defferrard, Kirell Benzi, Pierre Vandergheynst, Xavier Bresson. International Society for Music Information

Michaël Defferrard 1.8k Dec 29, 2022
The official code repository for examples in the O'Reilly book 'Generative Deep Learning'

Generative Deep Learning Teaching Machines to paint, write, compose and play The official code repository for examples in the O'Reilly book 'Generativ

David Foster 1.3k Dec 29, 2022
Deep Reinforcement Learning based Trading Agent for Bitcoin

Deep Trading Agent Deep Reinforcement Learning based Trading Agent for Bitcoin using DeepSense Network for Q function approximation. For complete deta

Kartikay Garg 669 Dec 29, 2022
Labelbox is the fastest way to annotate data to build and ship artificial intelligence applications

Labelbox Labelbox is the fastest way to annotate data to build and ship artificial intelligence applications. Use this github repository to help you s

labelbox 1.7k Dec 29, 2022
A demo of how to use JAX to create a simple gravity simulation

JAX Gravity This repo contains a demo of how to use JAX to create a simple gravity simulation. It uses JAX's experimental ode package to solve the dif

Cristian Garcia 16 Sep 22, 2022
CLNTM - Contrastive Learning for Neural Topic Model

Contrastive Learning for Neural Topic Model This repository contains the impleme

Thong Thanh Nguyen 25 Nov 24, 2022
Cossim - Sharpened Cosine Distance implementation in PyTorch

Sharpened Cosine Distance PyTorch implementation of the Sharpened Cosine Distanc

Istvan Fehervari 10 Mar 22, 2022
Visualizing lattice vibration information from phonon dispersion to atoms (For GPUMD)

Phonon-Vibration-Viewer (For GPUMD) Visualizing lattice vibration information from phonon dispersion for primitive atoms. In this tutorial, we will in

Liangting 6 Dec 10, 2022