Code for Understanding Pooling in Graph Neural Networks

Related tags

Deep LearningSRC
Overview

Select, Reduce, Connect

This repository contains the code used for the experiments of:

"Understanding Pooling in Graph Neural Networks"

Setup

Install TensorFlow and other dependencies:

pip install -r requirements.txt

Running experiments

Experiments are found in the following folders:

  • autoencoder/
  • spectral_similarity/
  • graph_classification/

Each folder has a bash script called run_all.sh that will reproduce the results reported in the paper.

To generate the plots and tables that we included in the paper, you can use the plots.py, plots_datasets.py, or tables.py found in the folders.

To run experiments for an individual pooling operator, you can use the run_[OPERATOR NAME].py scripts in each folder.

The pooling operators that we used for the experiments are found in layers/ (trainable) and modules/ (non-trainable). The GNN architectures used in the experiments are found in models/.

The SRCPool class

The core of this repository is the SRCPool class that implements a general interface to create SRC pooling layers with the Keras API.

Our implementation of MinCutPool, DiffPool, LaPool, Top-K, and SAGPool using the SRCPool class can be found in src/layers.

In general, SRC layers compute:

Where is a node equivariant selection function that computes the supernode assignments , is a permutation-invariant function to reduce the supernodes into the new node attributes, and is a permutation-invariant connection function that computes the links between the pooled nodes.

By extending this class, it is possible to create any pooling layer in the SRC framework.

Input

  • X: Tensor of shape ([batch], N, F) representing node features;
  • A: Tensor or SparseTensor of shape ([batch], N, N) representing the adjacency matrix;
  • I: (optional) Tensor of integers with shape (N, ) representing the batch index;

Output

  • X_pool: Tensor of shape ([batch], K, F), representing the node features of the output. K is the number of output nodes and depends on the specific pooling strategy;
  • A_pool: Tensor or SparseTensor of shape ([batch], K, K) representing the adjacency matrix of the output;
  • I_pool: (only if I was given as input) Tensor of integers with shape (K, ) representing the batch index of the output;
  • S_pool: (if return_sel=True) Tensor or SparseTensor representing the supernode assignments;

API

  • pool(X, A, I, **kwargs): pools the graph and returns the reduced node features and adjacency matrix. If the batch index I is not None, a reduced version of I will be returned as well. Any given kwargs will be passed as keyword arguments to select(), reduce() and connect() if any matching key is found. The mandatory arguments of pool() (X, A, and I) must be computed in call() by calling self.get_inputs(inputs).
  • select(X, A, I, **kwargs): computes supernode assignments mapping the nodes of the input graph to the nodes of the output.
  • reduce(X, S, **kwargs): reduces the supernodes to form the nodes of the pooled graph.
  • connect(A, S, **kwargs): connects the reduced supernodes.
  • reduce_index(I, S, **kwargs): helper function to reduce the batch index (only called if I is given as input).

When overriding any function of the API, it is possible to access the true number of nodes of the input (N) as a Tensor in the instance variable self.N (this is populated by self.get_inputs() at the beginning of call()).

Arguments:

  • return_sel: if True, the Tensor used to represent supernode assignments will be returned with X_pool, A_pool, and I_pool;
Owner
Daniele Grattarola
PhD student @ Università della Svizzera italiana
Daniele Grattarola
Base pretrained models and datasets in pytorch (MNIST, SVHN, CIFAR10, CIFAR100, STL10, AlexNet, VGG16, VGG19, ResNet, Inception, SqueezeNet)

This is a playground for pytorch beginners, which contains predefined models on popular dataset. Currently we support mnist, svhn cifar10, cifar100 st

Aaron Chen 2.4k Dec 28, 2022
Use Python, OpenCV, and MediaPipe to control a keyboard with facial gestures

CheekyKeys A Face-Computer Interface CheekyKeys lets you control your keyboard using your face. View a fuller demo and more background on the project

69 Nov 09, 2022
This is the code of paper ``Contrastive Coding for Active Learning under Class Distribution Mismatch'' with python.

Contrastive Coding for Active Learning under Class Distribution Mismatch Official PyTorch implementation of ["Contrastive Coding for Active Learning u

21 Dec 22, 2022
ByteTrack: Multi-Object Tracking by Associating Every Detection Box

ByteTrack ByteTrack is a simple, fast and strong multi-object tracker. ByteTrack: Multi-Object Tracking by Associating Every Detection Box Yifu Zhang,

Yifu Zhang 2.9k Jan 04, 2023
Beancount-mercury - Beancount importer for Mercury Startup Checking

beancount-mercury beancount-mercury provides an Importer for converting CSV expo

Michael Lynch 4 Oct 31, 2022
Unofficial TensorFlow implementation of the Keyword Spotting Transformer model

Keyword Spotting Transformer This is the unofficial TensorFlow implementation of the Keyword Spotting Transformer model. This model is used to train o

Intelligent Machines Limited 8 May 11, 2022
Updated for TTS(CE) = Also Known as TTN V3. The code requires the first server to be 'ttn' protocol.

Updated Updated for TTS(CE) = Also Known as TTN V3. The code requires the first server to be 'ttn' protocol. Introduction This balenaCloud (previously

Remko 1 Oct 17, 2021
OpenCV, MediaPipe Pose Estimation, Affine Transform for Icon Overlay

Yoga Pose Identification and Icon Matching Project Goal Detect yoga poses performed by a user and overlay a corresponding icon image. Running the main

Anna Garverick 1 Dec 03, 2021
Python Implementation of the CoronaWarnApp (CWA) Event Registration

Python implementation of the Corona-Warn-App (CWA) Event Registration This is an implementation of the Protocol used to generate event and location QR

MaZderMind 17 Oct 05, 2022
Building Ellee — A GPT-3 and Computer Vision Powered Talking Robotic Teddy Bear With Human Level Conversation Intelligence

Using an object detection and facial recognition system built on MobileNetSSDV2 and Dlib and running on an NVIDIA Jetson Nano, a GPT-3 model, Google Speech Recognition, Amazon Polly and servo motors,

24 Oct 26, 2022
[ICCV'21] Learning Conditional Knowledge Distillation for Degraded-Reference Image Quality Assessment

CKDN The official implementation of the ICCV2021 paper "Learning Conditional Knowledge Distillation for Degraded-Reference Image Quality Assessment" O

Multimedia Research 50 Dec 13, 2022
Official code for "Stereo Waterdrop Removal with Row-wise Dilated Attention (IROS2021)"

Stereo-Waterdrop-Removal-with-Row-wise-Dilated-Attention This repository includes official codes for "Stereo Waterdrop Removal with Row-wise Dilated A

29 Oct 01, 2022
An efficient framework for reinforcement learning.

rl: An efficient framework for reinforcement learning Requirements Introduction PPO Test Requirements name version Python =3.7 numpy =1.19 torch =1

16 Nov 30, 2022
Urban mobility simulations with Python3, RLlib (Deep Reinforcement Learning) and Mesa (Agent-based modeling)

Deep Reinforcement Learning for Smart Cities Documentation RLlib: https://docs.ray.io/en/master/rllib.html Mesa: https://mesa.readthedocs.io/en/stable

1 May 15, 2022
GAN-STEM-Conv2MultiSlice - Exploring Generative Adversarial Networks for Image-to-Image Translation in STEM Simulation

GAN-STEM-Conv2MultiSlice GAN method to help covert lower resolution STEM images generated by convolution methods to higher resolution STEM images gene

UW-Madison Computational Materials Group 2 Feb 10, 2021
CondLaneNet: a Top-to-down Lane Detection Framework Based on Conditional Convolution

CondLaneNet: a Top-to-down Lane Detection Framework Based on Conditional Convolution This is the official implementation code of the paper "CondLaneNe

Alibaba Cloud 311 Dec 30, 2022
Controlling the MicriSpotAI robot from scratch

Abstract: The SpotMicroAI project is designed to be a low cost, easily built quadruped robot. The design is roughly based off of Boston Dynamics quadr

Florian Wilk 405 Jan 05, 2023
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Christoph Reich 23 Sep 21, 2022
A cool little repl-based simulation written in Python

A cool little repl-based simulation written in Python planned to integrate machine-learning into itself to have AI battle to the death before your eye

Em 6 Sep 17, 2022
Jupyter notebooks for the code samples of the book "Deep Learning with Python"

Jupyter notebooks for the code samples of the book "Deep Learning with Python"

François Chollet 16.2k Dec 30, 2022