Official code for "Maximum Likelihood Training of Score-Based Diffusion Models", NeurIPS 2021 (spotlight)

Overview

Maximum Likelihood Training of Score-Based Diffusion Models

This repo contains the official implementation for the paper Maximum Likelihood Training of Score-Based Diffusion Models

by Yang Song*, Conor Durkan*, Iain Murray, and Stefano Ermon. Published in NeurIPS 2021 (spotlight).


We prove the connection between the Kullback–Leibler divergence and the weighted combination of score matching losses used for training score-based generative models. Our results can be viewed as a generalization of both the de Bruijn identity in information theory and the evidence lower bound in variational inference.

Our theoretical results enable ScoreFlow, a continuous normalizing flow model trained with a variational objective, which is much more efficient than neural ODEs. We report the state-of-the-art likelihood on CIFAR-10 and ImageNet 32x32 among all flow models, achieving comparable performance to cutting-edge autoregressive models.

How to run the code

Dependencies

Run the following to install a subset of necessary python packages for our code

pip install -r requirements.txt

Stats files for quantitative evaluation

We provide stats files for computing FID and Inception scores for CIFAR-10 and ImageNet 32x32. You can find cifar10_stats.npz and imagenet32_stats.npz under the directory assets/stats in our Google drive. Download them and save to assets/stats/ in the code repo.

Usage

Train and evaluate our models through main.py. Here are some common options:

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval|train_deq>: Running mode: train or eval or training the Flow++ variational dequantization model
  --workdir: Working directory
  • config is the path to the config file. Our config files are provided in configs/. They are formatted according to ml_collections and should be quite self-explanatory.

    Naming conventions of config files: the name of a config file contains the following attributes:

    • dataset: Either cifar10 or imagenet32
    • model: Either ddpmpp_continuous or ddpmpp_deep_continuous
  • workdir is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

  • eval_folder is the name of a subfolder in workdir that stores all artifacts of the evaluation process, like meta checkpoints for supporting pre-emption recovery, image samples, and numpy dumps of quantitative results.

  • mode is either "train" or "eval" or "train_deq". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in workdir/checkpoints-meta . When set to "eval", it can do the following:

    • Compute the log-likelihood on the training or test dataset.

    • Compute the lower bound of the log-likelihood on the training or test dataset.

    • Evaluate the loss function on the test / validation dataset.

    • Generate a fixed number of samples and compute its Inception score, FID, or KID. Prior to evaluation, stats files must have already been downloaded/computed and stored in assets/stats.

      When set to "train_deq", it trains a Flow++ variational dequantization model to bridge the gap of likelihoods on continuous and discrete images. Recommended if you want to compete with generative models trained on discrete images, such as VAEs and autoregressive models. train_deq mode also supports pre-emption recovery.

These functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections package.

Configurations for training

To turn on likelihood weighting, set --config.training.likelihood_weighting. To additionally turn on importance sampling for variance reduction, use --config.training.likelihood_weighting. To train a separate Flow++ variational dequantizer, you need to first finish training a score-based model, then use --mode=train_deq.

Configurations for evaluation

To generate samples and evaluate sample quality, use the --config.eval.enable_sampling flag; to compute log-likelihoods, use the --config.eval.enable_bpd flag, and specify --config.eval.dataset=train/test to indicate whether to compute the likelihoods on the training or test dataset. Turn on --config.eval.bound to evaluate the variational bound for the log-likelihood. Enable --config.eval.dequantizer to use variational dequantization for likelihood computation. --config.eval.num_repeats configures the number of repetitions across the dataset (more can reduce the variance of the likelihoods; default to 5).

Pretrained checkpoints

All checkpoints are provided in this Google drive.

Folder structure:

  • assets: contains cifar10_stats.npz and imagenet32_stats.npz. Necessary for computing FID and Inception scores.
  • <cifar10|imagenet32>_(deep)_<vp|subvp>_(likelihood)_(iw)_(flip). Here the part enclosed in () is optional. deep in the name specifies whether the score model is a deeper architecture (ddpmpp_deep_continuous). likelihood specifies whether the model was trained with likelihood weighting. iw specifies whether the model was trained with importance sampling for variance reduction. flip shows whether the model was trained with horizontal flip for data augmentation. Each folder has the following two subfolders:
    • checkpoints: contains the last checkpoint for the score-based model.
    • flowpp_dequantizer/checkpoints: contains the last checkpoint for the Flow++ variational dequantization model.

References

If you find the code useful for your research, please consider citing

@inproceedings{song2021maximum,
  title={Maximum Likelihood Training of Score-Based Diffusion Models},
  author={Song, Yang and Durkan, Conor and Murray, Iain and Ermon, Stefano},
  booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
  year={2021}
}

This work is built upon some previous papers which might also interest you:

  • Yang Song and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." Proceedings of the 33rd Annual Conference on Neural Information Processing Systems, 2019.
  • Yang Song and Stefano Ermon. "Improved techniques for training score-based generative models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems, 2020.
  • Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. "Score-Based Generative Modeling through Stochastic Differential Equations". Proceedings of the 9th International Conference on Learning Representations, 2021.
Owner
Yang Song
PhD Candidate in Stanford AI Lab
Yang Song
🔊 Audio and fastai v2

Fastaudio An audio module for fastai v2. We want to help you build audio machine learning applications while minimizing the need for audio domain expe

152 Dec 28, 2022
A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019).

ClusterGCN ⠀⠀ A PyTorch implementation of "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" (KDD 2019). A

Benedek Rozemberczki 697 Dec 27, 2022
A rule-based log analyzer & filter

Flog 一个根据规则集来处理文本日志的工具。 前言 在日常开发过程中,由于缺乏必要的日志规范,导致很多人乱打一通,一个日志文件夹解压缩后往往有几十万行。 日志泛滥会导致信息密度骤减,给排查问题带来了不小的麻烦。 以前都是用grep之类的工具先挑选出有用的,再逐条进行排查,费时费力。在忍无可忍之后决

上山打老虎 9 Jun 23, 2022
OneShot Learning-based hotword detection.

EfficientWord-Net Hotword detection based on one-shot learning Home assistants require special phrases called hotwords to get activated (eg:"ok google

ANT-BRaiN 102 Dec 25, 2022
Code for PackNet: Adding Multiple Tasks to a Single Network by Iterative Pruning

PackNet: https://arxiv.org/abs/1711.05769 Pretrained models are available here: https://uofi.box.com/s/zap2p03tnst9dfisad4u0sfupc0y1fxt Datasets in Py

Arun Mallya 216 Jan 05, 2023
Exponential Graph is Provably Efficient for Decentralized Deep Training

Exponential Graph is Provably Efficient for Decentralized Deep Training This code repository is for the paper Exponential Graph is Provably Efficient

3 Apr 20, 2022
Pywonderland - A tour in the wonderland of math with python.

A Tour in the Wonderland of Math with Python A collection of python scripts for drawing beautiful figures and animating interesting algorithms in math

Zhao Liang 4.1k Jan 03, 2023
This is my codes that can visualize the psnr image in testing videos.

CVPR2018-Baseline-PSNRplot This is my codes that can visualize the psnr image in testing videos. Future Frame Prediction for Anomaly Detection – A New

Wenhao Yang 12 May 29, 2021
Implementations of polygamma, lgamma, and beta functions for PyTorch

lgamma Implementations of polygamma, lgamma, and beta functions for PyTorch. It's very hacky, but that's usually ok for research use. To build, run: .

Rachit Singh 24 Nov 09, 2021
ByteTrack with ReID module following the paradigm of FairMOT, tracking strategy is borrowed from FairMOT/JDE.

ByteTrack_ReID ByteTrack is the SOTA tracker in MOT benchmarks with strong detector YOLOX and a simple association strategy only based on motion infor

Han GuangXin 46 Dec 29, 2022
Official code for "Focal Self-attention for Local-Global Interactions in Vision Transformers"

Focal Transformer This is the official implementation of our Focal Transformer -- "Focal Self-attention for Local-Global Interactions in Vision Transf

Microsoft 486 Dec 20, 2022
ArcaneGAN by Alex Spirin

ArcaneGAN by Alex Spirin

Alex 617 Dec 28, 2022
Pytorch implementation of face attention network

Face Attention Network Pytorch implementation of face attention network as described in Face Attention Network: An Effective Face Detector for the Occ

Hooks 312 Dec 09, 2022
Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral]

Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral] Learning to Disambiguate Strongly In

Zicong Fan 40 Dec 22, 2022
Code accompanying "Adaptive Methods for Aggregated Domain Generalization"

Adaptive Methods for Aggregated Domain Generalization (AdaClust) Official Pytorch Implementation of Adaptive Methods for Aggregated Domain Generalizat

Xavier Thomas 15 Sep 20, 2022
scikit-learn inspired API for CRFsuite

sklearn-crfsuite sklearn-crfsuite is a thin CRFsuite (python-crfsuite) wrapper which provides interface simlar to scikit-learn. sklearn_crfsuite.CRF i

417 Dec 20, 2022
Research on Event Accumulator Settings for Event-Based SLAM

Research on Event Accumulator Settings for Event-Based SLAM This is the source code for paper "Research on Event Accumulator Settings for Event-Based

Robin Shaun 26 Dec 21, 2022
SelfAugment extends MoCo to include automatic unsupervised augmentation selection.

SelfAugment extends MoCo to include automatic unsupervised augmentation selection. In addition, we've included the ability to pretrain on several new datasets and included a wandb integration.

Colorado Reed 24 Oct 26, 2022
Robot Hacking Manual (RHM). From robotics to cybersecurity. Papers, notes and writeups from a journey into robot cybersecurity.

RHM: Robot Hacking Manual Download in PDF RHM v0.4 ┃ Read online The Robot Hacking Manual (RHM) is an introductory series about cybersecurity for robo

Víctor Mayoral Vilches 233 Dec 30, 2022
BASH - Biomechanical Animated Skinned Human

We developed a method animating a statistical 3D human model for biomechanical analysis to increase accessibility for non-experts, like patients, athletes, or designers.

Machine Learning and Data Analytics Lab FAU 66 Nov 19, 2022