[IJCAI-2021] A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"

Overview

DataFree

A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"

Authors: Gongfan Fang, Jie Song, Xinchao Wang, Chengchao Shen, Xingen Wang, Mingli Song

CMI (this work) DeepInv
ZSKT DFQ

Results

1. CIFAR-10

Method resnet-34
resnet-18
vgg-11
resnet-18
wrn-40-2
wrn-16-1
wrn-40-2
wrn-40-1
wrn-40-2
wrn-16-2
T. Scratch 95.70 92.25 94.87 94.87 94.87
S. Scratch 95.20 95.20 91.12 93.94 93.95
DAFL 92.22 81.10 65.71 81.33 81.55
ZSKT 93.32 89.46 83.74 86.07 89.66
DeepInv 93.26 90.36 83.04 86.85 89.72
DFQ 94.61 90.84 86.14 91.69 92.01
CMI 94.84 91.13 90.01 92.78 92.52

2. CIFAR-100

Method resnet-34
resnet-18
vgg-11
resnet-18
wrn-40-2
wrn-16-1
wrn-40-2
wrn-40-1
wrn-40-2
wrn-16-2
T. Scratch 78.05 71.32 75.83 75.83 75.83
S. Scratch 77.10 77.01 65.31 72.19 73.56
DAFL 74.47 57.29 22.50 34.66 40.00
ZSKT 67.74 34.72 30.15 29.73 28.44
DeepInv 61.32 54.13 53.77 61.33 61.34
DFQ 77.01 68.32 54.77 62.92 59.01
CMI 77.04 70.56 57.91 68.88 68.75

Quick Start

1. Visualize the inverted samples

Results will be saved as checkpoints/datafree-cmi/synthetic-cmi_for_vis.png

bash scripts/cmi/cmi_cifar10_for_vis.sh

2. Reproduce our results

Note: This repo was refactored from our experimental code and is still under development. I'm struggling to find the appropriate hyperparams for every methods (°ー°〃). So far, we only provide the hyperparameters to reproduce CIFAR-10 results for wrn-40-2 => wrn-16-1. You may need to tune the hyper-parameters for other models and datasets. More resources will be uploaded in the future update.

To reproduce our results, please download pre-trained teacher models from Dropbox-Models (266 MB) and extract them as checkpoints/pretrained. Also a pre-inverted data set with ~50k samples is available for wrn-40-2 teacher on CIFAR-10. You can download it from Dropbox-Data (133 MB) and extract them to run/cmi-preinverted-wrn402/.

  • Non-adversarial CMI: you can train a student model on inverted data directly. It should reach the accuracy of ~87.38% on CIFAR-10 as reported in Figure 3.

    bash scripts/cmi/nonadv_cmi_cifar10_wrn402_wrn161.sh
    
  • Adversarial CMI: or you can apply the adversarial distillation based on the pre-inverted data, where ~10k (256x40) new samples will be generated to improve the student. It should reach the accuracy of ~90.01% on CIFAR-10 as reported in Table 1.

    bash scripts/cmi/adv_cmi_cifar10_wrn402_wrn161.sh
    
  • Scratch CMI: It is OK to run the cmi algorithm wihout any pre-inverted data, but the student may overfit to early samples due to the limited data amount. It should reach the accuracy of ~88.82% on CIFAR-10, slightly worse than our reported results (90.01%).

    bash scripts/cmi/scratch_cmi_cifar10_wrn402_wrn161.sh
    

3. Scratch training

python train_scratch.py --model wrn40_2 --dataset cifar10 --batch-size 256 --lr 0.1 --epoch 200 --gpu 0

4. Vanilla KD

# KD with original training data (beta>0 to use hard targets)
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set cifar10 --beta 0.1 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

# KD with unlabeled data
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set cifar100 --beta 0 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

# KD with unlabeled data from a specified folder
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set run/cmi --beta 0 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

5. Data-free KD

bash scripts/xxx/xxx.sh # e.g. scripts/zskt/zskt_cifar10_wrn402_wrn161.sh

Hyper-parameters used by different methods:

Method adv bn oh balance act cr GAN Example
DAFL - - - scripts/dafl_cifar10.sh
ZSKT - - - - - scripts/zskt_cifar10.sh
DeepInv - - - - scripts/deepinv_cifar10.sh
DFQ - - scripts/dfq_cifar10.sh
CMI - - scripts/cmi_cifar10_scratch.sh

4. Use your models/datasets

You can register your models and datasets in registry.py by modifying NORMALIZE_DICT, MODEL_DICT and get_dataset. Then you can run the above commands to train your own models. As DAFL requires intermediate features from the penultimate layer, your model should accept an return_features=True parameter and return a (logits, features) tuple for DAFL.

5. Implement your algorithms

Your algorithms should inherent datafree.synthesis.BaseSynthesizer to implement two interfaces: 1) BaseSynthesizer.synthesize takes several steps to craft new samples and return an image dict for visualization; 2) BaseSynthesizer.sample fetches a batch of training data for KD.

Citation

If you found this work useful for your research, please cite our paper:

@misc{fang2021contrastive,
      title={Contrastive Model Inversion for Data-Free Knowledge Distillation}, 
      author={Gongfan Fang and Jie Song and Xinchao Wang and Chengchao Shen and Xingen Wang and Mingli Song},
      year={2021},
      eprint={2105.08584},
      archivePrefix={arXiv},
      primaryClass={cs.AI}
}

Reference

Owner
ZJU-VIPA
Laboratory of Visual Intelligence and Pattern Analysis
ZJU-VIPA
A PyTorch implementation of "Capsule Graph Neural Network" (ICLR 2019).

CapsGNN ⠀⠀ A PyTorch implementation of Capsule Graph Neural Network (ICLR 2019). Abstract The high-quality node embeddings learned from the Graph Neur

Benedek Rozemberczki 1.2k Jan 02, 2023
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation

PyGRANSO PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation Please check https://ncvx.org/PyGRANSO for detailed instructions (introd

SUN Group @ UMN 26 Nov 16, 2022
Deep Learning and Reinforcement Learning Library for Scientists and Engineers 🔥

TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides an extens

TensorLayer Community 7.1k Dec 29, 2022
HiFi-GAN: High Fidelity Denoising and Dereverberation Based on Speech Deep Features in Adversarial Networks

HiFiGAN Denoiser This is a Unofficial Pytorch implementation of the paper HiFi-GAN: High Fidelity Denoising and Dereverberation Based on Speech Deep F

Rishikesh (ऋषिकेश) 134 Dec 27, 2022
🔥 Real-time Super Resolution enhancement (4x) with content loss and relativistic adversarial optimization 🔥

🔥 Real-time Super Resolution enhancement (4x) with content loss and relativistic adversarial optimization 🔥

Rishik Mourya 48 Dec 20, 2022
We utilize deep reinforcement learning to obtain favorable trajectories for visual-inertial system calibration.

Unified Data Collection for Visual-Inertial Calibration via Deep Reinforcement Learning Update: The lastest code will be updated in this branch. Pleas

ETHZ ASL 27 Dec 29, 2022
Unofficial implementation of the Involution operation from CVPR 2021

involution_pytorch Unofficial PyTorch implementation of "Involution: Inverting the Inherence of Convolution for Visual Recognition" by Li et al. prese

Rishabh Anand 46 Dec 07, 2022
A pytorch implementation of Detectron. Both training from scratch and inferring directly from pretrained Detectron weights are available.

Use this instead: https://github.com/facebookresearch/maskrcnn-benchmark A Pytorch Implementation of Detectron Example output of e2e_mask_rcnn-R-101-F

Roy 2.8k Dec 29, 2022
A curated list of awesome Model-Based RL resources

Awesome Model-Based Reinforcement Learning This is a collection of research papers for model-based reinforcement learning (mbrl). And the repository w

OpenDILab 427 Jan 03, 2023
Resources for our AAAI 2022 paper: "LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification".

LOREN Resources for our AAAI 2022 paper (pre-print): "LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification". DEMO System Check out o

Jiangjie Chen 37 Dec 27, 2022
Garbage classification using structure data.

垃圾分类模型使用说明 1.包含以下数据文件 文件 描述 data/MaterialMapping.csv 物体以及其归类的信息 data/TestRecords 光谱原始测试数据 CSV 文件 data/TestRecordDesc.zip CSV 文件描述文件 data/Boundaries.cs

wenqi 1 Dec 10, 2021
Implementation of "Semi-supervised Domain Adaptive Structure Learning"

Semi-supervised Domain Adaptive Structure Learning - ASDA This repo contains the source code and dataset for our ASDA paper. Illustration of the propo

3 Dec 13, 2021
Codebase for "ProtoAttend: Attention-Based Prototypical Learning."

Codebase for "ProtoAttend: Attention-Based Prototypical Learning." Authors: Sercan O. Arik and Tomas Pfister Paper: Sercan O. Arik and Tomas Pfister,

47 2 May 17, 2022
SigOpt wrappers for scikit-learn methods

SigOpt + scikit-learn Interfacing This package implements useful interfaces and wrappers for using SigOpt and scikit-learn together Getting Started In

SigOpt 73 Sep 30, 2022
sssegmentation is a general framework for our research on strongly supervised semantic segmentation.

sssegmentation is a general framework for our research on strongly supervised semantic segmentation.

445 Jan 02, 2023
Forecasting Nonverbal Social Signals during Dyadic Interactions with Generative Adversarial Neural Networks

ForecastingNonverbalSignals This is the implementation for the paper Forecasting Nonverbal Social Signals during Dyadic Interactions with Generative A

1 Feb 10, 2022
Differential fuzzing for the masses!

NEZHA NEZHA is an efficient and domain-independent differential fuzzer developed at Columbia University. NEZHA exploits the behavioral asymmetries bet

147 Dec 05, 2022
PyTorch implementation of "A Two-Stage End-to-End System for Speech-in-Noise Hearing Aid Processing"

Implementation of the Sheffield entry for the first Clarity enhancement challenge (CEC1) This repository contains the PyTorch implementation of "A Two

10 Aug 19, 2022
Framework for estimating the structures and parameters of Bayesian networks (DAGs) at per-sample resolution

Sample-specific Bayesian Networks A framework for estimating the structures and parameters of Bayesian networks (DAGs) at per-sample or per-patient re

Caleb Ellington 1 Sep 23, 2022