Learning to Self-Train for Semi-Supervised Few-Shot

Overview

Learning to Self-Train for Semi-Supervised Few-Shot Classification

LICENSE Python TensorFlow

This repository contains the TensorFlow implementation for NeurIPS 2019 Paper "Learning to Self-Train for Semi-Supervised Few-Shot Classification".

Check the few-shot classification leaderboard.

Summary

Installation

In order to run this repository, we advise you to install python 2.7 or 3.5 and TensorFlow 1.3.0 with Anaconda.

You may download Anaconda and read the installation instruction on their official website: https://www.anaconda.com/download/

Create a new environment and install tensorflow on it:

conda create --name lst-tf python=2.7
conda activate lst-tf
conda install tensorflow-gpu=1.3.0

Install other requirements:

pip install scipy tqdm opencv-python pillow matplotlib

Clone this repository:

git clone https://github.com/xinzheli1217/learning-to-self-train.git 
cd learning-to-self-train

Project Architecture

.
├── data_generator              # dataset generator 
|   └── meta_data_generator.py  # data genertor for meta-train phase
├── models                      # tensorflow model files 
|   ├── models.py               # resnet12 CNN class
|   └── meta_model_LST.py       # semi-supervised meta-train model class
├── trainer                     # tensorflow trianer files  
|   └── meta_LST.py             # semi-supervised meta-train trainer class
├── utils                       # a series of tools used in this repo
|   └── misc.py                 # miscellaneous tool functions
| 
├── data                        # the folder containing datasets for experiments
├── pretrain_weights_dir        # the folder containing MTL pre-training weights
├── weights_saving_dir          # the folder containing meta-training weights
├── test_output_dir             # the folder containing meta-testing files
├── filenames_and_labels        # the folder containing image file paths and labels for experiments
|
├── exp_train.py                # the python file with main function and parameter settings for meta-training
└── exp_test.py                 # the python file with main function and parameter settings for meta-testing

Running Experiments

First, download our processed images: miniImagenet[Download Page] or tieredImagenet[Download Page], move the unziped folder to ./data. And then download the pre-trained models: miniImagenet[Download Page] or tieredImagenet[Download Page], move the unziped folder to ./pretrain_weights_dir.

Training from Pre-Trained Models

Run semi-supervised meta-train phase (e.g. 𝑚𝑖𝑛𝑖ImageNet, 1-shot) :

python exp_train.py --shot_num=1 --dataset='miniImagenet' --pretrain_class_num=64 --nb_ul_samples=10 --metatrain_iterations=15000 --exp_name='LST_mini_1_shot'

Run semi-supervised meta-test phase (e.g. 𝑚𝑖𝑛𝑖ImageNet, 1-shot) :

python exp_test.py --shot_num=1 --dataset='miniImagenet' --pretrain_class_num=64 --use_distractors=False --nb_ul_samples=100 --unfiles_num=10 --test_iter=15000 --recurrent_stage_nums=6 --nums_in_folders=30 --hard_selection=20 --exp_name='LST_mini_1_shot' 

Hyperparameters and Options

There are some main hyperparameters used in the experiments, you can edit them in the exp_train.py and the exp_test.py file for meta-train and meta-test phase respectively. There are two kinds of hyperparameters: (1) common hyperparameters that shared with meta-train and meta-test, (2) test-specific hyperparameters that used for recurrent self-training process in meta-test.

  • Common hyperparameters:

    • way_num number of classes
    • shot_num number of examples per class
    • dataset dataset used in the experiment (miniImagenet or tieredImagenet)
    • pretrain_class_num number of meta-train classes
    • exp_name name for the current experiment
    • meta_batch_size number of tasks sampled per meta-update in meta-train phase
    • base_lr step size alpha for inner gradient update
    • meta_lr the meta learning rate for SS and initial model parameters
    • min_meta_lr the min meta learning rate for all meta-parameters
    • swn_lr the meta learning rate for SWN
    • nb_ul_samples number of unlabeled examples per class
    • re_train_epoch_num number of re-training inner gradient updates
    • train_base_epoch_num number of total inner gradient updates during train (meta-train only)
    • test_base_epoch_num number of total inner gradient updates during test (meta-test only)
  • Test-specific hyperparameters:

    • use_distractors if using distractor classes during meta-test
    • num_dis number of distracting classes used for meta-testing
    • unfiles_num number of unlabeled sample files used in the experiment (There are 10 unlabeled samples per class in each file)
    • recurrent_stage_nums number of recurrent stages used during meta-test
    • local_update_num number of inner gradient updates used in each recurrent stage
    • nums_in_folders number of unlabeled samples (per class) used in each recurrent stage
    • hard_selection number of remaining samples (per class) after applying hard-selection

If you want to change other settings, please see the comments and descriptions in exp_train.py and exp_test.py.

Performance

(%) 𝑚𝑖𝑛𝑖 𝒕𝒊𝒆𝒓𝒆𝒅 𝑚𝑖𝑛𝑖 (w/D) 𝒕𝒊𝒆𝒓𝒆𝒅 (w/D)
1-shot 70.1 ± 1.9 77.7 ± 1.6 64.1 ± 1.9 73.5 ± 1.6
5-shot 78.7 ± 0.8 85.2 ± 0.8 77.4 ± 1.8 83.4 ± 0.8

Citation

Please cite our paper if it is helpful to your work:

@inproceedings{li2019lst,
  title={Learning to Self-Train for Semi-Supervised Few-Shot Classification},
  author = {Li, Xinzhe and Sun, Qianru and Liu, Yaoyao and Zhou, Qin and Zheng, Shibao and Chua, Tat-Seng and Schiele, Bernt},
  booktitle={NeurIPS},
  year={2019}
}

Acknowledgements

Our implementations use the source code from the following repositories and users:

An OpenAI Gym environment for multi-agent car racing based on Gym's original car racing environment.

Multi-Car Racing Gym Environment This repository contains MultiCarRacing-v0 a multiplayer variant of Gym's original CarRacing-v0 environment. This env

Igor Gilitschenski 56 Nov 01, 2022
This is the code for ACL2021 paper A Unified Generative Framework for Aspect-Based Sentiment Analysis

This is the code for ACL2021 paper A Unified Generative Framework for Aspect-Based Sentiment Analysis Install the package in the requirements.txt, the

108 Dec 23, 2022
Feature board for ERPNext

ERPNext Feature Board Feature board for ERPNext Development Prerequisites k3d kubectl helm bench Install K3d Cluster # export K3D_FIX_CGROUPV2=1 # use

Revant Nandgaonkar 16 Nov 09, 2022
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
Self-Guided Contrastive Learning for BERT Sentence Representations

Self-Guided Contrastive Learning for BERT Sentence Representations This repository is dedicated for releasing the implementation of the models utilize

Taeuk Kim 16 Dec 04, 2022
Artifacts for paper "MMO: Meta Multi-Objectivization for Software Configuration Tuning"

MMO: Meta Multi-Objectivization for Software Configuration Tuning This repository contains the data and code for the following paper that is currently

0 Nov 17, 2021
Multispectral Object Detection with Yolov5

Multispectral-Object-Detection Intro Official Code for Cross-Modality Fusion Transformer for Multispectral Object Detection. Multispectral Object Dete

Richard Fang 121 Jan 01, 2023
This repository contains the source code of an efficient 1D probabilistic model for music time analysis proposed in ICASSP2022 venue.

Jump Reward Inference for 1D Music Rhythmic State Spaces An implementation of the probablistic jump reward inference model for music rhythmic informat

Mojtaba Heydari 25 Dec 16, 2022
Learning infinite-resolution image processing with GAN and RL from unpaired image datasets, using a differentiable photo editing model.

Exposure: A White-Box Photo Post-Processing Framework ACM Transactions on Graphics (presented at SIGGRAPH 2018) Yuanming Hu1,2, Hao He1,2, Chenxi Xu1,

Yuanming Hu 719 Dec 29, 2022
Code for testing convergence rates of Lipschitz learning on graphs

📈 LipschitzLearningRates The code in this repository reproduces the experimental results on convergence rates for k-nearest neighbor graph infinity L

2 Dec 20, 2021
Certis - Certis, A High-Quality Backtesting Engine

Certis - Backtesting For y'all Certis is a powerful, lightweight, simple backtes

Yeachan-Heo 46 Oct 30, 2022
Implementing DropPath/StochasticDepth in PyTorch

%load_ext memory_profiler Implementing Stochastic Depth/Drop Path In PyTorch DropPath is available on glasses my computer vision library! Introduction

Francesco Saverio Zuppichini 13 Jan 05, 2023
Fibonacci Method Gradient Descent

An implementation of the Fibonacci method for gradient descent, featuring a TKinter GUI for inputting the function / parameters to be examined and a matplotlib plot of the function and results.

Emma 1 Jan 28, 2022
Official Repository for the ICCV 2021 paper "PixelSynth: Generating a 3D-Consistent Experience from a Single Image"

PixelSynth: Generating a 3D-Consistent Experience from a Single Image (ICCV 2021) Chris Rockwell, David F. Fouhey, and Justin Johnson [Project Website

Chris Rockwell 95 Nov 22, 2022
Official Implement of CVPR 2021 paper “Cross-Modal Collaborative Representation Learning and a Large-Scale RGBT Benchmark for Crowd Counting”

RGBT Crowd Counting Lingbo Liu, Jiaqi Chen, Hefeng Wu, Guanbin Li, Chenglong Li, Liang Lin. "Cross-Modal Collaborative Representation Learning and a L

37 Dec 08, 2022
Experiments for Fake News explainability project

fake-news-explainability Experiments for fake news explainability project This repository only contains the notebooks used to train the models and eva

Lorenzo Flores (Lj) 1 Dec 03, 2022
Pose Transformers: Human Motion Prediction with Non-Autoregressive Transformers

Pose Transformers: Human Motion Prediction with Non-Autoregressive Transformers This is the repo used for human motion prediction with non-autoregress

Idiap Research Institute 26 Dec 14, 2022
Hitters Linear Regression - Hitters Linear Regression With Python

Hitters_Linear_Regression Kullanacağımız veri seti Carnegie Mellon Üniversitesi'

AyseBuyukcelik 2 Jan 26, 2022
Unofficial implementation of "Coordinate Attention for Efficient Mobile Network Design"

Unofficial implementation of "Coordinate Attention for Efficient Mobile Network Design". CoordAttention tensorflow slim

Billy 9 Aug 22, 2022
Implementations of CNNs, RNNs, GANs, etc

Tensorflow Programs and Tutorials This repository will contain Tensorflow tutorials on a lot of the most popular deep learning concepts. It'll also co

Adit Deshpande 1k Dec 30, 2022