PyTorch Implementation of DSB for Score Based Generative Modeling. Experiments managed using Hydra.

Overview

Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling

This repository contains the implementation for the paper Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling.

If using this code, please cite the paper:

    @article{de2021diffusion,
              title={Diffusion Schr$\backslash$" odinger Bridge with Applications to Score-Based Generative Modeling},
              author={De Bortoli, Valentin and Thornton, James and Heng, Jeremy and Doucet, Arnaud},
              journal={arXiv preprint arXiv:2106.01357},
              year={2021}
            }

Contributors

  • Valentin De Bortoli
  • James Thornton
  • Jeremy Heng
  • Arnaud Doucet

What is a Schrödinger bridge?

The Schrödinger Bridge (SB) problem is a classical problem appearing in applied mathematics, optimal control and probability; see [1, 2, 3]. In the discrete-time setting, it takes the following (dynamic) form. Consider as reference density p(x0:N) describing the process adding noise to the data. We aim to find p*(x0:N) such that p*(x0) = pdata(x0) and p*(xN) = pprior(xN) and minimize the Kullback-Leibler divergence between p* and p. In this work we introduce Diffusion Schrodinger Bridge (DSB), a new algorithm which uses score-matching approaches [4] to approximate the Iterative Proportional Fitting algorithm, an iterative method to find the solutions of the SB problem. DSB can be seen as a refinement of existing score-based generative modeling methods [5, 6].

Schrodinger bridge

Installation

This project can be installed from its git repository.

  1. Obtain the sources by:

    git clone https://github.com/anon284/schrodinger_bridge.git

or, if git is unavailable, download as a ZIP from GitHub https://github.com/.

  1. Install:

    conda env create -f conda.yaml

    conda activate bridge

  2. Download data examples:

    • CelebA: python data.py --data celeba --data_dir './data/'
    • MNIST: python data.py --data mnist --data_dir './data/'

How to use this code?

  1. Train Networks:
  • 2d: python main.py dataset=2d model=Basic num_steps=20 num_iter=5000
  • mnist python main.py dataset=stackedmnist num_steps=30 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>
  • celeba python main.py dataset=celeba num_steps=50 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>

Checkpoints and sampled images will be saved to a newly created directory. If GPU has insufficient memory, then reduce cache size. 2D dataset should train on CPU. MNIST and CelebA was ran on 2 high-memory V100 GPUs.

References

.. [1] Hans Föllmer Random fields and diffusion processes In: École d'été de Probabilités de Saint-Flour 1985-1987

.. [2] Christian Léonard A survey of the Schrödinger problem and some of its connections with optimal transport In: Discrete & Continuous Dynamical Systems-A 2014

.. [3] Yongxin Chen, Tryphon Georgiou and Michele Pavon Optimal Transport in Systems and Control In: Annual Review of Control, Robotics, and Autonomous Systems 2020

.. [4] Aapo Hyvärinen and Peter Dayan Estimation of non-normalized statistical models by score matching In: Journal of Machine Learning Research 2005

.. [5] Yang Song and Stefano Ermon Generative modeling by estimating gradients of the data distribution In: Advances in Neural Information Processing Systems 2019

.. [6] Jonathan Ho, Ajay Jain and Pieter Abbeel Denoising diffusion probabilistic models In: Advances in Neural Information Processing Systems 2020

Owner
James Thornton
James Thornton
A ssl analyzer which could analyzer target domain's certificate.

ssl_analyzer A ssl analyzer which could analyzer target domain's certificate. Analyze the domain name ssl certificate information according to the inp

vincent 17 Dec 12, 2022
Pseudo-mask Matters in Weakly-supervised Semantic Segmentation

Pseudo-mask Matters in Weakly-supervised Semantic Segmentation By Yi Li, Zhanghui Kuang, Liyang Liu, Yimin Chen, Wayne Zhang SenseTime, Tsinghua Unive

33 Oct 14, 2022
Official implementation of the paper "Lightweight Deep CNN for Natural Image Matting via Similarity Preserving Knowledge Distillation"

Lightweight-Deep-CNN-for-Natural-Image-Matting-via-Similarity-Preserving-Knowledge-Distillation Introduction Accepted at IEEE Signal Processing Letter

DongGeun-Yoon 19 Jun 07, 2022
Voice Gender Recognition

In this project it was used some different Machine Learning models to identify the gender of a voice (Female or Male) based on some specific speech and voice attributes.

Anne Livia 1 Jan 27, 2022
[CVPR2021] Look before you leap: learning landmark features for one-stage visual grounding.

LBYL-Net This repo implements paper Look Before You Leap: Learning Landmark Features For One-Stage Visual Grounding CVPR 2021. Getting Started Prerequ

SVIP Lab 45 Dec 12, 2022
Advantage Actor Critic (A2C): jax + flax implementation

Advantage Actor Critic (A2C): jax + flax implementation Current version supports only environments with continious action spaces and was tested on muj

Andrey 3 Jan 23, 2022
Price-Prediction-For-a-Dream-Home - A machine learning based linear regression trained model for house price prediction.

Price-Prediction-For-a-Dream-Home ROADMAP TO THIS LINEAR REGRESSION BASED HOUSE PRICE PREDICTION PREDICTION MODEL Import all the dependencies of the p

DIKSHA DESWAL 1 Dec 29, 2021
Epidemiology analysis package

zEpid zEpid is an epidemiology analysis package, providing easy to use tools for epidemiologists coding in Python 3.5+. The purpose of this library is

Paul Zivich 111 Jan 08, 2023
Official implementation of "Watermarking Images in Self-Supervised Latent-Spaces"

🔍 Watermarking Images in Self-Supervised Latent-Spaces PyTorch implementation and pretrained models for the paper. For details, see Watermarking Imag

Meta Research 32 Dec 13, 2022
Hydra Lightning Template for Structured Configs

Hydra Lightning Template for Structured Configs Template for creating projects with pytorch-lightning and hydra. How to use this template? Create your

Model-driven Machine Learning 4 Jul 19, 2022
Centroid-UNet is deep neural network model to detect centroids from satellite images.

Centroid UNet - Locating Object Centroids in Aerial/Serial Images Introduction Centroid-UNet is deep neural network model to detect centroids from Aer

GIC-AIT 19 Dec 08, 2022
Towards Flexible Blind JPEG Artifacts Removal (FBCNN, ICCV 2021)

Towards Flexible Blind JPEG Artifacts Removal (FBCNN, ICCV 2021)

Jiaxi Jiang 282 Jan 02, 2023
McGill Physics Hackathon 2021: Reaction-Diffusion Models for the Generation of Biological Patterns

DiffuseAnimals: Reaction-Diffusion Models for the Generation of Biological Patterns Introduction Reaction-diffusion equations can be utilized in order

Austin Szuminsky 2 Mar 07, 2022
Ego4d dataset repository. Download the dataset, visualize, extract features & example usage of the dataset

Ego4D EGO4D is the world's largest egocentric (first person) video ML dataset and benchmark suite, with 3,600 hrs (and counting) of densely narrated v

Meta Research 118 Jan 07, 2023
A PyTorch implementation of "Semi-Supervised Graph Classification: A Hierarchical Graph Perspective" (WWW 2019)

SEAL ⠀⠀⠀ A PyTorch implementation of Semi-Supervised Graph Classification: A Hierarchical Graph Perspective (WWW 2019) Abstract Node classification an

Benedek Rozemberczki 202 Dec 27, 2022
Official implementation of the article "Unsupervised JPEG Domain Adaptation For Practical Digital Forensics"

Unsupervised JPEG Domain Adaptation for Practical Digital Image Forensics @WIFS2021 (Montpellier, France) Rony Abecidan, Vincent Itier, Jeremie Boulan

Rony Abecidan 6 Jan 06, 2023
Segmentation for medical image.

EfficientSegmentation Introduction EfficientSegmentation is an open source, PyTorch-based segmentation framework for 3D medical image. Features A whol

68 Nov 28, 2022
Meta Learning Backpropagation And Improving It (VSML)

Meta Learning Backpropagation And Improving It (VSML) This is research code for the NeurIPS 2021 publication Kirsch & Schmidhuber 2021. Many concepts

Louis Kirsch 22 Dec 21, 2022
Code to reproduce results from the paper "AmbientGAN: Generative models from lossy measurements"

AmbientGAN: Generative models from lossy measurements This repository provides code to reproduce results from the paper AmbientGAN: Generative models

Ashish Bora 87 Oct 19, 2022
Make your own game in a font!

Project structure. Included is a suite of tools to create font games. Tutorial: For a quick tutorial about how to make your own game go here For devel

Michael Mulet 125 Dec 04, 2022