Code for "Training Neural Networks with Fixed Sparse Masks" (NeurIPS 2021).

Related tags

Deep LearningFISH
Overview

Fisher Induced Sparse uncHanging (FISH) Mask

This repo contains the code for Fisher Induced Sparse uncHanging (FISH) Mask training, from "Training Neural Networks with Fixed Sparse Masks" by Yi-Lin Sung, Varun Nair, and Colin Raffel. To appear in Neural Information Processing Systems (NeurIPS) 2021.

Abstract: During typical gradient-based training of deep neural networks, all of the model's parameters are updated at each iteration. Recent work has shown that it is possible to update only a small subset of the model's parameters during training, which can alleviate storage and communication requirements. In this paper, we show that it is possible to induce a fixed sparse mask on the model’s parameters that selects a subset to update over many iterations. Our method constructs the mask out of the parameters with the largest Fisher information as a simple approximation as to which parameters are most important for the task at hand. In experiments on parameter-efficient transfer learning and distributed training, we show that our approach matches or exceeds the performance of other methods for training with sparse updates while being more efficient in terms of memory usage and communication costs.

Setup

pip install transformers/.
pip install datasets torch==1.8.0 tqdm torchvision==0.9.0

FISH Mask: GLUE Experiments

Parameter-Efficient Transfer Learning

To run the FISH Mask on a GLUE dataset, code can be run with the following format:

$ bash transformers/examples/text-classification/scripts/run_sparse_updates.sh <dataset-name> <seed> <top_k_percentage> <num_samples_for_fisher>

An example command used to generate Table 1 in the paper is as follows, where all GLUE tasks are provided at a seed of 0 and a FISH mask sparsity of 0.5%.

$ bash transformers/examples/text-classification/scripts/run_sparse_updates.sh "qqp mnli rte cola stsb sst2 mrpc qnli" 0 0.005 1024

Distributed Training

To use the FISH mask on the GLUE tasks in a distributed setting, one can use the following command.

$ bash transformers/examples/text-classification/scripts/distributed_training.sh <dataset-name> <seed> <num_workers> <training_epochs> <gpu_id>

Note the <dataset-name> here can only contain one task, so an example command could be

$ bash transformers/examples/text-classification/scripts/distributed_training.sh "mnli" 0 2 3.5 0

FISH Mask: CIFAR10 Experiments

To run the FISH mask on CIFAR10, code can be run with the following format:

Distributed Training

$ bash cifar10-fast/scripts/distributed_training_fish.sh <num_samples_for_fisher> <top_k_percentage> <training_epochs> <worker_updates> <learning_rate> <num_workers>

For example, in the paper, we compute the FISH mask of the 0.5% sparsity level by 256 samples and distribute the job to 2 workers for a total of 50 epochs training. Then the command would be

$ bash cifar10-fast/scripts/distributed_training_fish.sh 256 0.005 50 2 0.4 2

Efficient Checkpointing

$ bash cifar10-fast/scripts/small_checkpoints_fish.sh <num_samples_for_fisher> <top_k_percentage> <training_epochs> <learning_rate> <fix_mask>

The hyperparameters are almost the same as distributed training. However, the <fix_mask> is to indicate to fix the mask or not, and a valid input is either 0 or 1 (1 means to fix the mask).

Replicating Results

Replicating each of the tables and figures present in the original paper can be done by running the following:

# Table 1 - Parameter Efficient Fine-Tuning on GLUE

$ bash transformers/examples/text-classification/scripts/run_table_1.sh
# Figure 2 - Mask Sparsity Ablation and Sample Ablation

$ bash transformers/examples/text-classification/scripts/run_figure_2.sh
# Table 2 - Distributed Training on GLUE

$ bash transformers/examples/text-classification/scripts/run_table_2.sh
# Table 3 - Distributed Training on CIFAR10

$ bash cifar10-fast/scripts/distributed_training.sh

# Table 4 - Efficient Checkpointing

$ bash cifar10-fast/scripts/small_checkpoints.sh

Notes

  • For reproduction of Diff Pruning results from Table 1, see code here.

Acknowledgements

We thank Yoon Kim, Michael Matena, and Demi Guo for helpful discussions.

Owner
Varun Nair
Hi! I'm a student at Duke University studying CS. I'm interested in researching AI/ML and its applications in medicine, transportation, & education.
Varun Nair
Start-to-finish tutorial for interactive music co-creation in PyTorch and Tensorflow.js

Start-to-finish tutorial for interactive music co-creation in PyTorch and Tensorflow.js

Chris Donahue 98 Dec 14, 2022
This repository gives an example on how to preprocess the data of the HECKTOR challenge

HECKTOR 2021 challenge This repository gives an example on how to preprocess the data of the HECKTOR challenge. Any other preprocessing is welcomed an

56 Dec 01, 2022
Stochastic Scene-Aware Motion Prediction

Stochastic Scene-Aware Motion Prediction [Project Page] [Paper] Description This repository contains the training code for MotionNet and GoalNet of SA

Mohamed Hassan 31 Dec 09, 2022
Run Effective Large Batch Contrastive Learning on Limited Memory GPU

Gradient Cache Gradient Cache is a simple technique for unlimitedly scaling contrastive learning batch far beyond GPU memory constraint. This means tr

Luyu Gao 198 Dec 29, 2022
A python script to dump all the challenges locally of a CTFd-based Capture the Flag.

A python script to dump all the challenges locally of a CTFd-based Capture the Flag. Features Connects and logins to a remote CTFd instance. Dumps all

Podalirius 77 Dec 07, 2022
Additional code for Stable-baselines3 to load and upload models from the Hub.

Hugging Face x Stable-baselines3 A library to load and upload Stable-baselines3 models from the Hub. Installation With pip Examples [Todo: add colab t

Hugging Face 34 Dec 10, 2022
Automatic Data-Regularized Actor-Critic (Auto-DrAC)

Auto-DrAC: Automatic Data-Regularized Actor-Critic This is a PyTorch implementation of the methods proposed in Automatic Data Augmentation for General

89 Dec 13, 2022
In this project I played with mlflow, streamlit and fastapi to create a training and prediction app on digits

Fastapi + MLflow + streamlit Setup env. I hope I covered all. pip install -r requirements.txt Start app Go in the root dir and run these Streamlit str

76 Nov 23, 2022
Code of PVTv2 is released! PVTv2 largely improves PVTv1 and works better than Swin Transformer with ImageNet-1K pre-training.

Updates (2020/06/21) Code of PVTv2 is released! PVTv2 largely improves PVTv1 and works better than Swin Transformer with ImageNet-1K pre-training. Pyr

1.3k Jan 04, 2023
Heterogeneous Deep Graph Infomax

Heterogeneous-Deep-Graph-Infomax Parameter Setting: HDGI-A: Node-level dimension: 16 Attention head: 4 Semantic-level attention vector: 8 learning rat

52 Oct 31, 2022
Project dự đoán giá cổ phiếu bằng thuật toán LSTM gồm: code train và code demo

Web predicts stock prices using Long - Short Term Memory algorithm Give me some start please!!! User interface image: Choose: DayBegin, DayEnd, Stock

Vo Thuong Truong Nhon 8 Nov 11, 2022
Text Extraction Formulation + Feedback Loop for state-of-the-art WSD (EMNLP 2021)

ConSeC is a novel approach to Word Sense Disambiguation (WSD), accepted at EMNLP 2021. It frames WSD as a text extraction task and features a feedback loop strategy that allows the disambiguation of

Sapienza NLP group 36 Dec 13, 2022
Conflict-aware Inference of Python Compatible Runtime Environments with Domain Knowledge Graph, ICSE 2022

PyCRE Conflict-aware Inference of Python Compatible Runtime Environments with Domain Knowledge Graph, ICSE 2022 Dependencies This project is developed

<a href=[email protected]"> 7 May 06, 2022
Ground truth data for the Optical Character Recognition of Historical Classical Commentaries.

OCR Ground Truth for Historical Commentaries The dataset OCR ground truth for historical commentaries (GT4HistComment) was created from the public dom

Ajax Multi-Commentary 3 Sep 08, 2022
[AAAI2021] The source code for our paper 《Enhancing Unsupervised Video Representation Learning by Decoupling the Scene and the Motion》.

DSM The source code for paper Enhancing Unsupervised Video Representation Learning by Decoupling the Scene and the Motion Project Website; Datasets li

Jinpeng Wang 114 Oct 16, 2022
A no-BS, dead-simple training visualizer for tf-keras

A no-BS, dead-simple training visualizer for tf-keras TrainingDashboard Plot inter-epoch and intra-epoch loss and metrics within a jupyter notebook wi

Vibhu Agrawal 3 May 28, 2021
Find the Heart simple Python Game

This is a simple Python game for finding a heart emoji. There is a 3 x 3 matrix in which a heart emoji resides. The location of the heart is randomized and is not revealed. The player must guess the

p.katekomol 1 Jan 24, 2022
Camera-caps - Examine the camera capabilities for V4l2 cameras

camera-caps This is a graphical user interface over the v4l2-ctl command line to

Jetsonhacks 25 Dec 26, 2022
Interpretable-contrastive-word-mover-s-embedding

Interpretable-contrastive-word-mover-s-embedding Paper Datasets Here is a Dropbox link to the datasets used in the paper: https://www.dropbox.com/sh/n

0 Nov 02, 2021