A minimal yet resourceful implementation of diffusion models (along with pretrained models + synthetic images for nine datasets)

Overview

Minimal implementation of diffusion models

A minimal implementation of diffusion models with the goal to democratize the use of synthetic data from these models.

Check out the experimental results section for quantitative numbers on quality of synthetic data and FAQs for a broader discussion. We experiments with nine commonly used datasets, and released all assets, including models and synthetic data for each of them.

Requirements: pip install scipy opencv-python. We assume torch and torchvision are already installed.

Structure

main.py  - Train or sample from a diffusion model.
unets.py - UNet based network architecture for diffusion model.
data.py  - Common datasets and their metadata.
──  scripts
     └── train.sh  - Training scripts for all datasets.
     └── sample.sh - Sampling scripts for all datasets.

Training

Use the following command to train the diffusion model on four gpus.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py \
  --arch UNet --dataset cifar10 --class-cond --epochs 500

We provide the exact script used for training in ./scripts/train.sh.

Sampling

We reuse main.py for sampling but with the --sampling-only only flag. Use the following command to sample 50K images from a pretrained diffusion model.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py \
  --arch UNet --dataset cifar10 --class-cond --sampling-only --sampling-steps 250 \
  --num-sampled-images 50000 --pretrained-ckpt path_to_pretrained_model

We provide the exact script used for sampling in ./scripts/sample.sh.

How useful is synthetic data from diffusion models? 🤔

Takeaway: Across all datasets, training only on synthetic data suffice to achieve a competitive classification score on real data.

Goal: Our goal is to not only measure photo-realism of synthetic images but also measure how well synthetic images cover the data distribution, i.e., how diverse is synthetic data. Note that a generative model, commonly GANs, can generate high-quality images, but still fail to generate diverse images.

Choice of datasets: We use nine commonly used datasets in image recognition. The goal was to multiple datasets was to capture enough diversity in terms of the number of samples, the number of classes, and coarse vs fine-grained classification. In addition, by using a common setup across datasets, we can test the success of diffusion models without any assumptions about the dataset.

Diffusion model: For each dataset, we train a class-conditional diffusion model. We choose a modest size network and train it for a limited number of hours on a 4xA4000 cluster, as highlighted by the training time in the table below. Next, we sample 50,000 synthetic images from the diffusion model.

Metric to measure synthetic data quality: We train a ResNet50 classifier on only real images and another one on only synthetic images and measure their accuracy on the validation set of real images. This metric is also referred to as classification accuracy score and it provides us a way to measure both quality and diversity of synthetic data in a unified manner across datasets.

Released assets for each dataset: Pre-trained Diffusion models, 50,000 synthetic images for each dataset, and downstream clasifiers trained with real-only or synthetic-only dataset.

Table 1: Training images and classes refer to the number of training images and the number of classes in the dataset. Training time refers to the time taken to train the diffusion model. Real only is the test set accuracy of ResNet-50 model trained on only real training images. Synthetic accuracy is the test accuracy of the ResNet-50 model trained on only 50K synthetic images.

Dataset Training images Classes Training time (hours) Real only Synthetic only
MNIST 60,000 10 2.1 99.6 99.0
MNIST-M 60,000 10 5.3 99.3 97.3
CIFAR-10 50,000 10 10.7 93.8 87.3
Skin Cancer* 33126 2 19.1 69.7 64.1
AFHQ 14630 3 8.6 97.9 98.7
CelebA 109036 4 12.8 90.1 88.9
Standford Cars 8144 196 7.4 33.7 76.6
Oxford Flowers 2040 102 6.0 29.3 76.3
Traffic signs 39252 43 8.3 96.6 96.1

* Due to heavy class imbalance, we use AUROC to measure classification performance.

Note: Except for CIFAR10, MNIST, MNIST-M, and GTSRB, we use 64x64 image resolution for all datasets. The key reason to use a lower resolution was to reduce the computational resources needed to train the diffusion model.

Discussion: Across most datasets training only on synthetic data achieves competitive performance with training on real data. It shows that the synthetic data 1) has high-quality images, otherwise the model wouldn't have learned much from it 2) high coverage of distribution, otherwise, the model trained on synthetic data won't do well on the whole test set. Even more, the synthetic dataset has a unique advantage: we can easily generate a very large amount of it. This difference is clearly visible for the low-data regime (flowers and cars dataset), where training on synthetic data (50K images) achieves much better performance than training on real data, which has less than 10K images. A more principled investigation of sample complexity, i.e., performance vs number-of-synthetic-images is available in one of my previous papers (fig. 9).

FAQs

Q. Why use diffusion models?
A. This question is super broad and has multiple answers. 1) They are super easy to train. Unlike GANs, there are no training instabilities in the optimization process. 2) The mode coverage of the diffusion models is excellent where at the same time the generated images are quite photorealistic. 3) The training pipeline is also consistent across datasets, i.e., no assumption about the data. For all datasets above, the only parameter we changed was the amount of training time.

Q. Is synthetic data from diffusion models much different from other generative models, in particular GANs?
A. As mentioned in the previous answer, synthetic data from diffusion models have much higher coverage than GANs, while having a similar image quality. Check out the this previous paper by Prafulla Dhariwal and Alex Nichol where they provide extensive results supporting this claim. In the regime of robust training, you can find a more quantitive comparison of diffusion models with multiple GANs in one of my previous papers.

Q. Why classification accuracy on some datasets is so low (e.g., flowers), even when training with real data?
A. Due to many reasons, current classification numbers aren't meant to be competitive with state-of-the-art. 1) We don't tune any hyperparameters across datasets. For each dataset, we train a ResNet50 model with 0.1 learning rate, 1e-4 weight decay, 0.9 momentum, and cosine learning rate decay. 2) Instead of full resolution (commonly 224x224), we use low-resolution images (64x64), which makes classification harder.

Q. Using only synthetic data, how to further improve the test accuracy on real data?
A. Diffusion models benefit tremendously from scaling of the training setup. One can do so by increasing the network width (base_width) and training the network for more epochs (2-4x).

References

This implementation was originally motivated by the original implmentation of diffusion models by Jonathan Ho. I followed the recent PyTorch implementation by OpenAI for common design choices in diffusion models.

The experiments to test out the potential of synthetic data from diffusion models are inspired by one of my previous work. We found that using synthetic data from the diffusion model alone surpasses benefits from multiple algorithmic innovations in robust training, which is one of the simple yet extremely hard problems to solve for neural networks. The next step is to repeat the Table-1 experiments, but this time with robust training.

Visualizing real and synthetic images

For each data, we plot real images on the left and synthetic images on the right. Each row corresponds to a unique class while classes for real and synthetic data are identical.

Light     Dark

MNIST

Light     Dark

MNIST-M

Light     Dark

CIFAR-10

Light     Dark

GTSRB

Light     Dark

Celeb-A

Light     Dark

AFHQ

Light     Dark

Cars

Light     Dark

Flowers

Light     Dark

Melanoma (Skin cancer)

Light     Dark

Note: Real images for each dataset follow the same license as their respective dataset.

Owner
Vikash Sehwag
PhD candidate at Princeton University. Interested in problems at the intersection of Security, Privacy, and Machine leanring.
Vikash Sehwag
Improving Calibration for Long-Tailed Recognition (CVPR2021)

MiSLAS Improving Calibration for Long-Tailed Recognition Authors: Zhisheng Zhong, Jiequan Cui, Shu Liu, Jiaya Jia [arXiv] [slide] [BibTeX] Introductio

Jia Research Lab 116 Dec 20, 2022
PyTorch implementation of EfficientNetV2

[NEW!] Check out our latest work involution accepted to CVPR'21 that introduces a new neural operator, other than convolution and self-attention. PyTo

Duo Li 375 Jan 03, 2023
Implementation of Kaneko et al.'s MaskCycleGAN-VC model for non-parallel voice conversion.

MaskCycleGAN-VC Unofficial PyTorch implementation of Kaneko et al.'s MaskCycleGAN-VC (2021) for non-parallel voice conversion. MaskCycleGAN-VC is the

86 Dec 25, 2022
Nonuniform-to-Uniform Quantization: Towards Accurate Quantization via Generalized Straight-Through Estimation. In CVPR 2022.

Nonuniform-to-Uniform Quantization This repository contains the training code of N2UQ introduced in our CVPR 2022 paper: "Nonuniform-to-Uniform Quanti

Zechun Liu 60 Dec 28, 2022
A concise but complete implementation of CLIP with various experimental improvements from recent papers

x-clip (wip) A concise but complete implementation of CLIP with various experimental improvements from recent papers Install $ pip install x-clip Usag

Phil Wang 515 Dec 26, 2022
Code for KHGT model, AAAI2021

KHGT Code for KHGT accepted by AAAI2021 Please unzip the data files in Datasets/ first. To run KHGT on Yelp data, use python labcode_yelp.py For Movi

32 Nov 29, 2022
Semi-Supervised Learning, Object Detection, ICCV2021

End-to-End Semi-Supervised Object Detection with Soft Teacher By Mengde Xu*, Zheng Zhang*, Han Hu, Jianfeng Wang, Lijuan Wang, Fangyun Wei, Xiang Bai,

Microsoft 789 Dec 27, 2022
Source-to-Source Debuggable Derivatives in Pure Python

Tangent Tangent is a new, free, and open-source Python library for automatic differentiation. Existing libraries implement automatic differentiation b

Google 2.2k Jan 01, 2023
Group project for MFIN7036. Our goal is to predict firm profitability with text-based competition measures.

NLP_0-project Group project for MFIN7036. Our goal is to predict firm profitability with text-based competition measures1. We are a "democratic" and c

3 Mar 16, 2022
Official Implementation of VAT

Semantic correspondence Few-shot segmentation Cost Aggregation Is All You Need for Few-Shot Segmentation For more information, check out project [Proj

Hamacojr 114 Dec 27, 2022
Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks

Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks. Bayes

Intel Labs 210 Jan 04, 2023
Code for ACL'2021 paper WARP 🌀 Word-level Adversarial ReProgramming

Code for ACL'2021 paper WARP 🌀 Word-level Adversarial ReProgramming. Outperforming `GPT-3` on SuperGLUE Few-Shot text classification.

YerevaNN 75 Nov 06, 2022
Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)

MTTS-CAN: Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement Paper Xin Liu, Josh Fromm, Shwetak Patel, Daniel M

Xin Liu 106 Dec 30, 2022
blind SQLIpy sebuah alat injeksi sql yang menggunakan waktu sql untuk mendapatkan sebuah server database.

blind SQLIpy Alat blind SQLIpy ini merupakan alat injeksi sql yang menggunakan metode time based blind sql injection metode tersebut membutuhkan waktu

Galih Anggoro Prasetya 4 Feb 24, 2022
On Evaluation Metrics for Graph Generative Models

On Evaluation Metrics for Graph Generative Models Authors: Rylee Thompson, Boris Knyazev, Elahe Ghalebi, Jungtaek Kim, Graham Taylor This is the offic

13 Jan 07, 2023
UniMoCo: Unsupervised, Semi-Supervised and Full-Supervised Visual Representation Learning

UniMoCo: Unsupervised, Semi-Supervised and Full-Supervised Visual Representation Learning This is the official PyTorch implementation for UniMoCo pape

dddzg 49 Jan 02, 2023
Learnable Boundary Guided Adversarial Training (ICCV2021)

Learnable Boundary Guided Adversarial Training This repository contains the implementation code for the ICCV2021 paper: Learnable Boundary Guided Adve

DV Lab 27 Sep 25, 2022
METS/ALTO OCR enhancing tool by the National Library of Luxembourg (BnL)

Nautilus-OCR The National Library of Luxembourg (BnL) started its first initiative in digitizing newspapers, with layout recognition and OCR on articl

National Library of Luxembourg 36 Dec 05, 2022
Discovering Dynamic Salient Regions with Spatio-Temporal Graph Neural Networks

Discovering Dynamic Salient Regions with Spatio-Temporal Graph Neural Networks This is the official code for DyReg model inroduced in Discovering Dyna

Bitdefender Machine Learning 11 Nov 08, 2022
The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track.

ISC21-Descriptor-Track-1st The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track. You can check our solution

lyakaap 75 Jan 08, 2023