Public Code for NIPS submission SimiGrad: Fine-Grained Adaptive Batching for Large ScaleTraining using Gradient Similarity Measurement

Related tags

Deep LearningSimiGrad
Overview

Public code for NIPS submission "SimiGrad: Fine-Grained Adaptive Batching for Large Scale Training using Gradient Similarity Measurement"

This repo contains both our SimiGrad framework (integrated with DeepSpeed) and all training codes used to generate the results in the paper.

Installation

Please use ./DeepSpeed/install.sh to install our SimiGrad framework. For detailed installation options please see ./DeepSpeed/install.sh . It is recommended that you use a virtual environment to install SimiGrad.

Usage

To use SimiGrad, simply add an additional parameter adaptive_batch_params when initializing DeepSpeed. For example,

model, optimizer, _, _ = deepspeed.initialize(
        args=...,
        model=...,
        model_parameters=...,
        adaptive_batch_params={
            "enable_adjust": args.similarity_target, # bool, set to `True` to use adaptive batch size and `False` for fixed batch size
            "verbose": True, # bool, set to `True` to print details of batch size adjustment
            "similarity_target":args.similarity_target, # float, -1.0~1.0, the similarity target that controls how aggressive the batch size adjustment is.
            "batch_size_lower_bound":args.batchsize_lower_bound, # int, optional, the lower bound of batch size. Recommended only if you have a well-tuned warmup learning rate scheduling.
            "batch_size_upper_bound":args.batchsize_upper_bound, # int, optional, the upper bound of batch size.
            "max_micro_batch_size":args.max_micro_batch_size, # int, optional, the upper bound of micro batch size to prevent out-of-memory error. If unspecified, the initial micro batch size will be used as the max_micro_batch_size.})

Please refer to our code (e.g. DeepSpeedExamples/pytorch-cifar/main.py) for details such as how to read the metrics from the framework.

For usage of DeepSpeed, please refer to their website https://www.deepspeed.ai/

Reproduce Paper's Results

The parameters we used to get the claimed results are included in the paper.

BERT Large Pretrain

All scripts can be found in DeepSpeedExamples/bert_pretrain/. Please use the script ds_train_bert_bsz64k_seq128.sh for BERT Large pretrain with sequence length 128 (epoch 1-150). You need to specify the parameters like similarity_target and also the location of the WikiandBookCorpus dataset in the script.

After the sequence length 128 pretrain, use ds_train_bert_bsz32k_seq512.sh to finish the sequence length 512 part of pretrain (epoch 151-170). You need to specify the checkpoint from sequence length 128 pretrain for the sequence length 512 to start with. Then the BERT Large model is ready for downstream tasks.

SQuAD Score from BERT Large Pretrain

After the BERT pretrain, use DeepSpeedExamples/BingBertSquad/run_squad_deepspeed.sh to get the SQuAD 1.1 score. You need to specify the checkpoint from sequence length 512 pretrain and the location of SQuAD 1.1 dataset.

ResNet18 on CIFAR10

All scripts can be found in DeepSpeedExamples/pytorch-cifar/. Use the script run.sh to train ResNet18 with specific parameters. Use the grid_search.py and baseline_grid_search.py to get the Pareto results of test acc vs. batch size in the paper.

ResNet50 on ImageNet

All scripts can be found in DeepSpeedExamples/imagenet_deepspeed/. Use the script run_with2kmin.sh to train ResNet50 with spcific parameters.

Future of SimiGrad

SimiGrad will be officially integrated as part of DeepSpeed soon!

Owner
Heyang Qin
Heyang Qin
Open-source codebase for EfficientZero, from "Mastering Atari Games with Limited Data" at NeurIPS 2021.

EfficientZero (NeurIPS 2021) Open-source codebase for EfficientZero, from "Mastering Atari Games with Limited Data" at NeurIPS 2021. Thank you for you

Weirui Ye 671 Jan 03, 2023
A hifiasm fork for metagenome assembly using Hifi reads.

hifiasm_meta - de novo metagenome assembler, based on hifiasm, a haplotype-resolved de novo assembler for PacBio Hifi reads.

44 Jul 10, 2022
Style-based Point Generator with Adversarial Rendering for Point Cloud Completion (CVPR 2021)

Style-based Point Generator with Adversarial Rendering for Point Cloud Completion (CVPR 2021) An efficient PyTorch library for Point Cloud Completion.

Microsoft 119 Jan 02, 2023
This repository contains the files for running the Patchify GUI.

Repository Name Train-Test-Validation-Dataset-Generation App Name Patchify Description This app is designed for crop images and creating smal

Salar Ghaffarian 9 Feb 15, 2022
RaftMLP: How Much Can Be Done Without Attention and with Less Spatial Locality?

RaftMLP RaftMLP: How Much Can Be Done Without Attention and with Less Spatial Locality? By Yuki Tatsunami and Masato Taki (Rikkyo University) [arxiv]

Okojo 20 Aug 31, 2022
The NEOSSat is a dual-mission microsatellite designed to detect potentially hazardous Earth-orbit-crossing asteroids and track objects that reside in deep space

The NEOSSat is a dual-mission microsatellite designed to detect potentially hazardous Earth-orbit-crossing asteroids and track objects that reside in deep space

John Salib 2 Jan 30, 2022
Pytorch Implementation for Dilated Continuous Random Field

DilatedCRF Pytorch implementation for fully-learnable DilatedCRF. If you find my work helpful, please consider our paper: @article{Mo2022dilatedcrf,

DunnoCoding_Plus 3 Nov 13, 2022
PyTorch implementation of adversarial patch

adversarial-patch PyTorch implementation of adversarial patch This is an implementation of the Adversarial Patch paper. Not official and likely to hav

Jamie Hayes 172 Nov 29, 2022
Code for "Modeling Indirect Illumination for Inverse Rendering", CVPR 2022

Modeling Indirect Illumination for Inverse Rendering Project Page | Paper | Data Preparation Set up the python environment conda create -n invrender p

ZJU3DV 116 Jan 03, 2023
Parametric Contrastive Learning (ICCV2021)

Parametric-Contrastive-Learning This repository contains the implementation code for ICCV2021 paper: Parametric Contrastive Learning (https://arxiv.or

DV Lab 156 Dec 21, 2022
PINN(s): Physics-Informed Neural Network(s) for von Karman vortex street

PINN(s): Physics-Informed Neural Network(s) for von Karman vortex street This is

ShotaDEGUCHI 2 Apr 18, 2022
PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation

PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation Created by Charles R. Qi, Hao Su, Kaichun Mo, Leonidas J. Guibas from Sta

Charles R. Qi 4k Dec 30, 2022
A PyTorch implementation of "ANEMONE: Graph Anomaly Detection with Multi-Scale Contrastive Learning", CIKM-21

ANEMONE A PyTorch implementation of "ANEMONE: Graph Anomaly Detection with Multi-Scale Contrastive Learning", CIKM-21 Dependencies python==3.6.1 dgl==

Graph Analysis & Deep Learning Laboratory, GRAND 30 Dec 14, 2022
Group project for MFIN7036. Our goal is to predict firm profitability with text-based competition measures.

NLP_0-project Group project for MFIN7036. Our goal is to predict firm profitability with text-based competition measures1. We are a "democratic" and c

3 Mar 16, 2022
All public open-source implementations of convnets benchmarks

convnet-benchmarks Easy benchmarking of all public open-source implementations of convnets. A summary is provided in the section below. Machine: 6-cor

Soumith Chintala 2.7k Dec 30, 2022
Algorithmic trading with deep learning experiments

Deep-Trading Algorithmic trading with deep learning experiments. Now released part one - simple time series forecasting. I plan to implement more soph

Alex Honchar 1.4k Jan 02, 2023
Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction

GraviCap Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction. Gravity-Aware Monocular 3D Human-Object

Rishabh Dabral 15 Dec 09, 2022
PyBrain - Another Python Machine Learning Library.

PyBrain -- the Python Machine Learning Library =============================================== INSTALLATION ------------ Quick answer: make sure you

2.8k Dec 31, 2022
Multispectral Object Detection with Yolov5

Multispectral-Object-Detection Intro Official Code for Cross-Modality Fusion Transformer for Multispectral Object Detection. Multispectral Object Dete

Richard Fang 121 Jan 01, 2023
XViT - Space-time Mixing Attention for Video Transformer

XViT - Space-time Mixing Attention for Video Transformer This is the official implementation of the XViT paper: @inproceedings{bulat2021space, title

Adrian Bulat 33 Dec 23, 2022