[ICML 2022] The official implementation of Graph Stochastic Attention (GSAT).

Overview

Graph Stochastic Attention (GSAT)

The official implementation of GSAT for our paper: Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism, to appear in ICML 2022.

Introduction

Commonly used attention mechanisms do not impose any constraints during training (besides normalization), and thus may lack interpretability. GSAT is a novel attention mechanism for building interpretable graph learning models. It injects stochasticity to learn attention, where a higher attention weight means a higher probability of the corresponding edge being kept during training. Such a mechanism will push the model to learn higher attention weights for edges that are important for prediction accuracy, which provides interpretability. To further improve the interpretability for graph learning tasks and avoid trivial solutions, we derive regularization terms for GSAT based on the information bottleneck (IB) principle. As a by-product, IB also helps model generalization. Fig. 1 shows the architecture of GSAT.

Figure 1. The architecture of GSAT.

Installation

We have tested our code on Python 3.9 with PyTorch 1.10.0, PyG 2.0.3 and CUDA 11.3. Please follow the following steps to create a virtual environment and install the required packages.

Create a virtual environment:

conda create --name gsat python=3.9
conda activate gsat

Install dependencies:

conda install -y pytorch==1.10.0 torchvision cudatoolkit=11.3 -c pytorch
pip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
pip install -r requirements.txt

In case a lower CUDA version is required, please use the following command to install dependencies:

conda install -y pytorch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 cudatoolkit=10.2 -c pytorch
pip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.9.0+cu102.html
pip install -r requirements.txt

Run Examples

We provide examples with minimal code to run GSAT in ./example/example.ipynb. We have tested the provided examples on Ba-2Motifs (GIN), Mutag (GIN) and OGBG-Molhiv (PNA). Yet, to implement GSAT* one needs to load a pre-trained model first in the provided example.

It should be able to run on other datasets as well, but some hard-coded hyperparameters might need to be changed accordingly. To reproduce results for other datasets, please follow the instructions in the following section.

Reproduce Results

We provide the source code to reproduce the results in our paper. The results of GSAT can be reproduced by running run_gsat.py. To reproduce GSAT*, one needs to run pretrain_clf.py first and change the configuration file accordingly (from_scratch: false).

To pre-train a classifier:

cd ./src
python pretrain_clf.py --dataset [dataset_name] --backbone [model_name] --cuda [GPU_id]

To train GSAT:

cd ./src
python run_gsat.py --dataset [dataset_name] --backbone [model_name] --cuda [GPU_id]

dataset_name can be choosen from ba_2motifs, mutag, mnist, Graph-SST2, spmotif_0.5, spmotif_0.7, spmotif_0.9, ogbg_molhiv, ogbg_moltox21, ogbg_molbace, ogbg_molbbbp, ogbg_molclintox, ogbg_molsider.

model_name can be choosen from GIN, PNA.

GPU_id is the id of the GPU to use. To use CPU, please set it to -1.

Training Logs

Standard output provides basic training logs, while more detailed logs and interpretation visualizations can be found on tensorboard:

tensorboard --logdir=./data/[dataset_name]/logs

Hyperparameter Settings

All settings can be found in ./src/configs.

Instructions on Acquiring Datasets

  • Ba_2Motifs

    • Raw data files can be downloaded automatically, provided by PGExplainer and DIG.
  • Spurious-Motif

    • Raw data files can be generated automatically, provide by DIR.
  • OGBG-Mol

    • Raw data files can be downloaded automatically, provided by OGBG.
  • Mutag

    • Raw data files need to be downloaded here, provided by PGExplainer.
    • Unzip Mutagenicity.zip and Mutagenicity.pkl.zip.
    • Put the raw data files in ./data/mutag/raw.
  • Graph-SST2

    • Raw data files need to be downloaded here, provided by DIG.
    • Unzip the downloaded Graph-SST2.zip.
    • Put the raw data files in ./data/Graph-SST2/raw.
  • MNIST-75sp

    • Raw data files need to be generated following the instruction here.
    • Put the generated files in ./data/mnist/raw.

FAQ

Does GSAT encourage sparsity?

No, GSAT doesn't encourage generating sparse subgraphs. We find r = 0.7 (Eq.(9) in our paper) can generally work well for all datasets in our experiments, which means during training roughly 70% of edges will be kept (kind of still large). This is because GSAT doesn't try to provide interpretability by finding a small/sparse subgraph of the original input graph, which is what previous works normally do and will hurt performance significantly for inhrently interpretable models (as shown in Fig. 7 in the paper). By contrast, GSAT provides interpretability by pushing the critical edges to have relatively lower stochasticity during training.

How to choose the value of r?

A grid search in [0.5, 0.6, 0.7, 0.8, 0.9] is recommended, but r = 0.7 is a good starting point. Note that in practice we would decay the value of r gradually during training from 0.9 to the chosen value.

p or α to implement Eq.(9)?

Recall in Fig. 1, p is the probability of dropping an edge, while α is the sampled result from Bern(p). In our provided implementation, as an empirical choice, α is used to implement Eq.(9) (the Gumbel-softmax trick makes α essentially continuous in practice). We find that when α is used it may provide more regularization and makes the model more robust to hyperparameters. Nonetheless, using p can achieve the same performance, but it needs some more tuning.

Can you show an example of how GSAT works?

Below we show an example from the ba_2motifs dataset, which is to distinguish five-node cycle motifs (left) and house motifs (right). To make good predictions (minimize the cross-entropy loss), GSAT will push the attention weights of those critical edges to be relatively large (ideally close to 1). Otherwise, those critical edges may be dropped too frequently and thus result in a large cross-entropy loss. Meanwhile, to minimize the regularization loss (the KL divergence term in Eq.(9) of the paper), GSAT will push the attention weights of other non-critical edges to be close to r, which is set to be 0.7 in the example. This mechanism of injecting stochasticity makes the learned attention weights from GSAT directly interpretable, since the more critical an edge is, the larger its attention weight will be (the less likely it can be dropped). Note that ba_2motifs satisfies our Thm. 4.1 with no noise, and GSAT achieves perfect interpretation performance on it.

Figure 2. An example of the learned attention weights.

Reference

If you find our paper and repo useful, please cite our paper:

@article{miao2022interpretable,
  title={Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism},
  author={Miao, Siqi and Liu, Miaoyuan and Li, Pan},
  journal={arXiv preprint arXiv:2201.12987},
  year={2022}
}
PIXIE: Collaborative Regression of Expressive Bodies

PIXIE: Collaborative Regression of Expressive Bodies [Project Page] This is the official Pytorch implementation of PIXIE. PIXIE reconstructs an expres

Yao Feng 331 Jan 04, 2023
PyTorch implementation of NIPS 2017 paper Dynamic Routing Between Capsules

Dynamic Routing Between Capsules - PyTorch implementation PyTorch implementation of NIPS 2017 paper Dynamic Routing Between Capsules from Sara Sabour,

Adam Bielski 475 Dec 24, 2022
Paddle Graph Learning (PGL) is an efficient and flexible graph learning framework based on PaddlePaddle

DOC | Quick Start | 中文 Breaking News !! 🔥 🔥 🔥 OGB-LSC KDD CUP 2021 winners announced!! (2021.06.17) Super excited to announce our PGL team won TWO

1.5k Jan 06, 2023
Deep Convolutional Generative Adversarial Networks

Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks Alec Radford, Luke Metz, Soumith Chintala All images in t

Alec Radford 3.4k Dec 29, 2022
PSGAN running with ncnn⚡妆容迁移/仿妆⚡Imitation Makeup/Makeup Transfer⚡

PSGAN running with ncnn⚡妆容迁移/仿妆⚡Imitation Makeup/Makeup Transfer⚡

WuJinxuan 144 Dec 26, 2022
"Projelerle Yapay Zeka Ve Bilgisayarlı Görü" Kitabımın projeleri

"Projelerle Yapay Zeka Ve Bilgisayarlı Görü" Kitabımın projeleri Bu Github Reposundaki tüm projeler; kaleme almış olduğum "Projelerle Yapay Zekâ ve Bi

Ümit Aksoylu 4 Aug 03, 2022
Segmentation for medical image.

EfficientSegmentation Introduction EfficientSegmentation is an open source, PyTorch-based segmentation framework for 3D medical image. Features A whol

68 Nov 28, 2022
This is a Python wrapper for TA-LIB based on Cython instead of SWIG.

TA-Lib This is a Python wrapper for TA-LIB based on Cython instead of SWIG. From the homepage: TA-Lib is widely used by trading software developers re

John Benediktsson 7.3k Jan 03, 2023
Weakly Supervised Text-to-SQL Parsing through Question Decomposition

Weakly Supervised Text-to-SQL Parsing through Question Decomposition The official repository for the paper "Weakly Supervised Text-to-SQL Parsing thro

14 Dec 19, 2022
optimization routines for hyperparameter tuning

Hyperopt: Distributed Hyperparameter Optimization Hyperopt is a Python library for serial and parallel optimization over awkward search spaces, which

Marc Claesen 398 Nov 09, 2022
DFFNet: An IoT-perceptive Dual Feature Fusion Network for General Real-time Semantic Segmentation

DFFNet Paper DFFNet: An IoT-perceptive Dual Feature Fusion Network for General Real-time Semantic Segmentation. Xiangyan Tang, Wenxuan Tu, Keqiu Li, J

4 Sep 23, 2022
Implementation of EMNLP 2017 Paper "Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog" using PyTorch and ParlAI

Language Emergence in Multi Agent Dialog Code for the Paper Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog Satwik Kottur, José M.

Karan Desai 105 Nov 25, 2022
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. Diffrax is a JAX-based library providing numerical differe

Patrick Kidger 717 Jan 09, 2023
Scalable Graph Neural Networks for Heterogeneous Graphs

Neighbor Averaging over Relation Subgraphs (NARS) NARS is an algorithm for node classification on heterogeneous graphs, based on scalable neighbor ave

Facebook Research 67 Dec 03, 2022
A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

Zain 1 Feb 01, 2022
Recursive Bayesian Networks

Recursive Bayesian Networks This repository contains the code to reproduce the results from the NeurIPS 2021 paper Lieck R, Rohrmeier M (2021) Recursi

Robert Lieck 11 Oct 18, 2022
Details about the wide minima density hypothesis and metrics to compute width of a minima

wide-minima-density-hypothesis Details about the wide minima density hypothesis and metrics to compute width of a minima This repo presents the wide m

Nikhil Iyer 9 Dec 27, 2022
Le dataset des images du projet d'IA de 2021

face-mask-dataset-ilc-2021 Le dataset des images du projet d'IA de 2021, Indiquez vos id git dans la issue pour les droits TL;DR: Choisir 200 images J

7 Nov 15, 2021
This respository includes implementations on Manifoldron: Direct Space Partition via Manifold Discovery

Manifoldron: Direct Space Partition via Manifold Discovery This respository includes implementations on Manifoldron: Direct Space Partition via Manifo

dayang_wang 4 Apr 28, 2022