Unofficial PyTorch implementation of MobileViT based on paper "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer".

Overview

MobileViT

RegNet

Unofficial PyTorch implementation of MobileViT based on paper MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER.


Table of Contents


Model Architecture

Trulli

MobileViT Architecture

Usage

Training

python main.py
optional arguments:
  -h, --help            show this help message and exit
  --gpu_device GPU_DEVICE
                        Select specific GPU to run the model
  --batch-size N        Input batch size for training (default: 64)
  --epochs N            Number of epochs to train (default: 20)
  --num-class N         Number of classes to classify (default: 10)
  --lr LR               Learning rate (default: 0.01)
  --weight-decay WD     Weight decay (default: 1e-5)
  --model-path PATH     Path to save the model

Citation

@InProceedings{Sachin2021,
  title = {MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER},
  author = {Sachin Mehta and Mohammad Rastegari},
  booktitle = {},
  year = {2021}
}

If this implement have any problem please let me know, thank you.

Comments
  • Training settings

    Training settings

    I really appreciate your efforts in implementing this model in pytorch. Here, I have one concern about the training settings. If what I understand is correct, you just trained the model for less than 5 epoches.

    In addition, the hyper-parameters you adopted is different from that in the original article. For instance, in the original manuscript, authors train mobilevit using AdamW optimizer, label smoothing cross-entry and multi-scale sampler. The training phase has a warmup stage.

    I also found that the classificaion accuracy provided here is much lower than that in the original version.

    I conjecture that the gab between accuracies are caused by different training settings.

    opened by hkzhang91 6
  • load pretrain weight failed

    load pretrain weight failed

    import torch
    import models
    
    model = models.MobileViT_S()
    PATH = "./MobileVit-S.pth.tar"
    weights = torch.load(PATH, map_location=lambda storage, loc: storage)
    model.load_state_dict(weights['state_dict'])
    model.eval()
    torch.save(model, './model.pt')
    
    • I try to load the pre-train weight to test one demo; but the network structure does not seem to match the weights, is there any solution?

    image

    opened by hererookie 2
  • model training hyperparameter

    model training hyperparameter

    A problem has been bothering me. the learning rate, optimizer, batch_size, L2 regularization, label smoothing and epochs are inconsistent with the paper. How should I modify the code?

    opened by Agino-ltp 1
  • Have you test MobileVit on cifar-10?

    Have you test MobileVit on cifar-10?

    Thanks for your wonderful work!

    I prepare to try MobileVit on small dataset, such as MNIST, and I need adjust the network structure. Before this work, I want to know if MobileVit has a better performance than other networks on small dataset.

    I notice "get_cifar10_dataset" in utils.py. Have you tested MobileVit on cifar-10? If you have, could you please show me the accuracy and inference time result?

    opened by Jerryme-xxm 1
  • Issues when loading MobileViT_S()

    Issues when loading MobileViT_S()

    I wanted to load the MobileViT_S() model and use the pre-trained weights, but I have got some errors in my code. To make it easier and help others, I will share my solution (in case there will be someone who is beginner like me):

    def load_mobilevit_weights(model_path):
      # Create an instance of the MobileViT model
      net = MobileViT_S()
      
      # Load the PyTorch state_dict
      state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
      
      # Since there is a problem in the names of layers, we will change the keys to meet the MobileViT model architecture
      for key in list(state_dict.keys()):
        state_dict[key.replace('module.', '')] = state_dict.pop(key)
      
      # Once the keys are fixed, we can modify the parameters of MobileViT
      net.load_state_dict(state_dict)
      
      return net
    
    net = load_mobilevit_weights("MobileViT_S_model_best.pth.tar")
    
    opened by Sehaba95 4
Releases(weight)
Owner
Hong-Jia Chen
Master student at National Chung Cheng University, Taiwan. Interested in Deep Learning and Computer Vision.
Hong-Jia Chen
Simple implementation of Mobile-Former on Pytorch

Simple-implementation-of-Mobile-Former At present, only the model but no trained. There may be some bug in the code, and some details may be different

Acheung 103 Dec 31, 2022
codes for IKM (arXiv2021, Submitted to IEEE Trans)

Image-specific Convolutional Kernel Modulation for Single Image Super-resolution This repository is for IKM introduced in the following paper Yuanfei

Yuanfei Huang 9 Dec 29, 2022
PCGNN - Procedural Content Generation with NEAT and Novelty

PCGNN - Procedural Content Generation with NEAT and Novelty Generation Approach — Metrics — Paper — Poster — Examples PCGNN - Procedural Content Gener

Michael Beukman 8 Dec 10, 2022
[ICCV 2021 (oral)] Planar Surface Reconstruction from Sparse Views

Planar Surface Reconstruction From Sparse Views Linyi Jin, Shengyi Qian, Andrew Owens, David F. Fouhey University of Michigan ICCV 2021 (Oral) This re

Linyi Jin 89 Jan 05, 2023
Official Pytorch implementation of "Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021)

Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021) Official Pytorch implementation of Unbiased Classification

Youngkyu 17 Jan 01, 2023
LSTM built using Keras Python package to predict time series steps and sequences. Includes sin wave and stock market data

LSTM Neural Network for Time Series Prediction LSTM built using the Keras Python package to predict time series steps and sequences. Includes sine wav

Jakob Aungiers 4.1k Jan 02, 2023
DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Object Detection

DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Object Detection Code for our Paper DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Obje

Steven Lang 58 Dec 19, 2022
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
NAS-FCOS: Fast Neural Architecture Search for Object Detection (CVPR 2020)

NAS-FCOS: Fast Neural Architecture Search for Object Detection This project hosts the train and inference code with pretrained model for implementing

Ning Wang 180 Dec 06, 2022
AI that generate music

PianoGPT ai that generate music try it here https://share.streamlit.io/annasajkh/pianogpt/main/main.py or here https://huggingface.co/spaces/Annas/Pia

Annas 28 Nov 27, 2022
Tensorflow Repo for "DeepGCNs: Can GCNs Go as Deep as CNNs?"

DeepGCNs: Can GCNs Go as Deep as CNNs? In this work, we present new ways to successfully train very deep GCNs. We borrow concepts from CNNs, mainly re

Guohao Li 612 Nov 15, 2022
Array Camera Ptychography

Array Camera Ptychography This repository provides the code for the following papers: Schulz, Timothy J., David J. Brady, and Chengyu Wang. "Photon-li

Brady lab in Optical Sciences 1 Nov 15, 2021
Code to reproduce the results in the paper "Tensor Component Analysis for Interpreting the Latent Space of GANs".

Tensor Component Analysis for Interpreting the Latent Space of GANs [ paper | project page ] Code to reproduce the results in the paper "Tensor Compon

James Oldfield 4 Jun 17, 2022
An NVDA add-on to split screen reader and audio from other programs to different sound channels

An NVDA add-on to split screen reader and audio from other programs to different sound channels (add-on idea credit: Tony Malykh)

Joseph Lee 7 Dec 25, 2022
ReferFormer - Official Implementation of ReferFormer

The official implementation of the paper: Language as Queries for Referring Vide

Jonas Wu 232 Dec 29, 2022
Self-Supervised Image Denoising via Iterative Data Refinement

Self-Supervised Image Denoising via Iterative Data Refinement Yi Zhang1, Dasong Li1, Ka Lung Law2, Xiaogang Wang1, Hongwei Qin2, Hongsheng Li1 1CUHK-S

Zhang Yi 72 Jan 01, 2023
Source code for Zalo AI 2021 submission

zalo_ltr_2021 Source code for Zalo AI 2021 submission Solution: Pipeline We use the pipepline in the picture below: Our pipeline is combination of BM2

128 Dec 27, 2022
Accelerating BERT Inference for Sequence Labeling via Early-Exit

Sequence-Labeling-Early-Exit Code for ACL 2021 paper: Accelerating BERT Inference for Sequence Labeling via Early-Exit Requirement: Please refer to re

李孝男 23 Oct 14, 2022
adversarial_multi_armed_bandit_variable_plays

Adversarial Multi-Armed Bandit with Variable Plays This code is for paper: Adversarial Online Learning with Variable Plays in the Evasion-and-Pursuit

Yiyang Wang 1 Oct 28, 2021
🔊 Audio and fastai v2

Fastaudio An audio module for fastai v2. We want to help you build audio machine learning applications while minimizing the need for audio domain expe

152 Dec 28, 2022