[ICML 2022] The official implementation of Graph Stochastic Attention (GSAT).

Overview

Graph Stochastic Attention (GSAT)

The official implementation of GSAT for our paper: Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism, to appear in ICML 2022.

Introduction

Commonly used attention mechanisms do not impose any constraints during training (besides normalization), and thus may lack interpretability. GSAT is a novel attention mechanism for building interpretable graph learning models. It injects stochasticity to learn attention, where a higher attention weight means a higher probability of the corresponding edge being kept during training. Such a mechanism will push the model to learn higher attention weights for edges that are important for prediction accuracy, which provides interpretability. To further improve the interpretability for graph learning tasks and avoid trivial solutions, we derive regularization terms for GSAT based on the information bottleneck (IB) principle. As a by-product, IB also helps model generalization. Fig. 1 shows the architecture of GSAT.

Figure 1. The architecture of GSAT.

Installation

We have tested our code on Python 3.9 with PyTorch 1.10.0, PyG 2.0.3 and CUDA 11.3. Please follow the following steps to create a virtual environment and install the required packages.

Create a virtual environment:

conda create --name gsat python=3.9
conda activate gsat

Install dependencies:

conda install -y pytorch==1.10.0 torchvision cudatoolkit=11.3 -c pytorch
pip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
pip install -r requirements.txt

In case a lower CUDA version is required, please use the following command to install dependencies:

conda install -y pytorch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 cudatoolkit=10.2 -c pytorch
pip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.9.0+cu102.html
pip install -r requirements.txt

Run Examples

We provide examples with minimal code to run GSAT in ./example/example.ipynb. We have tested the provided examples on Ba-2Motifs (GIN), Mutag (GIN) and OGBG-Molhiv (PNA). Yet, to implement GSAT* one needs to load a pre-trained model first in the provided example.

It should be able to run on other datasets as well, but some hard-coded hyperparameters might need to be changed accordingly. To reproduce results for other datasets, please follow the instructions in the following section.

Reproduce Results

We provide the source code to reproduce the results in our paper. The results of GSAT can be reproduced by running run_gsat.py. To reproduce GSAT*, one needs to run pretrain_clf.py first and change the configuration file accordingly (from_scratch: false).

To pre-train a classifier:

cd ./src
python pretrain_clf.py --dataset [dataset_name] --backbone [model_name] --cuda [GPU_id]

To train GSAT:

cd ./src
python run_gsat.py --dataset [dataset_name] --backbone [model_name] --cuda [GPU_id]

dataset_name can be choosen from ba_2motifs, mutag, mnist, Graph-SST2, spmotif_0.5, spmotif_0.7, spmotif_0.9, ogbg_molhiv, ogbg_moltox21, ogbg_molbace, ogbg_molbbbp, ogbg_molclintox, ogbg_molsider.

model_name can be choosen from GIN, PNA.

GPU_id is the id of the GPU to use. To use CPU, please set it to -1.

Training Logs

Standard output provides basic training logs, while more detailed logs and interpretation visualizations can be found on tensorboard:

tensorboard --logdir=./data/[dataset_name]/logs

Hyperparameter Settings

All settings can be found in ./src/configs.

Instructions on Acquiring Datasets

  • Ba_2Motifs

    • Raw data files can be downloaded automatically, provided by PGExplainer and DIG.
  • Spurious-Motif

    • Raw data files can be generated automatically, provide by DIR.
  • OGBG-Mol

    • Raw data files can be downloaded automatically, provided by OGBG.
  • Mutag

    • Raw data files need to be downloaded here, provided by PGExplainer.
    • Unzip Mutagenicity.zip and Mutagenicity.pkl.zip.
    • Put the raw data files in ./data/mutag/raw.
  • Graph-SST2

    • Raw data files need to be downloaded here, provided by DIG.
    • Unzip the downloaded Graph-SST2.zip.
    • Put the raw data files in ./data/Graph-SST2/raw.
  • MNIST-75sp

    • Raw data files need to be generated following the instruction here.
    • Put the generated files in ./data/mnist/raw.

FAQ

Does GSAT encourage sparsity?

No, GSAT doesn't encourage generating sparse subgraphs. We find r = 0.7 (Eq.(9) in our paper) can generally work well for all datasets in our experiments, which means during training roughly 70% of edges will be kept (kind of still large). This is because GSAT doesn't try to provide interpretability by finding a small/sparse subgraph of the original input graph, which is what previous works normally do and will hurt performance significantly for inhrently interpretable models (as shown in Fig. 7 in the paper). By contrast, GSAT provides interpretability by pushing the critical edges to have relatively lower stochasticity during training.

How to choose the value of r?

A grid search in [0.5, 0.6, 0.7, 0.8, 0.9] is recommended, but r = 0.7 is a good starting point. Note that in practice we would decay the value of r gradually during training from 0.9 to the chosen value.

p or α to implement Eq.(9)?

Recall in Fig. 1, p is the probability of dropping an edge, while α is the sampled result from Bern(p). In our provided implementation, as an empirical choice, α is used to implement Eq.(9) (the Gumbel-softmax trick makes α essentially continuous in practice). We find that when α is used it may provide more regularization and makes the model more robust to hyperparameters. Nonetheless, using p can achieve the same performance, but it needs some more tuning.

Can you show an example of how GSAT works?

Below we show an example from the ba_2motifs dataset, which is to distinguish five-node cycle motifs (left) and house motifs (right). To make good predictions (minimize the cross-entropy loss), GSAT will push the attention weights of those critical edges to be relatively large (ideally close to 1). Otherwise, those critical edges may be dropped too frequently and thus result in a large cross-entropy loss. Meanwhile, to minimize the regularization loss (the KL divergence term in Eq.(9) of the paper), GSAT will push the attention weights of other non-critical edges to be close to r, which is set to be 0.7 in the example. This mechanism of injecting stochasticity makes the learned attention weights from GSAT directly interpretable, since the more critical an edge is, the larger its attention weight will be (the less likely it can be dropped). Note that ba_2motifs satisfies our Thm. 4.1 with no noise, and GSAT achieves perfect interpretation performance on it.

Figure 2. An example of the learned attention weights.

Reference

If you find our paper and repo useful, please cite our paper:

@article{miao2022interpretable,
  title={Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism},
  author={Miao, Siqi and Liu, Miaoyuan and Li, Pan},
  journal={arXiv preprint arXiv:2201.12987},
  year={2022}
}
StyleGAN2-ADA - Official PyTorch implementation

Abstract: Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmenta

NVIDIA Research Projects 3.2k Dec 30, 2022
This is the source code for generating the ASL-Skeleton3D and ASL-Phono datasets. Check out the README.md for more details.

ASL-Skeleton3D and ASL-Phono Datasets Generator The ASL-Skeleton3D contains a representation based on mapping into the three-dimensional space the coo

Cleison Amorim 5 Nov 20, 2022
The official implementation code of "PlantStereo: A Stereo Matching Benchmark for Plant Surface Dense Reconstruction."

PlantStereo This is the official implementation code for the paper "PlantStereo: A Stereo Matching Benchmark for Plant Surface Dense Reconstruction".

Wang Qingyu 14 Nov 28, 2022
The Empirical Investigation of Representation Learning for Imitation (EIRLI)

The Empirical Investigation of Representation Learning for Imitation (EIRLI)

Center for Human-Compatible AI 31 Nov 06, 2022
An official implementation of "Background-Aware Pooling and Noise-Aware Loss for Weakly-Supervised Semantic Segmentation" (CVPR 2021) in PyTorch.

BANA This is the implementation of the paper "Background-Aware Pooling and Noise-Aware Loss for Weakly-Supervised Semantic Segmentation". For more inf

CV Lab @ Yonsei University 59 Dec 12, 2022
This is the implementation of "SELF SUPERVISED REPRESENTATION LEARNING WITH DEEP CLUSTERING FOR ACOUSTIC UNIT DISCOVERY FROM RAW SPEECH" submitted to ICASSP 2022

CPC_DeepCluster This is the implementation of "SELF SUPERVISED REPRESENTATION LEARNING WITH DEEP CLUSTERING FOR ACOUSTIC UNIT DISCOVERY FROM RAW SPEEC

LEAP Lab 2 Sep 15, 2022
Structural Constraints on Information Content in Human Brain States

Structural Constraints on Information Content in Human Brain States Code accompanying the paper "The information content of brain states is explained

Leon Weninger 3 Sep 07, 2022
The code for Bi-Mix: Bidirectional Mixing for Domain Adaptive Nighttime Semantic Segmentation

BiMix The code for Bi-Mix: Bidirectional Mixing for Domain Adaptive Nighttime Semantic Segmentation arxiv Framework: visualization results: Requiremen

stanley 18 Sep 18, 2022
EM-POSE 3D Human Pose Estimation from Sparse Electromagnetic Trackers.

EM-POSE: 3D Human Pose Estimation from Sparse Electromagnetic Trackers This repository contains the code to our paper published at ICCV 2021. For ques

Facebook Research 62 Dec 14, 2022
Rotation Robust Descriptors

RoRD Rotation-Robust Descriptors and Orthographic Views for Local Feature Matching Project Page | Paper link Evaluation and Datasets MMA : Training on

Udit Singh Parihar 25 Nov 15, 2022
A simple consistency training framework for semi-supervised image semantic segmentation

PseudoSeg: Designing Pseudo Labels for Semantic Segmentation PseudoSeg is a simple consistency training framework for semi-supervised image semantic s

Google Interns 143 Dec 13, 2022
A Keras implementation of YOLOv3 (Tensorflow backend)

keras-yolo3 Introduction A Keras implementation of YOLOv3 (Tensorflow backend) inspired by allanzelener/YAD2K. Quick Start Download YOLOv3 weights fro

7.1k Jan 03, 2023
A Blender python script for getting asset browser custom preview images for objects and collections.

asset_snapshot A Blender python script for getting asset browser custom preview images for objects and collections. Installation: Click the code butto

Johnny Matthews 44 Nov 29, 2022
Hyperparameter tuning for humans

KerasTuner KerasTuner is an easy-to-use, scalable hyperparameter optimization framework that solves the pain points of hyperparameter search. Easily c

Keras 2.6k Dec 27, 2022
Quickly comparing your image classification models with the state-of-the-art models (such as DenseNet, ResNet, ...)

Image Classification Project Killer in PyTorch This repo is designed for those who want to start their experiments two days before the deadline and ki

349 Dec 08, 2022
[WWW 2022] Zero-Shot Stance Detection via Contrastive Learning

PT-HCL for Zero-Shot Stance Detection The code of this repository is constantly being updated... Please look forward to it! Introduction This reposito

Akuchi 12 Dec 21, 2022
Lecture materials for Cornell CS5785 Applied Machine Learning (Fall 2021)

Applied Machine Learning (Cornell CS5785, Fall 2021) This repo contains executable course notes and slides for the Applied ML course at Cornell and Co

Volodymyr Kuleshov 103 Dec 31, 2022
This is the code for Compressing BERT: Studying the Effects of Weight Pruning on Transfer Learning

This is the code for Compressing BERT: Studying the Effects of Weight Pruning on Transfer Learning It includes /bert, which is the original BERT repos

Mitchell Gordon 11 Nov 15, 2022
This code is 3d-CNN model that can predict environmental value

Predict-environmental-value-3dCNN This code is 3d-CNN model that can predict environmental value. Firstly, I built a model that can create a lot of bu

1 Jan 06, 2022
A repository with exploration into using transformers to predict DNA ↔ transcription factor binding

Transcription Factor binding predictions with Attention and Transformers A repository with exploration into using transformers to predict DNA ↔ transc

Phil Wang 62 Dec 20, 2022