Implementation of the state-of-the-art vision transformers with tensorflow

Overview

ViT Tensorflow

This repository contains the tensorflow implementation of the state-of-the-art vision transformers (a category of computer vision models first introduced in An Image is worth 16 x 16 words). This repository is inspired from the work of lucidrains which is vit-pytorch. I hope you enjoy these implementations :)

Models

Requirements

pip install tensorflow

Vision Transformer

Vision transformer was introduced in An Image is worth 16 x 16 words. This model uses a Transformer encoder to classify images with pure attention and no convolution.

Usage

Defining the Model

from vit import ViT
import tensorflow as tf

vitClassifier = ViT(
                    num_classes=1000,
                    patch_size=16,
                    num_of_patches=(224//16)**2,
                    d_model=128,
                    heads=2,
                    num_layers=4,
                    mlp_rate=2,
                    dropout_rate=0.1,
                    prediction_dropout=0.3,
)
Params
  • num_classes: int
    number of classes used for the final classification head
  • patch_size: int
    patch_size used for the tokenization
  • num_of_patches: int
    number of patches after the tokenization which is used for the positional encoding, Generally it can be computed by the following formula (((h-patch_size)//patch_size) + 1)*(((w-patch_size)//patch_size) + 1) where h is the height of the image and w is the width of the image. In addition, when height and width of the image are devisable by the patch_size the following formula can be used as well (h//patch_size)*(w//patch_size)
  • d_model: int
    hidden dimension of the transformer encoder and the demnesion used for patch embedding
  • heads: int
    number of heads used for the multi-head attention mechanism
  • num_layers: int
    number of blocks in encoder transformer
  • mlp_rate: int
    the rate of expansion in the feed-forward block of each transformer block (the dimension after expansion is mlp_rate * d_model)
  • dropout_rate: float
    dropout rate used in the multi-head attention mechanism
  • prediction_dropout: float
    dropout rate used in the final prediction head of the model

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = vitClassifier(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

vitClassifier.compile(
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              metrics=[
                       tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                       tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
              ])

vitClassifier.fit(
              trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
              validation_data=valData, #The same as training
              epochs=100,)

Convolutional Vision Transformer

Convolutional Vision Transformer was introduced in here. This model uses a hierarchical (multi-stage) architecture with convolutional embeddings in the begining of each stage. it also uses Convolutional Transformer Blocks to improve the orginal vision transformer by adding CNNs inductive bias into the architecture.

Usage

Defining the Model

from cvt import CvT , CvTStage
import tensorflow as tf

cvtModel = CvT(
num_of_classes=1000, 
stages=[
        CvTStage(projectionDim=64, 
                 heads=1, 
                 embeddingWindowSize=(7 , 7), 
                 embeddingStrides=(4 , 4), 
                 layers=1,
                 projectionWindowSize=(3 , 3), 
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1),
        CvTStage(projectionDim=192,
                 heads=3,
                 embeddingWindowSize=(3 , 3), 
                 embeddingStrides=(2 , 2),
                 layers=1, 
                 projectionWindowSize=(3 , 3), 
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1),
        CvTStage(projectionDim=384,
                 heads=6,
                 embeddingWindowSize=(3 , 3),
                 embeddingStrides=(2 , 2),
                 layers=1,
                 projectionWindowSize=(3 , 3),
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1)
],
dropout=0.5)
CvT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of CvTStage
    list of cvt stages
  • dropout: float
    dropout rate used for the prediction head
CvTStage Params
  • projectionDim: int
    dimension used for the multi-head attention mechanism and the convolutional embedding
  • heads: int
    number of heads in the multi-head attention mechanism
  • embeddingWindowSize: tuple(int , int)
    window size used for the convolutional emebdding
  • embeddingStrides: tuple(int , int)
    strides used for the convolutional embedding
  • layers: int
    number of convolutional transformer blocks
  • projectionWindowSize: tuple(int , int)
    window size used for the convolutional projection in each convolutional transformer block
  • projectionStrides: tuple(int , int)
    strides used for the convolutional projection in each convolutional transformer block
  • ffnRate: int
    expansion rate of the mlp block in each convolutional transformer block
  • dropoutRate: float
    dropout rate used in each convolutional transformer block

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = cvtModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

cvtModel.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        metrics=[
                 tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                 tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
        ])

cvtModel.fit(
        trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
        validation_data=valData, #The same as training
        epochs=100,)

Pyramid Vision Transformer V1

Pyramid Vision Transformer V1 was introduced in here. This model stacks multiple Transformer Encoders to form the first convolution-free multi-scale backbone for various visual tasks including Image Segmentation , Object Detection and etc. In addition to this a new attention mechanism called Spatial Reduction Attention (SRA) is also introduced in this paper to reduce the quadratic complexity of the multi-head attention mechansim.

Usage

Defining the Model

from pvt_v1 import PVT , PVTStage
import tensorflow as tf

pvtModel = PVT(
num_of_classes=1000, 
stages=[
        PVTStage(d_model=64,
                 patch_size=(2 , 2),
                 heads=1,
                 reductionFactor=2,
                 mlp_rate=2,
                 layers=2, 
                 dropout_rate=0.1),
        PVTStage(d_model=128,
                 patch_size=(2 , 2),
                 heads=2, 
                 reductionFactor=2, 
                 mlp_rate=2, 
                 layers=2, 
                 dropout_rate=0.1),
        PVTStage(d_model=320,
                 patch_size=(2 , 2),
                 heads=5, 
                 reductionFactor=2, 
                 mlp_rate=2, 
                 layers=2, 
                 dropout_rate=0.1),
],
dropout=0.5)
PVT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of PVTStage
    list of pvt stages
  • dropout: float
    dropout rate used for the prediction head
PVTStage Params
  • d_model: int
    dimension used for the SRA mechanism and the patch embedding
  • patch_size: tuple(int , int)
    window size used for the patch emebdding
  • heads: int
    number of heads in the SRA mechanism
  • reductionFactor: int
    reduction factor used for the down sampling of the K and V in the SRA mechanism
  • mlp_rate: int
    expansion rate used in the feed-forward block
  • layers: int
    number of transformer encoders
  • dropout_rate: float
    dropout rate used in each transformer encoder

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = pvtModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

pvtModel.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        metrics=[
                 tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                 tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
        ])

pvtModel.fit(
        trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
        validation_data=valData, #The same as training
        epochs=100,)

Pyramid Vision Transformer V2

Pyramid Vision Transformer V2 was introduced in here. This model is an improved version of the PVT V1. The improvements of this version are as follows:

  1. It uses overlapping patch embedding by using padded convolutions
  2. It uses convolutional feed-forward blocks which have a depth-wise convolution after the first fully-connected layer
  3. It uses a fixed pooling instead of convolutions for down sampling the K and V in the SRA attention mechanism (The new attention mechanism is called Linear SRA)

Usage

Defining the Model

from pvt_v2 import PVTV2 , PVTV2Stage
import tensorflow as tf

pvtV2Model = PVTV2(
num_of_classes=1000, 
stages=[
        PVTV2Stage(d_model=64,
                   windowSize=(2 , 2), 
                   heads=1,
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2, 
                   dropout_rate=0.1),
        PVTV2Stage(d_model=128, 
                   windowSize=(2 , 2),
                   heads=2,
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2,
                   dropout_rate=0.1),
        PVTV2Stage(d_model=320,
                   windowSize=(2 , 2), 
                   heads=5, 
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2, 
                   dropout_rate=0.1),
],
dropout=0.5)
PVT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of PVTV2Stage
    list of pvt v2 stages
  • dropout: float
    dropout rate used for the prediction head
PVTStage Params
  • d_model: int
    dimension used for the Linear SRA mechanism and the convolutional patch embedding
  • windowSize: tuple(int , int)
    window size used for the convolutional patch emebdding
  • heads: int
    number of heads in the Linear SRA mechanism
  • poolingSize: tuple(int , int)
    size of the K and V after the fixed pooling
  • mlp_rate: int
    expansion rate used in the convolutional feed-forward block
  • mlp_windowSize: tuple(int , int)
    the window size used for the depth-wise convolution in the convolutional feed-forward block
  • layers: int
    number of transformer encoders
  • dropout_rate: float
    dropout rate used in each transformer encoder

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = pvtV2Model(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

pvtV2Model.compile(
          loss=tf.keras.losses.SparseCategoricalCrossentropy(),
          optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
          metrics=[
                   tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                   tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
          ])

pvtV2Model.fit(
          trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
          validation_data=valData, #The same as training
          epochs=100,)

DeiT

DeiT was introduced in Training Data-Efficient Image Transformers & Distillation Through Attention. Since original vision transformer is data hungry due to the lack of existance of any inductive bias (unlike CNNs) a lot of data is required to train original vision transformer in order to surpass the state-of-the-art CNNs such as Resnet. Therefore, in this paper authors used a pre-trained CNN such as resent during training and used a sepcial loss function to perform distillation through attention.

Usage

Defining the Model

from deit import DeiT
import tensorflow as tf

teacherModel = tf.keras.applications.ResNet50(include_top=True, 
                                              weights="imagenet", 
                                              input_shape=(224 , 224 , 3))

deitModel = DeiT(
                 num_classes=1000,
                 patch_size=16,
                 num_of_patches=(224//16)**2,
                 d_model=128,
                 heads=2,
                 num_layers=4,
                 mlp_rate=2,
                 teacherModel=teacherModel,
                 temperature=1.0, 
                 alpha=0.5,
                 hard=False, 
                 dropout_rate=0.1,
                 prediction_dropout=0.3,
)
Params
  • num_classes: int
    number of classes used for the final classification head
  • patch_size: int
    patch_size used for the tokenization
  • num_of_patches: int
    number of patches after the tokenization which is used for the positional encoding, Generally it can be computed by the following formula (((h-patch_size)//patch_size) + 1)*(((w-patch_size)//patch_size) + 1) where h is the height of the image and w is the width of the image. In addition, when height and width of the image are devisable by the patch_size the following formula can be used as well (h//patch_size)*(w//patch_size)
  • d_model: int
    hidden dimension of the transformer encoder and the demnesion used for patch embedding
  • heads: int
    number of heads used for the multi-head attention mechanism
  • num_layers: int
    number of blocks in encoder transformer
  • mlp_rate: int
    the rate of expansion in the feed-forward block of each transformer block (the dimension after expansion is mlp_rate * d_model)
  • teacherModel: Tensorflow Model
    the teacherModel used for the distillation during training, This model is a pre-trained CNN model with the same input_shape and output_shape as the Transformer
  • temperature: float
    the temperature parameter in the loss
  • alpha: float
    the coefficient balancing the Kullback–Leibler divergence loss (KL) and the cross-entropy loss
  • hard: bool
    indicates using Hard-label distillation or Soft distillation
  • dropout_rate: float
    dropout rate used in the multi-head attention mechanism
  • prediction_dropout: float
    dropout rate used in the final prediction head of the model

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = deitModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

#Note that the loss is defined inside the model and no loss should be passed here
deitModel.compile(
         optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
         metrics=[
                  tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                  tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
         ])

deitModel.fit(
         trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b , num_classes))
         validation_data=valData, #The same as training
         epochs=100,)
Owner
Mohammadmahdi NouriBorji
Mohammadmahdi NouriBorji
Unofficial Implementation of MLP-Mixer, gMLP, resMLP, Vision Permutator, S2MLPv2, RaftMLP, ConvMLP, ConvMixer in Jittor and PyTorch.

Unofficial Implementation of MLP-Mixer, gMLP, resMLP, Vision Permutator, S2MLPv2, RaftMLP, ConvMLP, ConvMixer in Jittor and PyTorch! Now, Rearrange and Reduce in einops.layers.jittor are support!!

130 Jan 08, 2023
Code for "Unsupervised Layered Image Decomposition into Object Prototypes" paper

DTI-Sprites Pytorch implementation of "Unsupervised Layered Image Decomposition into Object Prototypes" paper Check out our paper and webpage for deta

40 Dec 22, 2022
The Illinois repository for Climatehack (https://climatehack.ai/). We won 1st place!

Climatehack This is the repository for Illinois's Climatehack Team. We earned first place on the leaderboard with a final score of 0.87992. An overvie

Jatin Mathur 20 Jun 09, 2022
Official Pytorch implementation of ICLR 2018 paper Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge.

Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge: Official Pytorch implementation of ICLR 2018 paper Deep Learning for Phy

emmanuel 47 Nov 06, 2022
Prototype-based Incremental Few-Shot Semantic Segmentation

Prototype-based Incremental Few-Shot Semantic Segmentation Fabio Cermelli, Massimiliano Mancini, Yongqin Xian, Zeynep Akata, Barbara Caputo -- BMVC 20

Fabio Cermelli 21 Dec 29, 2022
DeOldify - A Deep Learning based project for colorizing and restoring old images (and video!)

DeOldify - A Deep Learning based project for colorizing and restoring old images (and video!)

Jason Antic 15.8k Jan 04, 2023
Compute execution plan: A DAG representation of work that you want to get done. Individual nodes of the DAG could be simple python or shell tasks or complex deeply nested parallel branches or embedded DAGs themselves.

Hello from magnus Magnus provides four capabilities for data teams: Compute execution plan: A DAG representation of work that you want to get done. In

12 Feb 08, 2022
Iranian Cars Detection using Yolov5s, PyTorch

Iranian Cars Detection using Yolov5 Train 1- git clone https://github.com/ultralytics/yolov5 cd yolov5 pip install -r requirements.txt 2- Dataset ../

Nahid Ebrahimian 22 Dec 05, 2022
Exploration of some patients clinical variables.

Answer_ALS_clinical_data Exploration of some patients clinical variables. All the clinical / metadata data is available here: https://data.answerals.o

1 Jan 20, 2022
Python tools for 3D face: 3DMM, Mesh processing(transform, camera, light, render), 3D face representations.

face3d: Python tools for processing 3D face Introduction This project implements some basic functions related to 3D faces. You can use this to process

Yao Feng 2.3k Dec 30, 2022
Modified fork of Xuebin Qin's U-2-Net Repository. Used for demonstration purposes.

U^2-Net (U square net) Modified version of U2Net used for demonstation purposes. Paper: U^2-Net: Going Deeper with Nested U-Structure for Salient Obje

Shreyas Bhat Kera 13 Aug 28, 2022
This repository contains codes of ICCV2021 paper: SO-Pose: Exploiting Self-Occlusion for Direct 6D Pose Estimation

SO-Pose This repository contains codes of ICCV2021 paper: SO-Pose: Exploiting Self-Occlusion for Direct 6D Pose Estimation This paper is basically an

shangbuhuan 52 Nov 25, 2022
[ICCV 2021] Code release for "Sub-bit Neural Networks: Learning to Compress and Accelerate Binary Neural Networks"

Sub-bit Neural Networks: Learning to Compress and Accelerate Binary Neural Networks By Yikai Wang, Yi Yang, Fuchun Sun, Anbang Yao. This is the pytorc

Yikai Wang 26 Nov 20, 2022
[Pedestron] Generalizable Pedestrian Detection: The Elephant In The Room. @ CVPR2021

Pedestron Pedestron is a MMdetection based repository, that focuses on the advancement of research on pedestrian detection. We provide a list of detec

Irtiza Hasan 594 Jan 05, 2023
Image restoration with neural networks but without learning.

Warning! The optimization may not converge on some GPUs. We've personally experienced issues on Tesla V100 and P40 GPUs. When running the code, make s

Dmitry Ulyanov 7.4k Jan 01, 2023
EMNLP 2021 paper The Devil is in the Detail: Simple Tricks Improve Systematic Generalization of Transformers.

Codebase for training transformers on systematic generalization datasets. The official repository for our EMNLP 2021 paper The Devil is in the Detail:

Csordás Róbert 57 Nov 21, 2022
This is an unofficial implementation of the paper “Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection”.

This is an unofficial implementation of the paper “Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection”.

haifeng xia 32 Oct 26, 2022
PyTorch implementation for 3D human pose estimation

Towards 3D Human Pose Estimation in the Wild: a Weakly-supervised Approach This repository is the PyTorch implementation for the network presented in:

Xingyi Zhou 579 Dec 22, 2022
Efficient 3D Backbone Network for Temporal Modeling

VoV3D is an efficient and effective 3D backbone network for temporal modeling implemented on top of PySlowFast. Diverse Temporal Aggregation and

102 Dec 06, 2022
DeepCAD: A Deep Generative Network for Computer-Aided Design Models

DeepCAD This repository provides source code for our paper: DeepCAD: A Deep Generative Network for Computer-Aided Design Models Rundi Wu, Chang Xiao,

Rundi Wu 85 Dec 31, 2022