Official code for our CVPR '22 paper "Dataset Distillation by Matching Training Trajectories"

Overview

Dataset Distillation by Matching Training Trajectories

Project Page | Paper


Teaser image

This repo contains code for training expert trajectories and distilling synthetic data from our Dataset Distillation by Matching Training Trajectories paper (CVPR 2022). Please see our project page for more results.

Dataset Distillation by Matching Training Trajectories
George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A. Efros, Jun-Yan Zhu
CMU, MIT, UC Berkeley
CVPR 2022

The task of "Dataset Distillation" is to learn a small number of synthetic images such that a model trained on this set alone will have similar test performance as a model trained on the full real dataset.

Our method distills the synthetic dataset by directly optimizing the fake images to induce similar network training dynamics as the full, real dataset. We train "student" networks for many iterations on the synthetic data, measure the error in parameter space between the "student" and "expert" networks trained on real data, and back-propagate through all the student network updates to optimize the synthetic pixels.

Wearable ImageNet: Synthesizing Tileable Textures

Teaser image

Instead of treating our synthetic data as individual images, we can instead encourage every random crop (with circular padding) on a larger canvas of pixels to induce a good training trajectory. This results in class-based textures that are continuous around their edges.

Given these tileable textures, we can apply them to areas that require such properties, such as clothing patterns.

Visualizations made using FAB3D

Getting Started

First, download our repo:

git clone https://github.com/GeorgeCazenavette/mtt-distillation.git
cd mtt-distillation

For an express instillation, we include .yaml files.

If you have an RTX 30XX GPU (or newer), run

conda env create -f requirements_11_3.yaml

If you have an RTX 20XX GPU (or older), run

conda env create -f requirements_10_2.yaml

You can then activate your conda environment with

conda activate distillation
Quadro Users Take Note:

torch.nn.DataParallel seems to not work on Quadro A5000 GPUs, and this may extend to other Quadro cards.

If you experience indefinite hanging during training, try running the process with only 1 GPU by prepending CUDA_VISIBLE_DEVICES=0 to the command.

Generating Expert Trajectories

Before doing any distillation, you'll need to generate some expert trajectories using buffer.py

The following command will train 100 ConvNet models on CIFAR-100 with ZCA whitening for 50 epochs each:

python buffer.py --dataset=CIFAR100 --model=ConvNet --train_epochs=50 --num_experts=100 --zca --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

We used 50 epochs with the default learning rate for all of our experts. Worse (but still interesting) results can be obtained faster through training fewer experts by changing --num_experts. Note that experts need only be trained once and can be re-used for multiple distillation experiments.

Distillation by Matching Training Trajectories

The following command will then use the buffers we just generated to distill CIFAR-100 down to just 1 image per class:

python distill.py --dataset=CIFAR100 --ipc=1 --syn_steps=20 --expert_epochs=3 --max_start_epoch=20 --zca --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

ImageNet

Our method can also distill subsets of ImageNet into low-support synthetic sets.

When generating expert trajectories with buffer.py or distilling the dataset with distill.py, you must designate a named subset of ImageNet with the --subset flag.

For example,

python distill.py --dataset=ImageNet --subset=imagefruit --model=ConvNetD5 --ipc=1 --res=128 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagefruit subset (at 128x128 resolution) into the following 10 images

To register your own ImageNet subset, you can add it to the Config class at the top of utils.py.

Simply create a list with the desired class ID's and add it to the dictionary.

This gist contains a list of all 1k ImageNet classes and their corresponding numbers.

Texture Distillation

You can also use the same set of expert trajectories (except those using ZCA) to distill classes into toroidal textures by simply adding the --texture flag.

For example,

python distill.py --texture --dataset=ImageNet --subset=imagesquawk --model=ConvNetD5 --ipc=1 --res=256 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagesquawk subset (at 256x256 resolution) into the following 10 textures

Acknowledgments

We would like to thank Alexander Li, Assaf Shocher, Gokul Swamy, Kangle Deng, Ruihan Gao, Nupur Kumari, Muyang Li, Gaurav Parmar, Chonghyuk Song, Sheng-Yu Wang, and Bingliang Zhang as well as Simon Lucey's Vision Group at the University of Adelaide for their valuable feedback. This work is supported, in part, by the NSF Graduate Research Fellowship under Grant No. DGE1745016 and grants from J.P. Morgan Chase, IBM, and SAP. Our code is adapted from https://github.com/VICO-UoE/DatasetCondensation

Related Work

  1. Tongzhou Wang et al. "Dataset Distillation", in arXiv preprint 2018
  2. Bo Zhao et al. "Dataset Condensation with Gradient Matching", in ICLR 2020
  3. Bo Zhao and Hakan Bilen. "Dataset Condensation with Differentiable Siamese Augmentation", in ICML 2021
  4. Timothy Nguyen et al. "Dataset Meta-Learning from Kernel Ridge-Regression", in ICLR 2021
  5. Timothy Nguyen et al. "Dataset Distillation with Infinitely Wide Convolutional Networks", in NeurIPS 2021
  6. Bo Zhao and Hakan Bilen. "Dataset Condensation with Distribution Matching", in arXiv preprint 2021
  7. Kai Wang et al. "CAFE: Learning to Condense Dataset by Aligning Features", in CVPR 2022

Reference

If you find our code useful for your research, please cite our paper.

@inproceedings{
cazenavette2022distillation,
title={Dataset Distillation by Matching Training Trajectories},
author={George Cazenavette and Tongzhou Wang and Antonio Torralba and Alexei A. Efros and Jun-Yan Zhu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
Owner
George Cazenavette
Carnegie Mellon University
George Cazenavette
RATCHET is a Medical Transformer for Chest X-ray Diagnosis and Reporting

RATCHET: RAdiological Text Captioning for Human Examined Thoraxes RATCHET is a Medical Transformer for Chest X-ray Diagnosis and Reporting. Based on t

26 Nov 14, 2022
💃 VALSE: A Task-Independent Benchmark for Vision and Language Models Centered on Linguistic Phenomena

💃 VALSE: A Task-Independent Benchmark for Vision and Language Models Centered on Linguistic Phenomena.

Heidelberg-NLP 17 Nov 07, 2022
Original Implementation of Prompt Tuning from Lester, et al, 2021

Prompt Tuning This is the code to reproduce the experiments from the EMNLP 2021 paper "The Power of Scale for Parameter-Efficient Prompt Tuning" (Lest

Google Research 282 Dec 28, 2022
Pytorch domain adaptation package

DomainAdaptation This package is created to tackle the problem of domain shifts when dealing with two domains of different feature distributions. In d

Institute of Computational Perception 7 Oct 22, 2022
Iowa Project - My second project done at General Assembly, focused on feature engineering and understanding Linear Regression as a concept

Project 2 - Ames Housing Data and Kaggle Challenge PROBLEM STATEMENT Inferring or Predicting? What's more valuable for a housing model? When creating

Adam Muhammad Klesc 1 Jan 03, 2022
Analyzes your GitHub Profile and presents you with a report on how likely you are to become the next MLH Fellow!

Fellowship Prediction GitHub Profile Comparative Analysis Tool Built with BentoML Table of Contents: Features Disclaimer Technologies Used Contributin

Damir Temir 51 Dec 29, 2022
FLSim a flexible, standalone library written in PyTorch that simulates FL settings with a minimal, easy-to-use API

Federated Learning Simulator (FLSim) is a flexible, standalone core library that simulates FL settings with a minimal, easy-to-use API. FLSim is domain-agnostic and accommodates many use cases such a

Meta Research 162 Jan 02, 2023
The codes and related files to reproduce the results for Image Similarity Challenge Track 1.

ISC-Track1-Submission The codes and related files to reproduce the results for Image Similarity Challenge Track 1. Required dependencies To begin with

Wenhao Wang 115 Jan 02, 2023
A curated list of resources for Image and Video Deblurring

A curated list of resources for Image and Video Deblurring

Subeesh Vasu 1.7k Jan 01, 2023
Code for Neural-GIF: Neural Generalized Implicit Functions for Animating People in Clothing(ICCV21)

NeuralGIF Code for Neural-GIF: Neural Generalized Implicit Functions for Animating People in Clothing(ICCV21) We present Neural Generalized Implicit F

Garvita Tiwari 104 Nov 18, 2022
Official code implementation for "Personalized Federated Learning using Hypernetworks"

Personalized Federated Learning using Hypernetworks This is an official implementation of Personalized Federated Learning using Hypernetworks paper. [

Aviv Shamsian 121 Dec 25, 2022
This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

75 Dec 02, 2022
SAPIEN Manipulation Skill Benchmark

ManiSkill Benchmark SAPIEN Manipulation Skill Benchmark (abbreviated as ManiSkill, pronounced as "Many Skill") is a large-scale learning-from-demonstr

Hao Su's Lab, UCSD 107 Jan 08, 2023
MASS (Mueen's Algorithm for Similarity Search) - a python 2 and 3 compatible library used for searching time series sub-sequences under z-normalized Euclidean distance for similarity.

Introduction MASS allows you to search a time series for a subquery resulting in an array of distances. These array of distances enable you to identif

Matrix Profile Foundation 79 Dec 31, 2022
Single-stage Keypoint-based Category-level Object Pose Estimation from an RGB Image

CenterPose Overview This repository is the official implementation of the paper "Single-stage Keypoint-based Category-level Object Pose Estimation fro

NVIDIA Research Projects 188 Dec 27, 2022
Implementation of Hierarchical Transformer Memory (HTM) for Pytorch

Hierarchical Transformer Memory (HTM) - Pytorch Implementation of Hierarchical Transformer Memory (HTM) for Pytorch. This Deepmind paper proposes a si

Phil Wang 63 Dec 29, 2022
DNA sequence classification by Deep Neural Network

DNA sequence classification by Deep Neural Network: Project Overview worked on the DNA sequence classification problem where the input is the DNA sequ

Mohammed Jawwadul Islam Fida 0 Aug 02, 2022
Fiddle is a Python-first configuration library particularly well suited to ML applications.

Fiddle Fiddle is a Python-first configuration library particularly well suited to ML applications. Fiddle enables deep configurability of parameters i

Google 227 Dec 26, 2022
Automatic caption evaluation metric based on typicality analysis.

SeMantic and linguistic UndeRstanding Fusion (SMURF) Automatic caption evaluation metric described in the paper "SMURF: SeMantic and linguistic UndeRs

Joshua Feinglass 6 Jan 09, 2022
Using multidimensional LSTM neural networks to create a forecast for Bitcoin price

Multidimensional LSTM BitCoin Time Series Using multidimensional LSTM neural networks to create a forecast for Bitcoin price. For notes around this co

Jakob Aungiers 318 Dec 14, 2022