Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Overview

Swin-Transformer-Tensorflow

A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" to TensorFlow 2.

The official Pytorch implementation can be found here.

Introduction:

Swin Transformer Architecture Diagram

Swin Transformer (the name Swin stands for Shifted window) is initially described in arxiv, which capably serves as a general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping local windows while also allowing for cross-window connection.

Swin Transformer achieves strong performance on COCO object detection (58.7 box AP and 51.1 mask AP on test-dev) and ADE20K semantic segmentation (53.5 mIoU on val), surpassing previous models by a large margin.

Usage:

1. To Run a Pre-trained Swin Transformer

Swin-T:

python main.py --cfg configs/swin_tiny_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-S:

python main.py --cfg configs/swin_small_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

Swin-B:

python main.py --cfg configs/swin_base_patch4_window7_224.yaml --include_top 1 --resume 1 --weights_type imagenet_1k

The possible options for cfg and weights_type are:

cfg weights_type 22K model 1K Model
configs/swin_tiny_patch4_window7_224.yaml imagenet_1k - github
configs/swin_small_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window7_224.yaml imagenet_22kto1k - github
configs/swin_large_patch4_window12_384.yaml imagenet_22kto1k - github
configs/swin_base_patch4_window7_224.yaml imagenet_22k github -
configs/swin_base_patch4_window12_384.yaml imagenet_22k github -
configs/swin_large_patch4_window7_224.yaml imagenet_22k github -
configs/swin_large_patch4_window12_384.yaml imagenet_22k github -

2. Create custom models

To create a custom classification model:

import argparse

import tensorflow as tf

from config import get_config
from models.build import build_model

parser = argparse.ArgumentParser('Custom Swin Transformer')

parser.add_argument(
    '--cfg',
    type=str,
    metavar="FILE",
    help='path to config file',
    default="CUSTOM_YAML_FILE_PATH"
)
parser.add_argument(
    '--resume',
    type=int,
    help='Whether or not to resume training from pretrained weights',
    choices={0, 1},
    default=1,
)
parser.add_argument(
    '--weights_type',
    type=str,
    help='Type of pretrained weight file to load including number of classes',
    choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"},
    default="imagenet_1k",
)

args = parser.parse_args()
custom_config = get_config(args, include_top=False)

swin_transformer = tf.keras.Sequential([
    build_model(config=custom_config, load_pretrained=args.resume, weights_type=args.weights_type),
    tf.keras.layers.Dense(CUSTOM_NUM_CLASSES)
)

Model ouputs are logits, so don't forget to include softmax in training/inference!!

You can easily customize the model configs with custom YAML files. Predefined YAML files provided by Microsoft are located in the configs directory.

3. Convert PyTorch pretrained weights into Tensorflow checkpoints

We provide a python script with which we convert official PyTorch weights into Tensorflow checkpoints.

$ python convert_weights.py --cfg config_file --weights the_path_to_pytorch_weights --weights_type type_of_pretrained_weights --output the_path_to_output_tf_weights

TODO:

  • Translate model code over to TensorFlow
  • Load PyTorch pretrained weights into TensorFlow model
  • Write trainer code
  • Reproduce results presented in paper
    • Object Detection
  • Reproduce training efficiency of official code in TensorFlow

Citations:

@misc{liu2021swin,
      title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows}, 
      author={Ze Liu and Yutong Lin and Yue Cao and Han Hu and Yixuan Wei and Zheng Zhang and Stephen Lin and Baining Guo},
      year={2021},
      eprint={2103.14030},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
You might also like...
This is an official implementation of our CVPR 2021 paper "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression" (https://arxiv.org/abs/2104.02300)

Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression Introduction In this paper, we are interested in the bottom-up paradigm of estima

Non-Official Pytorch implementation of
Non-Official Pytorch implementation of "Face Identity Disentanglement via Latent Space Mapping" https://arxiv.org/abs/2005.07728 Using StyleGAN2 instead of StyleGAN

Face Identity Disentanglement via Latent Space Mapping - Implement in pytorch with StyleGAN 2 Description Pytorch implementation of the paper Face Ide

Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.
Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

PAWS-TF 🐾 Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS)

A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks
A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks

Spiking Neural Network training with EventProp This is an unofficial PyTorch implemenation of EventProp, a method to compute exact gradients for Spiki

Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286
Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Tensorflow implementation of Semi-supervised Sequence Learning (https://arxiv.org/abs/1511.01432)
Tensorflow implementation of Semi-supervised Sequence Learning (https://arxiv.org/abs/1511.01432)

Transfer Learning for Text Classification with Tensorflow Tensorflow implementation of Semi-supervised Sequence Learning(https://arxiv.org/abs/1511.01

PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)
PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)

Asym-Siam: On the Importance of Asymmetry for Siamese Representation Learning This is a PyTorch implementation of the Asym-Siam paper, CVPR 2022: @inp

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).
This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Predicting Patient Outcomes with Graph Representation Learning This repository contains the code used for Predicting Patient Outcomes with Graph Repre

https://arxiv.org/abs/2102.11005
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

Comments
  • Custom Swin Transformer: error: unrecognized arguments

    Custom Swin Transformer: error: unrecognized arguments

    parser = argparse.ArgumentParser('Custom Swin Transformer')

    parser.add_argument( '--cfg', type=str, metavar="FILE", help='/content/Swin-Transformer-Tensorflow/configs/swin_tiny_patch4_window7_224.yaml', default="CUSTOM_YAML_FILE_PATH" ) parser.add_argument( '--resume', type=int, help=1, choices={0, 1}, default=1, ) parser.add_argument( '--weights_type', type=str, help='imagenet_22k', choices={"imagenet_1k", "imagenet_22k", "imagenet_22kto1k"}, default="imagenet_1k", )

    args = parser.parse_args() custom_config = get_config(args, include_top=False)

    i am trying to use it but it throws an error below

    usage: Custom Swin Transformer [-h] [--cfg FILE] [--resume {0,1}] [--weights_type {imagenet_22kto1k,imagenet_1k,imagenet_22k}] Custom Swin Transformer: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-ee309a98-1f20-4bb7-aa12-c2980aea076c.json An exception has occurred, use %tb to see the full traceback.

    SystemExit: 2

    opened by AliKayhanAtay 1
  • train dataset

    train dataset

    Thank you for Thank you for providing your code. I've been running the pretrained model, and I'd like to know how to learn about custom data from the code you provided and how to transfer learning to custom data using the pretrained model. Thank you.

    opened by hoyeoung 1
Human4D Dataset tools for processing and visualization

HUMAN4D: A Human-Centric Multimodal Dataset for Motions & Immersive Media HUMAN4D constitutes a large and multimodal 4D dataset that contains a variet

tofis 15 Nov 09, 2022
1st Solution For ICDAR 2021 Competition on Mathematical Formula Detection

This project releases our 1st place solution on ICDAR 2021 Competition on Mathematical Formula Detection. We implement our solution based on MMDetection, which is an open source object detection tool

yuxzho 94 Dec 25, 2022
Model Zoo of BDD100K Dataset

Model Zoo of BDD100K Dataset

ETH VIS Group 200 Dec 27, 2022
A Framework for Encrypted Machine Learning in TensorFlow

TF Encrypted is a framework for encrypted machine learning in TensorFlow. It looks and feels like TensorFlow, taking advantage of the ease-of-use of t

TF Encrypted 0 Jul 06, 2022
Selene is a Python library and command line interface for training deep neural networks from biological sequence data such as genomes.

Selene is a Python library and command line interface for training deep neural networks from biological sequence data such as genomes.

Troyanskaya Laboratory 323 Jan 01, 2023
ESP32 python application to read data from a Tilt™ Hydrometer for homebrewing

TitlESP32 ESP32 MicroPython application to read and log data from a Tilt™ Hydrometer. Requirements A board with an ESP32 chip USB cable - USB A / micr

IoBeer 5 Dec 01, 2022
[CVPR 2022] Official Pytorch code for OW-DETR: Open-world Detection Transformer

OW-DETR: Open-world Detection Transformer (CVPR 2022) [Paper] Akshita Gupta*, Sanath Narayan*, K J Joseph, Salman Khan, Fahad Shahbaz Khan, Mubarak Sh

Akshita Gupta 127 Dec 27, 2022
Mercer Gaussian Process (MGP) and Fourier Gaussian Process (FGP) Regression

Mercer Gaussian Process (MGP) and Fourier Gaussian Process (FGP) Regression We provide the code used in our paper "How Good are Low-Rank Approximation

Aristeidis (Ares) Panos 0 Dec 13, 2021
Code release for General Greedy De-bias Learning

General Greedy De-bias for Dataset Biases This is an extention of "Greedy Gradient Ensemble for Robust Visual Question Answering" (ICCV 2021, Oral). T

4 Mar 15, 2022
PyTorch implementation of "PatchGame: Learning to Signal Mid-level Patches in Referential Games" to appear in NeurIPS 2021

PatchGame: Learning to Signal Mid-level Patches in Referential Games This repository is the official implementation of the paper - "PatchGame: Learnin

Kamal Gupta 22 Mar 16, 2022
Training, generation, and analysis code for Learning Particle Physics by Example: Location-Aware Generative Adversarial Networks for Physics

Location-Aware Generative Adversarial Networks (LAGAN) for Physics Synthesis This repository contains all the code used in L. de Oliveira (@lukedeo),

Deep Learning for HEP 57 Oct 22, 2022
Official implementation of Self-supervised Image-to-text and Text-to-image Synthesis

Self-supervised Image-to-text and Text-to-image Synthesis This is the official implementation of Self-supervised Image-to-text and Text-to-image Synth

6 Jul 31, 2022
ExCon: Explanation-driven Supervised Contrastive Learning

ExCon: Explanation-driven Supervised Contrastive Learning Contributors of this repo: Zhibo Zhang ( Zhibo (Darren) Zhang 18 Nov 01, 2022

Road Crack Detection Using Deep Learning Methods

Road-Crack-Detection-Using-Deep-Learning-Methods This is my Diploma Thesis ¨Road Crack Detection Using Deep Learning Methods¨ under the supervision of

Aggelos Katsaliros 3 May 03, 2022
Multi-View Consistent Generative Adversarial Networks for 3D-aware Image Synthesis (CVPR2022)

Multi-View Consistent Generative Adversarial Networks for 3D-aware Image Synthesis Multi-View Consistent Generative Adversarial Networks for 3D-aware

Xuanmeng Zhang 78 Dec 10, 2022
Deep Residual Networks with 1K Layers

Deep Residual Networks with 1K Layers By Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. Microsoft Research Asia (MSRA). Table of Contents Introduc

Kaiming He 856 Jan 06, 2023
This repository consists of Blender python scripts and corresponding assets to generate variants of the CANDLE dataset

candle-simulator This repository consists of Blender python scripts and corresponding assets to generate variants of the IITH-CANDLE dataset. The rend

1 Dec 15, 2021
Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21)

Learning Structural Edits via Incremental Tree Transformations Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21) 1.

NeuLab 40 Dec 23, 2022
It is an open dataset for object detection in remote sensing images.

RSOD-Dataset It is an open dataset for object detection in remote sensing images. The dataset includes aircraft, oiltank, playground and overpass. The

136 Dec 08, 2022
Specificity-preserving RGB-D Saliency Detection

Specificity-preserving RGB-D Saliency Detection Authors: Tao Zhou, Huazhu Fu, Geng Chen, Yi Zhou, Deng-Ping Fan, and Ling Shao. 1. Preface This reposi

Tao Zhou 35 Jan 08, 2023