(Preprint) Official PyTorch implementation of "How Do Vision Transformers Work?"

Overview

How Do Vision Transformers Work?

This repository provides a PyTorch implementation of "How Do Vision Transformers Work?" In the paper, we show that multi-head self-attentions (MSAs) for computer vision is NOT for capturing long-range dependency. In particular, we address the following three key questions of MSAs and Vision Transformers (ViTs):

  1. What properties of MSAs do we need to better optimize NNs? Do the long-range dependencies of MSAs help NNs learn?
  2. Do MSAs act like Convs? If not, how are they different?
  3. How can we harmonize MSAs with Convs? Can we just leverage their advantages?

We demonstrate that (1) MSAs flatten the loss landscapes, (2) MSA and Convs are complementary because MSAs are low-pass filters and convolutions (Convs) are high-pass filter, and (3) MSAs at the end of a stage significantly improve the accuracy.

Let's find the detailed answers below!

I. What Properties of MSAs Do We Need to Improve Optimization?

MSAs improve not only accuracy but also generalization by flattening the loss landscapes. Such improvement is primarily attributable to their data specificity, NOT long-range dependency 😱 Their weak inductive bias disrupts NN training. On the other hand, ViTs suffers from non-convex losses. MSAs allow negative Hessian eigenvalues in small data regimes. Large datasets and loss landscape smoothing methods alleviate this problem.

II. Do MSAs Act Like Convs?

MSAs and Convs exhibit opposite behaviors. For example, MSAs are low-pass filters, but Convs are high-pass filters. In addition, Convs are vulnerable to high-frequency noise but that MSAs are not. Therefore, MSAs and Convs are complementary.

III. How Can We Harmonize MSAs With Convs?

Multi-stage neural networks behave like a series connection of small individual models. In addition, MSAs at the end of a stage play a key role in prediction. Based on these insights, we propose design rules to harmonize MSAs with Convs. NN stages using this design pattern consists of a number of CNN blocks and one (or a few) MSA block. The design pattern naturally derives the structure of canonical Transformer, which has one MLP block for one MSA block.


In addition, we also introduce AlterNet, a model in which Conv blocks at the end of a stage are replaced with MSA blocks. Surprisingly, AlterNet outperforms CNNs not only in large data regimes but also in small data regimes. This contrasts with canonical ViTs, models that perform poorly on small amounts of data.

This repository is based on the official implementation of "Blurs Make Results Clearer: Spatial Smoothings to Improve Accuracy, Uncertainty, and Robustness". In this paper, we show that a simple (non-trainable) 2 βœ• 2 box blur filter improves accuracy, uncertainty, and robustness simultaneously by ensembling spatially nearby feature maps of CNNs. MSA is not simply generalized Conv, but rather a generalized (trainable) blur filter that complements Conv. Please check it out!

Getting Started

The following packages are required:

  • pytorch
  • matplotlib
  • notebook
  • ipywidgets
  • timm
  • einops
  • tensorboard
  • seaborn (optional)

We mainly use docker images pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime for the code.

See classification.ipynb for image classification. Run all cells to train and test models on CIFAR-10, CIFAR-100, and ImageNet.

Metrics. We provide several metrics for measuring accuracy and uncertainty: Acuracy (Acc, ↑) and Acc for 90% certain results (Acc-90, ↑), negative log-likelihood (NLL, ↓), Expected Calibration Error (ECE, ↓), Intersection-over-Union (IoU, ↑) and IoU for certain results (IoU-90, ↑), Unconfidence (Unc-90, ↑), and Frequency for certain results (Freq-90, ↑). We also define a method to plot a reliability diagram for visualization.

Models. We provide AlexNet, VGG, pre-activation VGG, ResNet, pre-activation ResNet, ResNeXt, WideResNet, ViT, PiT, Swin, MLP-Mixer, and Alter-ResNet by default.

Visualizing the Loss Landscapes

Refer to losslandscape.ipynb for exploring the loss landscapes. It requires a trained model. Run all cells to get predictive performance of the model for weight space grid. We provide a sample loss landscape result.

Evaluating Robustness on Corrupted Datasets

Refer to robustness.ipynb for evaluation corruption robustness on corrupted datasets such as CIFAR-10-C and CIFAR-100-C. It requires a trained model. Run all cells to get predictive performance of the model on datasets which consist of data corrupted by 15 different types with 5 levels of intensity each. We provide a sample robustness result.

How to Apply MSA to Your Own Model

We find that MSA complements Conv (not replaces Conv), and MSA closer to the end of stage improves predictive performance significantly. Based on these insights, we propose the following build-up rules:

  1. Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
  2. If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA
  3. Use more heads and higher hidden dimensions for MSA blocks in late stages.

In the animation above, we replace Convs of ResNet with MSAs one by one according to the build-up rules. Note that several MSAs in c3 harm the accuracy, but the MSA at the end of c2 improves it. As a result, surprisingly, the model with MSAs following the appropriate build-up rule outperforms CNNs even in the small data regime, e.g., CIFAR!

Caution: Investigate Loss Landscapes and Hessians With l2 Regularization on Augmented Datasets

Two common mistakes ⚠️ are investigating loss landscapes and Hessians (1) 'without considering l2 regularization' on (2) 'clean datasets'. However, note that NNs are optimized with l2 regularization on augmented datasets. Therefore, it is appropriate to visualize 'NLL + l2' on 'augmented datasets'. Measuring criteria without l2 on clean dataset would give incorrect (even opposite) results.

Citation

If you find this useful, please consider citing πŸ“‘ the paper and starring 🌟 this repository. Please do not hesitate to contact Namuk Park (email: namuk.park at gmail dot com, twitter: xxxnell) with any comments or feedback.

BibTex is TBD.

License

All code is available to you under Apache License 2.0. CNN models build off the torchvision models which are BSD licensed. ViTs build off the PyTorch Image Models and Vision Transformer - Pytorch which are Apache 2.0 and MIT licensed.

Copyright the maintainers.

Owner
xxxnell
Programmer & ML researcher
xxxnell
GyroSPD: Vector-valued Distance and Gyrocalculus on the Space of Symmetric Positive Definite Matrices

GyroSPD Code for the paper "Vector-valued Distance and Gyrocalculus on the Space of Symmetric Positive Definite Matrices" accepted at NeurIPS 2021. Re

Federico Lopez 12 Dec 12, 2022
Implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021).

[PDF] | [Slides] The official implementation of Learning Gradient Fields for Molecular Conformation Generation (ICML 2021 Long talk) Installation Inst

MilaGraph 117 Dec 09, 2022
Source code for "Pack Together: Entity and Relation Extraction with Levitated Marker"

PL-Marker Source code for Pack Together: Entity and Relation Extraction with Levitated Marker. Quick links Overview Setup Install Dependencies Data Pr

THUNLP 173 Dec 30, 2022
This GitHub repository contains code used for plots in NeurIPS 2021 paper 'Stochastic Multi-Armed Bandits with Control Variates.'

About Repository This repository contains code used for plots in NeurIPS 2021 paper 'Stochastic Multi-Armed Bandits with Control Variates.' About Code

Arun Verma 1 Nov 09, 2021
Corgis are the cutest creatures; have 30K of them!

corgi-net This is a dataset of corgi images scraped from the corgi subreddit. After filtering using an ImageNet classifier, the training set consists

Alex Nichol 6 Dec 24, 2022
Implementation of ViViT: A Video Vision Transformer

ViViT: A Video Vision Transformer Unofficial implementation of ViViT: A Video Vision Transformer. Notes: This is in WIP. Model 2 is implemented, Model

Rishikesh (ΰ€‹ΰ€·ΰ€Ώΰ€•ΰ₯‡ΰ€Ά) 297 Jan 06, 2023
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion". Paper link: https://arxiv.org/abs/2111.10

Ziyao Zeng 14 Feb 26, 2022
Record radiologists' eye gaze when they are labeling images.

Record radiologists' eye gaze when they are labeling images. Read for installation, usage, and deep learning examples. Why use MicEye Versatile As a l

24 Nov 03, 2022
Co-GAIL: Learning Diverse Strategies for Human-Robot Collaboration

CoGAIL Table of Content Overview Installation Dataset Training Evaluation Trained Checkpoints Acknowledgement Citations License Overview This reposito

Jeremy Wang 29 Dec 24, 2022
This is an open source library implementing hyperbox-based machine learning algorithms

hyperbox-brain is a Python open source toolbox implementing hyperbox-based machine learning algorithms built on top of scikit-learn and is distributed

Complex Adaptive Systems (CAS) Lab - University of Technology Sydney 21 Dec 14, 2022
RODD: A Self-Supervised Approach for Robust Out-of-Distribution Detection

RODD Official Implementation of 2022 CVPRW Paper RODD: A Self-Supervised Approach for Robust Out-of-Distribution Detection Introduction: Recent studie

Umar Khalid 17 Oct 11, 2022
BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment

BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment

Holy Wu 35 Jan 01, 2023
Band-Adaptive Spectral-Spatial Feature Learning Neural Network for Hyperspectral Image Classification

Band-Adaptive Spectral-Spatial Feature Learning Neural Network for Hyperspectral Image Classification

258 Dec 29, 2022
Deep and online learning with spiking neural networks in Python

Introduction The brain is the perfect place to look for inspiration to develop more efficient neural networks. One of the main differences with modern

Jason Eshraghian 447 Jan 03, 2023
An implementation of IMLE-Net: An Interpretable Multi-level Multi-channel Model for ECG Classification

IMLE-Net: An Interpretable Multi-level Multi-channel Model for ECG Classification The repostiory consists of the code, results and data set links for

12 Dec 26, 2022
Everything you want about DP-Based Federated Learning, including Papers and Code. (Mechanism: Laplace or Gaussian, Dataset: femnist, shakespeare, mnist, cifar-10 and fashion-mnist. )

Differential Privacy (DP) Based Federated Learning (FL) Everything about DP-based FL you need is here. οΌˆζ‰€ζœ‰δ½ ιœ€θ¦ηš„DP-based FLηš„δΏ‘ζ―ιƒ½εœ¨θΏ™ι‡ŒοΌ‰ Code Tip: the code o

wenzhu 83 Dec 24, 2022
Code of 3D Shape Variational Autoencoder Latent Disentanglement via Mini-Batch Feature Swapping for Bodies and Faces

3D Shape Variational Autoencoder Latent Disentanglement via Mini-Batch Feature Swapping for Bodies and Faces Installation After cloning the repo open

37 Dec 03, 2022
Clockwork Convnets for Video Semantic Segmentation

Clockwork Convnets for Video Semantic Segmentation This is the reference implementation of arxiv:1608.03609: Clockwork Convnets for Video Semantic Seg

Evan Shelhamer 141 Nov 21, 2022
This is a model to classify Vietnamese sign language using Motion history image (MHI) algorithm and CNN.

Vietnamese sign lagnuage recognition using MHI and CNN This is a model to classify Vietnamese sign language using Motion history image (MHI) algorithm

Phat Pham 3 Feb 24, 2022
RodoSol-ALPR Dataset

RodoSol-ALPR Dataset This dataset, called RodoSol-ALPR dataset, contains 20,000 images captured by static cameras located at pay tolls owned by the Ro

Rayson Laroca 45 Dec 15, 2022