CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

Overview

CSWin-Transformer

PWC PWC

This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". The code and models for downstream tasks are coming soon.

Introduction

CSWin Transformer (the name CSWin stands for Cross-Shaped Window) is introduced in arxiv, which is a new general-purpose backbone for computer vision. It is a hierarchical Transformer and replaces the traditional full attention with our newly proposed cross-shaped window self-attention. The cross-shaped window self-attention mechanism computes self-attention in the horizontal and vertical stripes in parallel that from a cross-shaped window, with each stripe obtained by splitting the input feature into stripes of equal width. With CSWin, we could realize global attention with a limited computation cost.

CSWin Transformer achieves strong performance on ImageNet classification (87.5 on val with only 97G flops) and ADE20K semantic segmentation (55.7 mIoU on val), surpassing previous models by a large margin.

teaser

Main Results on ImageNet

model pretrain resolution [email protected] #params FLOPs 22K model 1K model
CSWin-T ImageNet-1K 224x224 82.8 23M 4.3G - model
CSWin-S ImageNet-1k 224x224 83.6 35M 6.9G - model
CSWin-B ImageNet-1k 224x224 84.2 78M 15.0G - model
CSWin-B ImageNet-1k 384x384 85.5 78M 47.0G - model
CSWin-L ImageNet-22k 224x224 86.5 173M 31.5G model model
CSWin-L ImageNet-22k 384x384 87.5 173M 96.8G - model

Main Results on Downstream Tasks

COCO Object Detection

backbone Method pretrain lr Schd box mAP mask mAP #params FLOPS
CSwin-T Mask R-CNN ImageNet-1K 3x 49.0 43.6 42M 279G
CSwin-S Mask R-CNN ImageNet-1K 3x 50.0 44.5 54M 342G
CSwin-B Mask R-CNN ImageNet-1K 3x 50.8 44.9 97M 526G
CSwin-T Cascade Mask R-CNN ImageNet-1K 3x 52.5 45.3 80M 757G
CSwin-S Cascade Mask R-CNN ImageNet-1K 3x 53.7 46.4 92M 820G
CSwin-B Cascade Mask R-CNN ImageNet-1K 3x 53.9 46.4 135M 1004G

ADE20K Semantic Segmentation (val)

Backbone Method pretrain Crop Size Lr Schd mIoU mIoU (ms+flip) #params FLOPs
CSwin-T Semantic FPN ImageNet-1K 512x512 80K 48.2 - 26M 202G
CSwin-S Semantic FPN ImageNet-1K 512x512 80K 49.2 - 39M 271G
CSwin-B Semantic FPN ImageNet-1K 512x512 80K 49.9 - 81M 464G
CSwin-T UPerNet ImageNet-1K 512x512 160K 49.3 50.4 60M 959G
CSwin-S UperNet ImageNet-1K 512x512 160K 50.0 50.8 65M 1027G
CSwin-B UperNet ImageNet-1K 512x512 160K 50.8 51.7 109M 1222G
CSwin-B UPerNet ImageNet-22K 640x640 160K 51.8 52.6 109M 1941G
CSwin-L UperNet ImageNet-22K 640x640 160K 53.4 55.7 208M 2745G

Requirements

timm==0.3.4, pytorch>=1.4, opencv, ... , run:

bash install_req.sh

Apex for mixed precision training is used for finetuning. To install apex, run:

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Train

Train the three lite variants: CSWin-Tiny, CSWin-Small and CSWin-Base:

bash train.sh 8 --data <data path> --model CSWin_64_12211_tiny_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.2
bash train.sh 8 --data <data path> --model CSWin_64_24322_small_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.4
bash train.sh 8 --data <data path> --model CSWin_96_24322_base_224 -b 128 --lr 1e-3 --weight-decay .1 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99992 --drop-path 0.5

If you want to train our CSWin on images with 384x384 resolution, please use '--img-size 384'.

If the GPU memory is not enough, please use '-b 128 --lr 1e-3 --model-ema-decay 0.99992' or use checkpoint '--use-chk'.

Finetune

Finetune CSWin-Base with 384x384 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_96_24322_base_384 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 384 --warmup-epochs 0 --model-ema-decay 0.9998 --finetune <pretrained 224 model> --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1

Finetune ImageNet-22K pretrained CSWin-Large with 224x224 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_144_24322_large_224 -b 64 --lr 2.5e-4 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9996 --finetune <22k-pretrained model> --epochs 30 --mixup 0.01 --cooldown-epochs 10 --interpolation bicubic  --lr-scale 0.05 --drop-path 0.2 --cutmix 0.3 --use-chk --fine-22k --ema-finetune

If the GPU memory is not enough, please use checkpoint '--use-chk'.

Cite CSWin Transformer

@misc{dong2021cswin,
      title={CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows}, 
        author={Xiaoyi Dong and Jianmin Bao and Dongdong Chen and Weiming Zhang and Nenghai Yu and Lu Yuan and Dong Chen and Baining Guo},
        year={2021},
        eprint={2107.00652},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
}

Acknowledgement

This repository is built using the timm library and the DeiT repository.

License

This project is licensed under the license found in the LICENSE file in the root directory of this source tree.

Microsoft Open Source Code of Conduct

Contact Information

For help or issues using CSWin Transformer, please submit a GitHub issue.

For other communications related to CSWin Transformer, please contact Jianmin Bao ([email protected]), Dong Chen ([email protected]).

Comments
  • About the patches_resolution of the segmentation model

    About the patches_resolution of the segmentation model

    Hello, this work is interesting but I have some questions about the 'patches_resolution' of the segmentation model. I notice that the long side of the cross-shaped windows is the 'patches_resolution' rather than the real feature resoulution. For example, in the stage-3, the long side is 224 / 16 = 14. Do I understand it correctly? Does that make it impossible to exchannge information outside the 'patches_resolution' ?

    opened by danczs 2
  • The results of downstream task by my realization are poor

    The results of downstream task by my realization are poor

    There are some questions:

    1. the split size is still [1 2 7 7]?
    2. last stage branch_num is 2 or 1 ? The downstream task image resolution in last stages cannot equal to 7(split size). If not 1, the pretrained weights size is not matched
    3. pading is right in my realization ? pad_l = pad_t = 0 pad_r = (W_sp - W % W_sp) % W_sp pad_b = (H_sp - H % H_sp) % H_sp q = q.transpose(-2,-1).contiguous().view(B, H, W, C) k = q.transpose(-2,-1).contiguous().view(B, H, W, C) v = q.transpose(-2,-1).contiguous().view(B, H, W, C) if pad_r > 0 or pad_b > 0: q = F.pad(q, (0, 0, pad_l, pad_r, pad_t, pad_b)) k = F.pad(k, (0, 0, pad_l, pad_r, pad_t, pad_b)) v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = q.shape
    opened by Sunting78 2
  • Experiment setting for semantic segmentation

    Experiment setting for semantic segmentation

    Hi, thank you for the code. I implemented CSwin-T with FPN for semantic segmentation in ADE20K but couldn't reach the mIoU value of 48.2 as mentioned by you in the table. The maximum I could get was 39.9 mIoU, it will be great if you could share the exact experiment settings you used? Thanks

    opened by AnukritiSinghh 2
  • CSWin significantly slower than Swin?

    CSWin significantly slower than Swin?

    Greetings,

    From my benchmarks I have noticed that CSwin seems to be significantly slower than Swin when it comes to inference times, is this the expected behavior? While I can get predictions as fast as 20 miliseconds on Swin Large 384 it takes above 900 milisecond on CSWin_144_24322_large_384.

    I performed tests using FP16, torchscript, optimize_for_inference and torch.inference_mode

    opened by ErenBalatkan 2
  • Pretrained settings for object detection

    Pretrained settings for object detection

    Hi, I'm impressed by your excellent work.

    I have a question.

    I wonder which type of the pre-trained weights (224x224 or 384x384 finetuned) is used for object detection.

    I know both 224x224 and 384x384 are pre-trained on ImageNet-1k.

    opened by youngwanLEE 2
  • Error about building 384 models

    Error about building 384 models

    The code: https://github.com/microsoft/CSWin-Transformer/blob/d8be74a7833898f7bd9c77eb8c051d1b8bd5d753/models/cswin.py#L407 shoud be:

    model = CSWinTransformer(img_size=384, patch_size=4, embed_dim=96, depth=[2,4,32,2],
    

    And as the same: https://github.com/microsoft/CSWin-Transformer/blob/d8be74a7833898f7bd9c77eb8c051d1b8bd5d753/models/cswin.py#L414

    opened by TingquanGao 2
  • The problem is shown in the figure

    The problem is shown in the figure

    model = CSWinTransformer(patch_size=4, embed_dim=96, depth=[2,4,32,2], split_size=[1,2,12,12], num_heads=[4,8,16,32], mlp_ratio=4.).cuda().eval() inp = torch.rand(1, 3, 224, 224).cuda() outs = model(inp) for out in outs: print(out.shape)

    RuntimeError: shape '[1, 192, 1, 14, 1, 12]' is invalid for input of size 37632

    why? image

    opened by rui-cf 2
  • about the test result on imagenet-2012

    about the test result on imagenet-2012

    Hi! I've test the CSwin-Tiny-224 released pretrained weight, this is my data transforms during testing:

    DEFAULT_CROP_SIZE = 0.9
    scale_size = int(math.floor(image_size / DEFAULT_CROP_SIZE))
    transform = transforms.Compose(
            [
                transforms.Resize(scale_size, interpolation=3)  # 3: bibubic
                if image_size == 224
                else transforms.Resize(image_size, interpolation=3),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
            ]
        )
    

    I can only get 80.5% on imagenet2012 dataset which is inconsistent with the results as you mentioned in this repo, did I miss some details about the data-augmentation during testing?

    opened by rentainhe 1
  • CSwin Code for Segmentation with MMSegmentation

    CSwin Code for Segmentation with MMSegmentation

    Hi, Thank you for your work! I wonder if you plan to release the mmsegmentation code you used for the downstream segmentation task, just like in Swin-Transformer repo.

    Best,

    opened by WalBouss 1
  • How do you produce Table 9  (ablation on different attention mecahnisms) in the  paper?

    How do you produce Table 9 (ablation on different attention mecahnisms) in the paper?

    Hi, thanks for your nice work. I'm doing some comparison on different attention mechanisms, and want to follow your experimental settings. I meet two problems:

    1. Why the reported mIoU is 41.9 for Swin-T in Table 9, while it is 46.1 in Swin Paper?
    2. Can you provide detailed experimental settings for semantic segmentation and object detection in table 9 ?
    opened by rayleizhu 0
  • about the setting of --use_chk

    about the setting of --use_chk

    give the parameter of --use_chk can launch the torch.utils.checkpoint to save the GPU memory, and I wonder if this could hurt the final performance, thanks a lot!

    opened by go-ahead-maker 0
  • Input image normalization parameters for semantic segmentation.

    Input image normalization parameters for semantic segmentation.

    Hey, guys, cool work!

    Unfortunately, for me it is not quite obvious, which normalization parameters (mean, std) did you use, when trained semantic segmentation model on ADE20K. Are they still IMAGENET_DEFAULT_MEAN or you used another values?

    opened by NikAleksFed 0
  • Using transfer to train over food101

    Using transfer to train over food101

    Hi all! I'm trying to train a model for food101 using the using the CSWin_64_12211_tiny_224 model with its pretrained values. The thing is, during execution it looks like its training from 0 rather than reusing the pretrained weights. By this I mean the initial top5 accuracy is around 5% but my initial thoughts is that it should be higher than this.

    For this I loaded the pretrained model and changed it's classification layer in a separate script and saved it for use as follows

    model = create_model( 'CSWin_64_12211_tiny_224', pretrained=True, num_classes=1000, drop_rate=0.0, drop_connect_rate=None, # DEPRECATED, use drop_path drop_path_rate=0.2, drop_block_rate=None, global_pool=None, bn_tf=False, bn_momentum=None, bn_eps=None, checkpoint_path='', img_size=224, use_chk=True) chk_path = './pretrained/cswin_tiny_224.pth' load_checkpoint(model, chk_path) model.reset_classifier(101, 'max')

    These are some of the runs I tried ` bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1 --use-chk --num-classes 101 --pretrained --finetune ./pretrained/CSWin_64_12211_tiny_224101.pth

    bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --cooldown-epochs 10 --drop-path 0.2 --ema-finetune --cutmix 0.1 --use-chk --num-classes 101 --initial-checkpoint ./pretrained/CSWin_64_12211_tiny_224101.pth --lr-scale 1.0 --output ./full_base `

    Is there something I'm missing or a proper way I should try this?

    Thanks in advance for any help! :)

    opened by andreynz691 0
Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
Source Code For Template-Based Named Entity Recognition Using BART

Template-Based NER Source Code For Template-Based Named Entity Recognition Using BART Training Training train.py Inference inference.py Corpus ATIS (h

174 Dec 19, 2022
Tensorflow Implementation of the paper "Spectral Normalization for Generative Adversarial Networks" (ICML 2017 workshop)

tf-SNDCGAN Tensorflow implementation of the paper "Spectral Normalization for Generative Adversarial Networks" (https://www.researchgate.net/publicati

Nhat M. Nguyen 248 Nov 25, 2022
This is an official implementation for "AS-MLP: An Axial Shifted MLP Architecture for Vision".

AS-MLP architecture for Image Classification Model Zoo Image Classification on ImageNet-1K Network Resolution Top-1 (%) Params FLOPs Throughput (image

SVIP Lab 106 Dec 12, 2022
CS506-Spring2022 - Code and Slides for Boston University CS 506

CS 506 - Computational Tools for Data Science Code, slides, and notes for Boston

Lance Galletti 17 May 06, 2022
Gesture-Volume-Control - This Python program can adjust the system's volume by using hand gestures

Gesture-Volume-Control This Python program can adjust the system's volume by usi

VatsalAryanBhatanagar 1 Dec 30, 2021
Python scripts form performing stereo depth estimation using the HITNET model in Tensorflow Lite.

TFLite-HITNET-Stereo-depth-estimation Python scripts form performing stereo depth estimation using the HITNET model in Tensorflow Lite. Stereo depth e

Ibai Gorordo 22 Oct 20, 2022
CLEAR algorithm for multi-view data association

CLEAR: Consistent Lifting, Embedding, and Alignment Rectification Algorithm The Matlab, Python, and C++ implementation of the CLEAR algorithm, as desc

MIT Aerospace Controls Laboratory 30 Jan 02, 2023
[ICCV 2021] Amplitude-Phase Recombination: Rethinking Robustness of Convolutional Neural Networks in Frequency Domain

Amplitude-Phase Recombination (ICCV'21) Official PyTorch implementation of "Amplitude-Phase Recombination: Rethinking Robustness of Convolutional Neur

Guangyao Chen 53 Oct 05, 2022
Object Detection and Multi-Object Tracking

Object Detection and Multi-Object Tracking

Bobby Chen 1.6k Jan 04, 2023
Extracting and filtering paraphrases by bridging natural language inference and paraphrasing

nli2paraphrases Source code repository accompanying the preprint Extracting and filtering paraphrases by bridging natural language inference and parap

Matej Klemen 1 Mar 09, 2022
Diverse Object-Scene Compositions For Zero-Shot Action Recognition

Diverse Object-Scene Compositions For Zero-Shot Action Recognition This repository contains the source code for the use of object-scene compositions f

7 Sep 21, 2022
ICS 4u HD project, start before-wards. A curtain shooting game using python.

Touhou-Star-Salvation HDCH ICS 4u HD project, start before-wards. A curtain shooting game using python and pygame. By Jason Li For arts and gameplay,

15 Dec 22, 2022
This is the official implementation of our proposed SwinMR

SwinMR This is the official implementation of our proposed SwinMR: Swin Transformer for Fast MRI Please cite: @article{huang2022swin, title={Swi

A Yang Lab (led by Dr Guang Yang) 27 Nov 17, 2022
Fast and robust clustering of point clouds generated with a Velodyne sensor.

Depth Clustering This is a fast and robust algorithm to segment point clouds taken with Velodyne sensor into objects. It works with all available Velo

Photogrammetry & Robotics Bonn 957 Dec 21, 2022
Developing your First ML Workflow of the AWS Machine Learning Engineer Nanodegree Program

Exercises and project documentation for the 3. Developing your First ML Workflow of the AWS Machine Learning Engineer Nanodegree Program

Simona Mircheva 1 Jan 13, 2022
Embracing Single Stride 3D Object Detector with Sparse Transformer

SST: Single-stride Sparse Transformer This is the official implementation of paper: Embracing Single Stride 3D Object Detector with Sparse Transformer

TuSimple 385 Dec 28, 2022
A Model for Natural Language Attack on Text Classification and Inference

TextFooler A Model for Natural Language Attack on Text Classification and Inference This is the source code for the paper: Jin, Di, et al. "Is BERT Re

Di Jin 418 Dec 16, 2022
A Neural Net Training Interface on TensorFlow, with focus on speed + flexibility

Tensorpack is a neural network training interface based on TensorFlow. Features: It's Yet Another TF high-level API, with speed, and flexibility built

Tensorpack 6.2k Jan 01, 2023
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. Diffrax is a JAX-based library providing numerical differe

Patrick Kidger 717 Jan 09, 2023
Implementation of the bachelor's thesis "Real-time stock predictions with deep learning and news scraping".

Real-time stock predictions with deep learning and news scraping This repository contains a partial implementation of my bachelor's thesis "Real-time

David Álvarez de la Torre 0 Feb 09, 2022