A TensorFlow 2.x implementation of Masked Autoencoders Are Scalable Vision Learners

Overview

Masked Autoencoders Are Scalable Vision Learners

Open In Colab

A TensorFlow implementation of Masked Autoencoders Are Scalable Vision Learners [1]. Our implementation of the proposed method is available in mae-pretraining.ipynb notebook. It includes evaluation with linear probing as well. Furthermore, the notebook can be fully executed on Google Colab. Our main objective is to present the core idea of the proposed method in a minimal and readable manner. We have also prepared a blog for getting started with Masked Autoencoder easily.


With just 100 epochs of pre-training and a fairly lightweight and asymmetric Autoencoder architecture we achieve 49.33%% accuracy with linear probing on the CIFAR-10 dataset. Our training logs and encoder weights are released in Weights and Logs. For comparison, we took the encoder architecture and trained it from scratch (refer to regular-classification.ipynb) in a fully supervised manner. This gave us ~76% test top-1 accuracy.

We note that with further hyperparameter tuning and more epochs of pre-training, we can achieve a better performance with linear-probing. Below we present some more results:

Config Masking
proportion
LP
performance
Encoder weights
& logs
Encoder & decoder layers: 3 & 1
Batch size: 256
0.6 44.25% Link
Do 0.75 46.84% Link
Encoder & decoder layers: 6 & 2
Batch size: 256
0.75 48.16% Link
Encoder & decoder layers: 9 & 3
Batch size: 256
Weight deacy: 1e-5
0.75 49.33% Link

LP denotes linear-probing. Config is mostly based on what we define in the hyperparameters section of this notebook: mae-pretraining.ipynb.

Acknowledgements

References

[1] Masked Autoencoders Are Scalable Vision Learners; He et al.; arXiv 2021; https://arxiv.org/abs/2111.06377.

You might also like...
A repository that shares tuning results of trained models generated by TensorFlow / Keras. Post-training quantization (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization), Quantization-aware training. TensorFlow Lite. OpenVINO. CoreML. TensorFlow.js. TF-TRT. MediaPipe. ONNX. [.tflite,.h5,.pb,saved_model,tfjs,tftrt,mlmodel,.xml/.bin, .onnx] Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Official implementation of the paper
Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders"

AAVAE Official implementation of the paper "AAVAE: Augmentation-AugmentedVariational Autoencoders" Abstract Recent methods for self-supervised learnin

VIMPAC: Video Pre-Training via Masked Token Prediction and Contrastive Learning

This is a release of our VIMPAC paper to illustrate the implementations. The pretrained checkpoints and scripts will be soon open-sourced in HuggingFace transformers.

EMNLP 2021 - Frustratingly Simple Pretraining Alternatives to Masked Language Modeling

Frustratingly Simple Pretraining Alternatives to Masked Language Modeling This is the official implementation for "Frustratingly Simple Pretraining Al

The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization

PRIMER The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization. PRIMER is a pre-trained model for mu

SimMIM: A Simple Framework for Masked Image Modeling
SimMIM: A Simple Framework for Masked Image Modeling

SimMIM By Zhenda Xie*, Zheng Zhang*, Yue Cao*, Yutong Lin, Jianmin Bao, Zhuliang Yao, Qi Dai and Han Hu*. This repo is the official implementation of

SeMask: Semantically Masked Transformers for Semantic Segmentation.
SeMask: Semantically Masked Transformers for Semantic Segmentation.

SeMask: Semantically Masked Transformers Jitesh Jain, Anukriti Singh, Nikita Orlov, Zilong Huang, Jiachen Li, Steven Walton, Humphrey Shi This repo co

FocusFace: Multi-task Contrastive Learning for Masked Face Recognition
FocusFace: Multi-task Contrastive Learning for Masked Face Recognition

FocusFace This is the official repository of "FocusFace: Multi-task Contrastive Learning for Masked Face Recognition" accepted at IEEE International C

Comments
  • Excellent work (`mae.ipynb`)!

    Excellent work (`mae.ipynb`)!

    @ariG23498 this is fantastic stuff. Super clean, readable, and coherent with the original implementation. A couple of suggestions that would likely make things even better:

    • Since you have already implemented masking visualization utilities how about making them part of the PatchEncoder itself? That way you could let it accept a test image, apply random masking, and plot it just like the way you are doing in the earlier cells. This way I believe the notebook will be cleaner.
    • AdamW (tfa.optimizers.adamw) is a better choice when it comes to training Transformer-based models.
    • Are we taking the loss on the correct component? I remember you mentioning it being dealt with differently.

    After these points are addressed I will take a crack at porting the training loop to TPUs along with other performance monitoring callbacks.

    opened by sayakpaul 7
  • Unshuffle the patches?

    Unshuffle the patches?

    Your code helps me a lot! However, I still have some questions. In the paper, the authors say they unshuffle the full list before applying the deocder. In the MaskedAutoencoder class of your implementation, decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)
    no unshuffling is used. I wonder if you can tell me the purpose of doing so? Thanks a lot!

    opened by changtaoli 2
  • Could you also share the weight of the pretrained decoder?

    Could you also share the weight of the pretrained decoder?

    Hi,

    Thanks for your excellent implementation! I found that you have shared the weights of the encoder, but if we want to replicate the reconstruction, the pretrained decoder is still needed. So, could you also share the weight of the pretrained decoder?

    Best Regards, Hongxin

    opened by hongxin001 1
  • Issue with the plotting utility `show_masked_image`

    Issue with the plotting utility `show_masked_image`

    Should be:

    def show_masked_image(self, patches):
            # Utility function that helps visualize maksed images.
            _, unmask_indices = self.get_random_indices()
            unmasked_patches = tf.gather(patches, unmask_indices, axis=1, batch_dims=1)
    
            # Necessary for plotting.
            ids = tf.argsort(unmask_indices)
            sorted_unmask_indices = tf.sort(unmask_indices)
            unmasked_patches = tf.gather(unmasked_patches, ids, batch_dims=1)
    
            # Select a random index for visualization.
            idx = np.random.choice(len(sorted_unmask_indices))
            print(f"Index selected: {idx}.")
    
            n = int(np.sqrt(NUM_PATCHES))
            unmask_index = sorted_unmask_indices[idx]
            unmasked_patch = unmasked_patches[idx]
    
            plt.figure(figsize=(4, 4))
    
            count = 0
            for i in range(NUM_PATCHES):
                ax = plt.subplot(n, n, i + 1)
    
                if count < unmask_index.shape[0] and unmask_index[count].numpy() == i:
                    patch = unmasked_patch[count]
                    patch_img = tf.reshape(patch, (PATCH_SIZE, PATCH_SIZE, 3))
                    plt.imshow(patch_img)
                    plt.axis("off")
                    count = count + 1
                else:
                    patch_img = tf.zeros((PATCH_SIZE, PATCH_SIZE, 3))
                    plt.imshow(patch_img)
                    plt.axis("off")
            plt.show()
    
            # Return the random index to validate the image outside the method.
            return idx
    
    opened by ariG23498 1
Releases(v1.0.0)
Owner
Aritra Roy Gosthipaty
Learning with a learning rate of 1e-10.
Aritra Roy Gosthipaty
CTRL-C: Camera calibration TRansformer with Line-Classification

CTRL-C: Camera calibration TRansformer with Line-Classification This repository contains the official code and pretrained models for CTRL-C (Camera ca

57 Nov 14, 2022
KUIELAB-MDX-Net got the 2nd place on the Leaderboard A and the 3rd place on the Leaderboard B in the MDX-Challenge ISMIR 2021

KUIELAB-MDX-Net got the 2nd place on the Leaderboard A and the 3rd place on the Leaderboard B in the MDX-Challenge ISMIR 2021

IELab@ Korea University 74 Dec 28, 2022
Keras community contributions

keras-contrib : Keras community contributions Keras-contrib is deprecated. Use TensorFlow Addons. The future of Keras-contrib: We're migrating to tens

Keras 1.6k Dec 21, 2022
Code for sound field predictions in domains with impedance boundaries. Used for generating results from the paper

Code for sound field predictions in domains with impedance boundaries. Used for generating results from the paper

DTU Acoustic Technology Group 11 Dec 17, 2022
Evidential Softmax for Sparse Multimodal Distributions in Deep Generative Models

Evidential Softmax for Sparse Multimodal Distributions in Deep Generative Models Abstract Many applications of generative models rely on the marginali

Stanford Intelligent Systems Laboratory 9 Jun 06, 2022
Depth-Aware Video Frame Interpolation (CVPR 2019)

DAIN (Depth-Aware Video Frame Interpolation) Project | Paper Wenbo Bao, Wei-Sheng Lai, Chao Ma, Xiaoyun Zhang, Zhiyong Gao, and Ming-Hsuan Yang IEEE C

Wenbo Bao 7.7k Dec 31, 2022
Source code for paper "ATP: AMRize Than Parse! Enhancing AMR Parsing with PseudoAMRs" @NAACL-2022

ATP: AMRize Then Parse! Enhancing AMR Parsing with PseudoAMRs Hi this is the source code of our paper "ATP: AMRize Then Parse! Enhancing AMR Parsing w

Chen Liang 13 Nov 23, 2022
Predictive Modeling on Electronic Health Records(EHR) using Pytorch

Predictive Modeling on Electronic Health Records(EHR) using Pytorch Overview Although there are plenty of repos on vision and NLP models, there are ve

81 Jan 01, 2023
BMW TechOffice MUNICH 148 Dec 21, 2022
Deep Face Recognition in PyTorch

Face Recognition in PyTorch By Alexey Gruzdev and Vladislav Sovrasov Introduction A repository for different experimental Face Recognition models such

Alexey Gruzdev 141 Sep 11, 2022
Greedy Gaussian Segmentation

GGS Greedy Gaussian Segmentation (GGS) is a Python solver for efficiently segmenting multivariate time series data. For implementation details, please

Stanford University Convex Optimization Group 72 Dec 07, 2022
ICLR21 Tent: Fully Test-Time Adaptation by Entropy Minimization

⛺️ Tent: Fully Test-Time Adaptation by Entropy Minimization This is the official project repository for Tent: Fully-Test Time Adaptation by Entropy Mi

Dequan Wang 204 Dec 25, 2022
Symbolic Parallel Adaptive Importance Sampling for Probabilistic Program Analysis in JAX

SYMPAIS: Symbolic Parallel Adaptive Importance Sampling for Probabilistic Program Analysis Overview | Installation | Documentation | Examples | Notebo

Yicheng Luo 4 Sep 13, 2022
Speech recognition tool to convert audio to text transcripts, for Linux and Raspberry Pi.

Spchcat Speech recognition tool to convert audio to text transcripts, for Linux and Raspberry Pi. Description spchcat is a command-line tool that read

Pete Warden 279 Jan 03, 2023
MERLOT: Multimodal Neural Script Knowledge Models

merlot MERLOT: Multimodal Neural Script Knowledge Models MERLOT is a model for learning what we are calling "neural script knowledge" -- representatio

Rowan Zellers 190 Dec 22, 2022
Visualization toolkit for neural networks in PyTorch! Demo -->

FlashTorch A Python visualization toolkit, built with PyTorch, for neural networks in PyTorch. Neural networks are often described as "black box". The

Misa Ogura 692 Dec 29, 2022
PyTorch implementation of PSPNet

PSPNet with PyTorch Unofficial implementation of "Pyramid Scene Parsing Network" (https://arxiv.org/abs/1612.01105). This repository is just for caffe

Kazuto Nakashima 52 Nov 16, 2022