vit for few-shot classification

Overview

Few-Shot ViT

Requirements

  • PyTorch (>= 1.9)
  • TorchVision
  • timm (latest)
  • einops
  • tqdm
  • numpy
  • scikit-learn
  • scipy
  • argparse
  • tensorboardx

Pretrained Checkpoints

Currently we provide SUN-M (Visformer) trained on miniImageNet (5-way 1-shot and 5-way 5-shot), see Google Drive for details.

More pretrained checkpoints coming soon.

Evaluate the Pretrained Checkpoints

Prepare data

For example, miniImageNet:

cd test_phase

Download miniImageNet dataset from miniImageNet (courtesy of Spyros Gidaris)

unzip the package to materials/mini-imagenet, then obtain materials/mini-imagenet with pickle files.

Prapare pretrained checkpoints

Download corresponding checkpoints from Google Drive and store the checkpoints in test_phase/ directory.

Evaluation

cd test_phase
python test_few_shot.py --config configs/test_1_shot.yaml --shot 1 --gpu 1 # for 1-shot
python test_few_shot.py --config configs/test_5_shot.yaml --shot 5 --gpu 1 # for 5-shot

For 1-shot, you can obtain: test epoch 1: acc=67.80 +- 0.45 (%)

For 5-shot, you can obtain: test epoch 1: acc=83.25 +- 0.28 (%)

Test accuracy may slightly vary with different pytorch/cuda versions or different hardwares

TODO

  • more checkpoints
  • training code
You might also like...
So-ViT: Mind Visual Tokens for Vision Transformer
So-ViT: Mind Visual Tokens for Vision Transformer

So-ViT: Mind Visual Tokens for Vision Transformer        Introduction This repository contains the source code under PyTorch framework and models trai

A PyTorch Implementation of ViT (Vision Transformer)
A PyTorch Implementation of ViT (Vision Transformer)

ViT - Vision Transformer This is an implementation of ViT - Vision Transformer by Google Research Team through the paper "An Image is Worth 16x16 Word

PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT.

MoCo v3 for Self-supervised ResNet and ViT Introduction This is a PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT. The original M

Official implement of Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer
Official implement of Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer

Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer This repository contains the PyTorch code for Evo-ViT. This work proposes a slow-fas

This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

A simple approach to emable dense segmentation with ViT.
A simple approach to emable dense segmentation with ViT.

Vision Transformer Segmentation Network This implementation of ViT in pytorch uses a super simple and straight-forward way of generating an output of

PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-supervised ViT.

MAE for Self-supervised ViT Introduction This is an unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-sup

A simple program for training and testing vit

Vit This is a simple program for training and testing vit. Key requirements: torch, torchvision and timm. Dataset I put 5 categories of the cub classi

Implementing Vision Transformer (ViT) in PyTorch
Implementing Vision Transformer (ViT) in PyTorch

Lightning-Hydra-Template A clean and scalable template to kickstart your deep learning project 🚀 ⚡ 🔥 Click on Use this template to initialize new re

Comments
  • timm version

    timm version

    hello, I met a question when run your code as follow? Traceback (most recent call last): File "train_classifier.py", line 296, in <module> main(config) File "train_classifier.py", line 133, in main lr_scheduler = CosineLRScheduler(optimizer, warmup_lr_init=float(config['optimizer_args']['warmup_lr']), t_initial=config['max_epoch'], cycle_decay=0.1, warmup_t=int(config['optimizer_args']['warmup'])) TypeError: __init__() got an unexpected keyword argument 'cycle_decay' I think it's the version of timm package is not right, and the requirement in your code just say that is the latest version. can your provide the version of timm package??

    opened by JIAOJIAYUASD 2
  • The variant of visformer

    The variant of visformer

    Hi Bowen

    Thanks for opensource the inference code. I am just curious which variant of the visformer achieves the best results in Table 5 on mini-ImageNet? Is it visformer_80_small?

    opened by RongKaiWeskerMA 1
Releases(SUN)
Owner
Martin Dong
HIT student, major in Computer Science and Technology. CS.CV, object detection, segmentation, generation.
Martin Dong
OpenMMLab Detection Toolbox and Benchmark

MMDetection is an open source object detection toolbox based on PyTorch. It is a part of the OpenMMLab project.

OpenMMLab 22.5k Jan 05, 2023
TrackFormer: Multi-Object Tracking with Transformers

TrackFormer: Multi-Object Tracking with Transformers This repository provides the official implementation of the TrackFormer: Multi-Object Tracking wi

Tim Meinhardt 321 Dec 29, 2022
Official code of paper "PGT: A Progressive Method for Training Models on Long Videos" on CVPR2021

PGT Code for paper PGT: A Progressive Method for Training Models on Long Videos. Install Run pip install -r requirements.txt. Run python setup.py buil

Bo Pang 27 Mar 30, 2022
MazeRL is an application oriented Deep Reinforcement Learning (RL) framework

MazeRL is an application oriented Deep Reinforcement Learning (RL) framework, addressing real-world decision problems. Our vision is to cover the complete development life cycle of RL applications ra

EnliteAI GmbH 222 Dec 24, 2022
This repository contains small projects related to Neural Networks and Deep Learning in general.

ILearnDeepLearning.py Description People say that nothing develops and teaches you like getting your hands dirty. This repository contains small proje

Piotr Skalski 1.2k Dec 22, 2022
PyTorch3D is FAIR's library of reusable components for deep learning with 3D data

Introduction PyTorch3D provides efficient, reusable components for 3D Computer Vision research with PyTorch. Key features include: Data structure for

Facebook Research 6.8k Jan 01, 2023
Automatically measure the facial Width-To-Height ratio and get facial analysis results provided by Microsoft Azure

fwhr-calc-website This project is to automatically measure the facial Width-To-Height ratio and get facial analysis results provided by Microsoft Azur

SoohyunPark 1 Feb 07, 2022
[CVPR 2021] MiVOS - Scribble to Mask module

MiVOS (CVPR 2021) - Scribble To Mask Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [arXiv] [Paper PDF] [Project Page] A simplistic network that turns scri

Rex Cheng 65 Dec 22, 2022
PyTorch implementation of Neural Dual Contouring.

NDC PyTorch implementation of Neural Dual Contouring. Citation We are still writing the paper while adding more improvements and applications. If you

Zhiqin Chen 140 Dec 26, 2022
Implementation of gaze tracking and demo

Predicting Customer Demand by Using Gaze Detecting and Object Tracking This project is the integration of gaze detecting and object tracking. Predict

2 Oct 20, 2022
To prepare an image processing model to classify the type of disaster based on the image dataset

Disaster Classificiation using CNNs bunnysaini/Disaster-Classificiation Goal To prepare an image processing model to classify the type of disaster bas

Bunny Saini 1 Jan 24, 2022
Repository for the semantic WMI loss

Installation: pip install -e . Installing DL2: First clone DL2 in a separate directory and install it using the following commands: git clone https:/

Nick Hoernle 4 Sep 15, 2022
Implementation for HFGI: High-Fidelity GAN Inversion for Image Attribute Editing

HFGI: High-Fidelity GAN Inversion for Image Attribute Editing High-Fidelity GAN Inversion for Image Attribute Editing Update: We released the inferenc

Tengfei Wang 371 Dec 30, 2022
Bag of Tricks for Natural Policy Gradient Reinforcement Learning

Bag of Tricks for Natural Policy Gradient Reinforcement Learning [ArXiv] Setup Python 3.8.0 pip install -r req.txt Mujoco 200 license Main Files main.

Brennan Gebotys 1 Oct 10, 2022
Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021

Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021 [WIP] The code for CVPR 2021 paper 'Disentangled Cycle Consistency for H

ChongjianGE 94 Dec 11, 2022
Reinforcement learning library(framework) designed for PyTorch, implements DQN, DDPG, A2C, PPO, SAC, MADDPG, A3C, APEX, IMPALA ...

Automatic, Readable, Reusable, Extendable Machin is a reinforcement library designed for pytorch. Build status Platform Status Linux Windows Supported

Iffi 348 Dec 24, 2022
Easy way to add GoogleMaps to Flask applications. maintainer: @getcake

Flask Google Maps Easy to use Google Maps in your Flask application requires Jinja Flask A google api key get here Contribute To contribute with the p

Flask Extensions 611 Dec 05, 2022
An Implementation of SiameseRPN with Feature Pyramid Networks

SiameseRPN with FPN This project is mainly based on HelloRicky123/Siamese-RPN. What I've done is just add a Feature Pyramid Network method to the orig

3 Apr 16, 2022
Flexible time series feature extraction & processing

tsflex is a toolkit for flexible time series processing & feature extraction, that is efficient and makes few assumptions about sequence data. Useful

PreDiCT.IDLab 206 Dec 28, 2022
Meta graph convolutional neural network-assisted resilient swarm communications

Resilient UAV Swarm Communications with Graph Convolutional Neural Network This repository contains the source codes of Resilient UAV Swarm Communicat

62 Dec 06, 2022