How to train a CNN to 99% accuracy on MNIST in less than a second on a laptop

Overview

Training a NN to 99% accuracy on MNIST in 0.76 seconds

A quick study on how fast you can reach 99% accuracy on MNIST with a single laptop. Our answer is 0.76 seconds, reaching 99% accuracy in just one epoch of training. This is more than 200 times faster than the default training code from Pytorch. To see the final results, check 8_Final_00s76.ipynb. If you're interested in the process read on below for a step by step description of changes made.

The repo is organized into jupyter notebooks, showing a chronological order of changes required to go from initial Pytorch tutorial that trains for 3 minutes to less than a second of training time on a laptop with GeForce GTX 1660 Ti GPU. I aimed for a coordinate ascent like procedure, changing only one thing at a time to make sure we understand what is the source of improvements each time, but sometimes I bunched up correlated or small changes.

Requirements

Python3.x and Pytorch 1.8 (most likely works with >= 1.3). For fast times you'll need Cuda and a compatible GPU.

0_Pytorch_initial_2m_52s.ipynb: Starting benchmark

First we need to benchmark starting performance. This can be found in the file 0_Pytorch_initial_2m_52s.ipynb. Note the code downloads the dataset if not already present so reporting second run time. Trains for 14 epochs each run, average accuracy of two runs is 99.185% on test set, and the mean runtime is 2min 52s ± 38.1ms.

1_Early_stopping_57s40.ipynb: Stop early

Since our goal is to reach only 99% accuracy, we don't need the full training time. Our first modification is to simply stop training after the epoch we hit 99% test accuracy. This is typically reached within 3-5 epochs with average final accuracy of 99.07%, cutting training time to around a third of the original at 57.4s ± 6.85s.

2_Smaller_NN_30s30.ipynb: Reduce network size

Next we employ the trick of reducing both network size and regularization to speed up convergence. This is done by adding a 2x2 max pool layer after the first conv layer, reducing parameters in our fully connected layers by more than 4x. To compensate we also remove one of the 2 dropout layers. This reduces number of epochs we need to converge to 2-3, and training time to 30.3s ± 5.28s.

3_Data_loading_07s31.ipynb: Optimize Data Loading!

This is probably the biggest and most surprising time save of this project. Just by better optimizing the data loading process we can save 75% of the entire training run time. It turns out that torch.utils.data.DataLoader is really inefficient for small datasets like MNIST, instead of reading it from the disk one batch at a time we can simply load the entire dataset into GPU memory at once and keep it there. To do this we save the entire dataset with the same processing we had before onto disk in a single pytorch array using data_loader.save_data(). This takes around 10s and is not counted in the training time as it has to be done only once. With this optimization, our average training time goes down to 7.31s ± 1.36s.

4_128_Batch_04s66.ipynb: Increase batch size

Now that we have optimized data loading, increasing batch size can significantly increase the speed of training. Simply increasing the batch size from 64 to 128 reduces our average train time to 4.66s ± 583ms.

5_Onecycle_lr_03s14.ipynb: Better learning rate schedule

For this step, we turn our looks to to the learning rate schedule. Previously we used an exponential decay where after each epoch lr is multiplied by 0.7. We replace this by Superconvergence also known as OneCycleLR, where the learning starts close to 0 and is linearly(or with cosine schedule) increased to to its peak value at the middle of training and slowly lowered down to zero again in the end. This allows using much higher learning rates than otherwise. We used peak LR of 4.0, 4 times higher than the starting lr used previously. The network reaches 99% in 2 epochs every time now, and this takes our training time down to 3.14s ± 4.72ms.

6_256_Batch_02s31.ipynb: Increase batch size, again

With our better lr schedule we can once more double our batch size without hurting performance much. Note this time around it doesn't reach 99% on all random seeds but I count it as a success as long I'm confident the mean accuracy is greater than 99%. This is because Superconvergence requires a fixed length training and we can't quarantee every seed works. This cuts our training time down to 2.31s ± 23.2ms.

7_Smaller_NN2_01s74.ipynb: Remove dropout and reduce size, again

Next we repeat our procedure from step 2 once again, remove the remaning dropout layer and compensate by reducing the width of our convolutional layers, first to 24 from 32 and second to 32 from 64. This reduces the time to train an epoch, and even nets us with increased accuracy, averaging around 99.1% after two epochs of training. This gives us mean time of 1.74s ± 18.3ms.

8_Final_00s76.ipynb: Tune everything

Now that we have a fast working model and we have grabbed most of the low hanging improvements, it is time to dive into final finetuning. To start off, we simply move our max pool operations before the ReLU activation, which doesn't change the network but saves us a bit of compute.

The next changes were the result of a large search operation, where I tried a number of different things, optimizing one hyperparameter at a time. For each change I trained on 30 different seeds and measured what gives us the highest mean accuracy. 30 seeds was necessary to make statistically significant conclusions on small changes, and it is worth noting training 30 seeds took less than a minute at this point. Higher accuracy can then be translated into faster times by cutting down on the number of epochs.

First I actually made the network bigger in select places that didn't slow down performance too much. The kernel size of first convolutional layer was incresed from 3 to 5, and the final fully connected layer increased from 128 to 256.

Next, it was time to change the optimizer. I found that with proper hyperparameters, Adam actually outperforms Adadelta which we had used so far. The hyperparameters I changed from default are learning rate of 0.01(default 0.001), beta1 of 0.7(default 0.9) and bata2 of 0.9(default 0.999).

All of this lead to a large boost in accuracy(99.245% accuracy after 2 epochs), which I was able to finally trade into faster training times by cutting training down to just one epoch! Our final result is 99.04% mean accuracy in just 762ms ± 24.9ms.

Owner
Tuomas Oikarinen
PhD student at UC San Diego, trying to understand ML and hopefully make it more safe. Previously @MIT.
Tuomas Oikarinen
StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation

StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation Demo video: CVPR 2021 Oral: Single Channel Manipulation: Localized or attribu

Zongze Wu 267 Dec 30, 2022
TRACER: Extreme Attention Guided Salient Object Tracing Network implementation in PyTorch

TRACER: Extreme Attention Guided Salient Object Tracing Network This paper was accepted at AAAI 2022 SA poster session. Datasets All datasets are avai

Karel 118 Dec 29, 2022
RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering

RNG-KBQA: Generation Augmented Iterative Ranking for Knowledge Base Question Answering Authors: Xi Ye, Semih Yavuz, Kazuma Hashimoto, Yingbo Zhou and

Salesforce 72 Dec 05, 2022
Visual Question Answering in Pytorch

Visual Question Answering in pytorch /!\ New version of pytorch for VQA available here: https://github.com/Cadene/block.bootstrap.pytorch This repo wa

Remi 672 Jan 01, 2023
[AI6122] Text Data Management & Processing

[AI6122] Text Data Management & Processing is an elective course of MSAI, SCSE, NTU, Singapore. The repository corresponds to the AI6122 of Semester 1, AY2021-2022, starting from 08/2021. The instruc

HT. Li 1 Jan 17, 2022
Fast SHAP value computation for interpreting tree-based models

FastTreeSHAP FastTreeSHAP package is built based on the paper Fast TreeSHAP: Accelerating SHAP Value Computation for Trees published in NeurIPS 2021 X

LinkedIn 369 Jan 04, 2023
Package for extracting emotions from social media text. Tailored for financial data.

EmTract: Extracting Emotions from Social Media Text Tailored for Financial Contexts EmTract is a tool that extracts emotions from social media text. I

13 Nov 17, 2022
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

Amin Rezaei 126 Dec 27, 2022
TGRNet: A Table Graph Reconstruction Network for Table Structure Recognition

TGRNet: A Table Graph Reconstruction Network for Table Structure Recognition Xue, Wenyuan, et al. "TGRNet: A Table Graph Reconstruction Network for Ta

Wenyuan 68 Jan 04, 2023
这是一个mobilenet-yolov4-lite的库,把yolov4主干网络修改成了mobilenet,修改了Panet的卷积组成,使参数量大幅度缩小。

YOLOV4:You Only Look Once目标检测模型-修改mobilenet系列主干网络-在Keras当中的实现 2021年2月8日更新: 加入letterbox_image的选项,关闭letterbox_image后网络的map一般可以得到提升。

Bubbliiiing 65 Dec 01, 2022
Streaming over lightweight data transformations

Description Data augmentation libarary for Deep Learning, which supports images, segmentation masks, labels and keypoints. Furthermore, SOLT is fast a

Research Unit of Medical Imaging, Physics and Technology 256 Jan 08, 2023
A pytorch implementation of MBNET: MOS PREDICTION FOR SYNTHESIZED SPEECH WITH MEAN-BIAS NETWORK

Pytorch-MBNet A pytorch implementation of MBNET: MOS PREDICTION FOR SYNTHESIZED SPEECH WITH MEAN-BIAS NETWORK Training To train a new model, please ru

46 Dec 28, 2022
A state-of-the-art semi-supervised method for image recognition

Mean teachers are better role models Paper ---- NIPS 2017 poster ---- NIPS 2017 spotlight slides ---- Blog post By Antti Tarvainen, Harri Valpola (The

Curious AI 1.4k Jan 06, 2023
Open AI's Python library

OpenAI Python Library The OpenAI Python library provides convenient access to the OpenAI API from applications written in the Python language. It incl

Pavan Ananth Sharma 3 Jul 10, 2022
QuALITY: Question Answering with Long Input Texts, Yes!

QuALITY: Question Answering with Long Input Texts, Yes! Authors: Richard Yuanzhe Pang,* Alicia Parrish,* Nitish Joshi,* Nikita Nangia, Jason Phang, An

ML² AT CILVR 61 Jan 02, 2023
DECA: Detailed Expression Capture and Animation (SIGGRAPH 2021)

DECA: Detailed Expression Capture and Animation (SIGGRAPH2021) input image, aligned reconstruction, animation with various poses & expressions This is

Yao Feng 1.5k Jan 02, 2023
PyTorch version implementation of DORN

DORN_PyTorch This is a PyTorch version implementation of DORN Reference H. Fu, M. Gong, C. Wang, K. Batmanghelich and D. Tao: Deep Ordinal Regression

Zilin.Zhang 3 Apr 27, 2022
Solver for Large-Scale Rank-One Semidefinite Relaxations

STRIDE: spectrahedral proximal gradient descent along vertices A Solver for Large-Scale Rank-One Semidefinite Relaxations About STRIDE is designed for

48 Dec 20, 2022
CN24 is a complete semantic segmentation framework using fully convolutional networks

Build status: master (production branch): develop (development branch): Welcome to the CN24 GitHub repository! CN24 is a complete semantic segmentatio

Computer Vision Group Jena 123 Jul 14, 2022
"Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices", official implementation

Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices This repository contains the official PyTorch implemen

Yandex Research 21 Oct 18, 2022