Tensorflow implementation of Swin Transformer model.

Overview

Swin Transformer (Tensorflow)

Tensorflow reimplementation of Swin Transformer model.

Based on Official Pytorch implementation. image

Requirements

  • tensorflow >= 2.4.1

Pretrained Swin Transformer Checkpoints

ImageNet-1K and ImageNet-22K Pretrained Checkpoints

name pretrain resolution [email protected] #params model
swin_tiny_224 ImageNet-1K 224x224 81.2 28M github
swin_small_224 ImageNet-1K 224x224 83.2 50M github
swin_base_224 ImageNet-22K 224x224 85.2 88M github
swin_base_384 ImageNet-22K 384x384 86.4 88M github
swin_large_224 ImageNet-22K 224x224 86.3 197M github
swin_large_384 ImageNet-22K 384x384 87.3 197M github

Examples

Initializing the model:

from swintransformer import SwinTransformer

model = SwinTransformer('swin_tiny_224', num_classes=1000, include_top=True, pretrained=False)

You can use a pretrained model like this:

import tensorflow as tf
from swintransformer import SwinTransformer

model = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
  SwinTransformer('swin_tiny_224', include_top=False, pretrained=True),
  tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

If you use a pretrained model with TPU on kaggle, specify use_tpu option:

import tensorflow as tf
from swintransformer import SwinTransformer

model = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
  SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
  tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

Example: TPU training on Kaggle

Citation

@article{liu2021Swin,
  title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
  author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
  journal={arXiv preprint arXiv:2103.14030},
  year={2021}
}
Comments
  • no module name 'swintransformer' error

    no module name 'swintransformer' error

    I wounder where the from swintransformer import SwinTransformer come from? I tried to pip install it, it also said that there is no such module. How can I overcome this problem?

    opened by HunarAA 2
  • Pretrained Swin-Transformer for multiple output

    Pretrained Swin-Transformer for multiple output

    Hi rishigami,

    Thank you for the implementation in Tensorflow. I am trying to use the Swin Transformer for a classification problem with multiple outputs. In your guide on how to use a pertained model you put it in a Sequential mode, but in this way I am not able to stack multiple dense layer for the multiple classification, could you help me understand how can I adapt your TF code to my problem, using it in a Functional API way maybe?

    opened by imanuelroz 2
  • NotImplementedError during model save

    NotImplementedError during model save

    I have defined a model as follows:

    def buildModel(LR = LR):
        backbone = SwinTransformer('swin_large_224', num_classes=None, include_top=False, pretrained=True, use_tpu=False)
        
        inp = L.Input(shape=(224,224,3))
        emb = backbone(inp)
        out = L.Dense(1,activation="relu")(emb)
        
        model = tf.keras.Model(inputs=inp,outputs=out)
        optimizer = tf.keras.optimizers.Adam(lr = LR)
        model.compile(loss="mse",optimizer=optimizer,metrics=[tf.keras.metrics.RootMeanSquaredError()])
        return model
    

    Now when I save this model using model.save("./model.hdf5") I get the following error:

    NotImplementedError                       Traceback (most recent call last)
    /tmp/ipykernel_43/131311624.py in <module>
    ----> 1 model.save("model.hdf5")
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
       2000     # pylint: enable=line-too-long
       2001     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
    -> 2002                     signatures, options, save_traces)
       2003 
       2004   def save_weights(self,
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
        152           'or using `save_weights`.')
        153     hdf5_format.save_model_to_hdf5(
    --> 154         model, filepath, overwrite, include_optimizer)
        155   else:
        156     saved_model_save.save(model, filepath, overwrite, include_optimizer,
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
        113 
        114   try:
    --> 115     model_metadata = saving_utils.model_metadata(model, include_optimizer)
        116     for k, v in model_metadata.items():
        117       if isinstance(v, (dict, list, tuple)):
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
        156   except NotImplementedError as e:
        157     if require_config:
    --> 158       raise e
        159 
        160   metadata = dict(
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
        153   model_config = {'class_name': model.__class__.__name__}
        154   try:
    --> 155     model_config['config'] = model.get_config()
        156   except NotImplementedError as e:
        157     if require_config:
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_config(self)
        648 
        649   def get_config(self):
    --> 650     return copy.deepcopy(get_network_config(self))
        651 
        652   @classmethod
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_network_config(network, serialize_layer_fn)
       1347         filtered_inbound_nodes.append(node_data)
       1348 
    -> 1349     layer_config = serialize_layer_fn(layer)
       1350     layer_config['name'] = layer.name
       1351     layer_config['inbound_nodes'] = filtered_inbound_nodes
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
        248         return serialize_keras_class_and_config(
        249             name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
    --> 250       raise e
        251     serialization_config = {}
        252     for key, item in config.items():
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
        243     name = get_registered_name(instance.__class__)
        244     try:
    --> 245       config = instance.get_config()
        246     except NotImplementedError as e:
        247       if _SKIP_FAILED_SERIALIZATION:
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in get_config(self)
       2252 
       2253   def get_config(self):
    -> 2254     raise NotImplementedError
       2255 
       2256   @classmethod
    
    NotImplementedError: 
    
    opened by Bibhash123 1
  • Invalid argument

    Invalid argument

    this is my basic model

    
    with tpu_strategy.scope():
        model = tf.keras.Sequential([
                            tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(data, mode="torch"), 
                                                                input_shape=[224,224, 3]),
                            SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
                            tf.keras.layers.Dense(1, activation='sigmoid')
                                            ])
    
    model.compile(loss = tf.keras.losses.BinaryCrossentropy(),
                              optimizer = tf.keras.optimizers.Adam(learning_rate=cfg['LEARNING_RATE']),
                              metrics   = RMSE)
    
    

    I am getting this error,

    (3) Invalid argument: {{function_node __inference_train_function_705020}} Reshape's input dynamic dimension is decomposed into multiple output dynamic dimensions, but the constraint is ambiguous and XLA can't infer the output dimension %reshape.12202 = f32[256,144,576]{2,1,0} reshape(f32[36864,576]{1,0} %transpose.12194), metadata={op_type="Reshape" op_name="sequential_40/swin_large_384/sequential_39/basic_layer_28/sequential_35/swin_transformer_block_169/window_attention_169/layers0/blocks1/attn/qkv/Tensordot"}. [[{{node TPUReplicate/_compile/_17658394825749957328/_4}}]] [[tpu_compile_succeeded_assert/_11424487196827204192/_5/_209]]

    opened by AliKayhanAtay 1
  • relative_position_bias_table initialization

    relative_position_bias_table initialization

    Hi, In the official code, relative_position_bias_table is initialized in a truncated normal distribution. Is that part missing in this repo?

    Official code: https://github.com/microsoft/Swin-Transformer/blob/6bbd83ca617db8480b2fb9b335c476ffaf5afb1a/models/swin_transformer.py#L110

    This implem https://github.com/rishigami/Swin-Transformer-TF/blob/8986ca7b0e1f984437db2d8f17e0ecd87fadcd4f/swintransformer/model.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L70

    opened by gathierry 1
  • Image size other than default ones doesn't work

    Image size other than default ones doesn't work

    • Notebook: https://colab.research.google.com/drive/1nqYkQCUzShkVdqGxW4TyMrtAb0n5MBZR#scrollTo=G9ZVlphmqD7d Issue:
    • In swin_tiny_224 I've tried multiple of 224, 512x512, multiple of window_size. But nothing seems to work other than the 224x224.
    • Same goes for swin_large_384, only default size 384x384 works.

    I'm wondering if this is expected behavior or not. Is there any way to make it work for non-square image?

    opened by awsaf49 1
  • Added 3D support for SwinTransformerModel, ie for medical imaging tasks

    Added 3D support for SwinTransformerModel, ie for medical imaging tasks

    Tested and working, ie:

    IMAGE_SIZE = [112, 112, 112]
    NUM_CLASSES = 10
    
    model_3d = tf.keras.Sequential([
      swin_transformer_nd.SwinTransformerModel(img_size=IMAGE_SIZE, patch_size=(4, 4, 4), depths=[2, 2, 6]),
      tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
    ])
    model_3d.compile(tf.keras.optimizers.Adam(), "categorical_crossentropy")
    
    for i in range(100):
        x = np.zeros([1, *IMAGE_SIZE, 1])
        y = tf.zeros([1, NUM_CLASSES])
        
        model_3d.fit(x, y)
        print("Trained on a batch")
    
    opened by MohamadZeina 0
  • Could you provide weights convert script?

    Could you provide weights convert script?

    I tried code and weights you provided, and find the performance is bad. Could you pleaase to provide weights convert script for me to figure out this issue?

    Many thanks

    opened by edwardyehuang 0
  • tf load model is erro

    tf load model is erro

    import tensorflow as tf from swintransformer import SwinTransformer model = tf.keras.Sequential([ tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]), SwinTransformer('swin_tiny_224', include_top=False, pretrained=True), tf.keras.layers.Dense(NUM_CLASSES, activation='softmax') ])

    tf can't load pre trained model。this step is errro

    opened by jangjiun 0
  • Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel)

    Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel)

    Has anyone tried to use the pretrained model with TimeDistributed layer ?

    model = tf.keras.Sequential([ tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), 
    input_shape=[224,224, 3]), SwinTransformer('swin_base_224', include_top=False, pretrained=True)])
    
    model_f = models.Sequential()
    	model.add(TimeDistributed(model, input_shape= (8,224,224,3)) 
    
    

    I get the following error:

    NotImplementedError: Exception encountered when calling layer "time_distributed" (type TimeDistributed).
    
    Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel).
    
    Call arguments received by layer "time_distributed" (type TimeDistributed):
      • inputs=tf.Tensor(shape=(None, 8, 224, 224, 3), dtype=float32)
      • training=False
    
    
    opened by atelili 0
Releases(v0.1-tf-swin-weights)
FLVIS: Feedback Loop Based Visual Initial SLAM

FLVIS Feedback Loop Based Visual Inertial SLAM 1-Video EuRoC DataSet MH_05 Handheld Test in Lab FlVIS on UAV Platform 2-Relevent Publication: Under Re

UAV Lab - HKPolyU 182 Dec 04, 2022
we propose a novel deep network, named feature aggregation and refinement network (FARNet), for the automatic detection of anatomical landmarks.

Feature Aggregation and Refinement Network for 2D Anatomical Landmark Detection Overview Localization of anatomical landmarks is essential for clinica

aoyueyuan 0 Aug 28, 2022
Official code base for the poster "On the use of Cortical Magnification and Saccades as Biological Proxies for Data Augmentation" published in NeurIPS 2021 Workshop (SVRHM)

Self-Supervised Learning (SimCLR) with Biological Plausible Image Augmentations Official code base for the poster "On the use of Cortical Magnificatio

Binxu 8 Aug 17, 2022
Studying Python release adoptions by looking at PyPI downloads

Analysis of version adoptions on PyPI We get PyPI download statistics via Google's BigQuery using the pypinfo tool. Usage First you need to get an acc

Julien Palard 9 Nov 04, 2022
A minimalist tool to display a network graph.

A tool to get a minimalist view of any architecture This tool has only be tested with the models included in this repo. Therefore, I can't guarantee t

Thibault Castells 1 Feb 11, 2022
DSEE: Dually Sparsity-embedded Efficient Tuning of Pre-trained Language Models

DSEE Codes for [Preprint] DSEE: Dually Sparsity-embedded Efficient Tuning of Pre-trained Language Models Xuxi Chen, Tianlong Chen, Yu Cheng, Weizhu Ch

VITA 4 Dec 27, 2021
[ICCV 2021] Official Tensorflow Implementation for "Single Image Defocus Deblurring Using Kernel-Sharing Parallel Atrous Convolutions"

KPAC: Kernel-Sharing Parallel Atrous Convolutional block This repository contains the official Tensorflow implementation of the following paper: Singl

Hyeongseok Son 50 Dec 29, 2022
Code for the paper "There is no Double-Descent in Random Forests"

Code for the paper "There is no Double-Descent in Random Forests" This repository contains the code to run the experiments for our paper called "There

2 Jan 14, 2022
People log into different sites every day to get information and browse through these sites one by one

HyperLink People log into different sites every day to get information and browse through these sites one by one. And they are exposed to advertisemen

0 Feb 17, 2022
Head and Neck Tumour Segmentation and Prediction of Patient Survival Project

Head-and-Neck-Tumour-Segmentation-and-Prediction-of-Patient-Survival Welcome to the Head and Neck Tumour Segmentation and Prediction of Patient Surviv

5 Oct 20, 2022
Neural machine translation between the writings of Shakespeare and modern English using TensorFlow

Shakespeare translations using TensorFlow This is an example of using the new Google's TensorFlow library on monolingual translation going from modern

Motoki Wu 245 Dec 28, 2022
No-Reference Image Quality Assessment via Transformers, Relative Ranking, and Self-Consistency

This repository contains the implementation for the paper: No-Reference Image Quality Assessment via Transformers, Relative Ranking, and Self-Consiste

Alireza Golestaneh 75 Dec 30, 2022
TVNet: Temporal Voting Network for Action Localization

TVNet: Temporal Voting Network for Action Localization This repo holds the codes of paper: "TVNet: Temporal Voting Network for Action Localization". P

hywang 5 Jul 26, 2022
Distributed Asynchronous Hyperparameter Optimization in Python

Hyperopt: Distributed Hyperparameter Optimization Hyperopt is a Python library for serial and parallel optimization over awkward search spaces, which

6.5k Jan 01, 2023
A collection of inference modules for fastai2

fastinference A collection of inference modules for fastai including inference speedup and interpretability Install pip install fastinference There ar

Zachary Mueller 83 Oct 10, 2022
Train an imgs.ai model on your own dataset

imgs.ai is a fast, dataset-agnostic, deep visual search engine for digital art history based on neural network embeddings.

Fabian Offert 5 Dec 21, 2021
Offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation

Shunted Transformer This is the offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation by Sucheng Ren, Daquan Zhou, Shengf

156 Dec 27, 2022
Official public repository of paper "Intention Adaptive Graph Neural Network for Category-Aware Session-Based Recommendation"

Intention Adaptive Graph Neural Network (IAGNN) This is the official repository of paper Intention Adaptive Graph Neural Network for Category-Aware Se

9 Nov 22, 2022
A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution

DRSAN A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution Karam Park, Jae Woong Soh, and Nam Ik Cho Environments U

4 May 10, 2022