Release for Improved Denoising Diffusion Probabilistic Models

Overview

improved-diffusion

This is the codebase for Improved Denoising Diffusion Probabilistic Models.

Usage

This section of the README walks through how to train and sample from a model.

Installation

Clone this repository and navigate to it in your terminal. Then run:

pip install -e .

This should install the improved_diffusion python package that the scripts depend on.

Preparing Data

The training code reads images from a directory of image files. In the datasets folder, we have provided instructions/scripts for preparing these directories for ImageNet, LSUN bedrooms, and CIFAR-10.

For creating your own dataset, simply dump all of your images into a directory with ".jpg", ".jpeg", or ".png" extensions. If you wish to train a class-conditional model, name the files like "mylabel1_XXX.jpg", "mylabel2_YYY.jpg", etc., so that the data loader knows that "mylabel1" and "mylabel2" are the labels. Subdirectories will automatically be enumerated as well, so the images can be organized into a recursive structure (although the directory names will be ignored, and the underscore prefixes are used as names).

The images will automatically be scaled and center-cropped by the data-loading pipeline. Simply pass --data_dir path/to/images to the training script, and it will take care of the rest.

Training

To train your model, you should first decide some hyperparameters. We will split up our hyperparameters into three groups: model architecture, diffusion process, and training flags. Here are some reasonable defaults for a baseline:

MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear"
TRAIN_FLAGS="--lr 1e-4 --batch_size 128"

Here are some changes we experiment with, and how to set them in the flags:

  • Learned sigmas: add --learn_sigma True to MODEL_FLAGS
  • Cosine schedule: change --noise_schedule linear to --noise_schedule cosine
  • Reweighted VLB: add --use_kl True to DIFFUSION_FLAGS and add --schedule_sampler loss-second-moment to TRAIN_FLAGS.
  • Class-conditional: add --class_cond True to MODEL_FLAGS.

Once you have setup your hyper-parameters, you can run an experiment like so:

python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

You may also want to train in a distributed manner. In this case, run the same command with mpiexec:

mpiexec -n $NUM_GPUS python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

When training in a distributed manner, you must manually divide the --batch_size argument by the number of ranks. In lieu of distributed training, you may use --microbatch 16 (or --microbatch 1 in extreme memory-limited cases) to reduce memory usage.

The logs and saved models will be written to a logging directory determined by the OPENAI_LOGDIR environment variable. If it is not set, then a temporary directory will be created in /tmp.

Sampling

The above training script saves checkpoints to .pt files in the logging directory. These checkpoints will have names like ema_0.9999_200000.pt and model200000.pt. You will likely want to sample from the EMA models, since those produce much better samples.

Once you have a path to your model, you can generate a large batch of samples like so:

python scripts/image_sample.py --model_path /path/to/model.pt $MODEL_FLAGS $DIFFUSION_FLAGS

Again, this will save results to a logging directory. Samples are saved as a large npz file, where arr_0 in the file is a large batch of samples.

Just like for training, you can run image_sample.py through MPI to use multiple GPUs and machines.

You can change the number of sampling steps using the --timestep_respacing argument. For example, --timestep_respacing 250 uses 250 steps to sample. Passing --timestep_respacing ddim250 is similar, but uses the uniform stride from the DDIM paper rather than our stride.

To sample using DDIM, pass --use_ddim True.

Owner
OpenAI
OpenAI
Python library for generating CycloneDX SBOMs

Python Library for generating CycloneDX This CycloneDX module for Python can generate valid CycloneDX bill-of-material document containing an aggregat

CycloneDX SBOM Standard 31 Dec 16, 2022
External Network Pentest Automation using Shodan API and other tools.

Chopin External Network Pentest Automation using Shodan API and other tools. Workflow Input a file containing CIDR ranges. Converts CIDR ranges to ind

Aditya Dixit 9 Aug 04, 2022
Python project setup, updater, and launcher

Launcher Python project setup, updater, and launcher Purpose: Increase project productivity and provide features easily. Once installed as a git submo

DAAV, LLC 1 Jan 07, 2022
Package pyVHR is a comprehensive framework for studying methods of pulse rate estimation relying on remote photoplethysmography (rPPG)

Package pyVHR (short for Python framework for Virtual Heart Rate) is a comprehensive framework for studying methods of pulse rate estimation relying on remote photoplethysmography (rPPG)

PHUSE Lab 261 Jan 03, 2023
A toy repo illustrating a minimal installable Python package

MyToy: a minimal Python package This repository contains a minimal, toy Python package with a few files as illustration for students of how to lay out

Fernando Perez 19 Apr 24, 2022
🐍 A Python lib for (de)serializing Python objects to/from JSON

Turn Python objects into dicts or (json)strings and back No changes required to your objects Easily customizable and extendable Works with dataclasses

Ramon Hagenaars 253 Dec 14, 2022
GitHub saver for stargazers, forks, repos

GitHub backup repositories Save your repos and list of stargazers & list of forks for them. Pure python3 and git with no dependencies to install. GitH

Alexander Kapitanov 23 Aug 21, 2022
Modern robots.txt Parser for Python

Robots Exclusion Protocol Parser for Python Robots.txt parsing in Python. Goals Fetching -- helper utilities for fetching and parsing robots.txts, inc

Moz 176 Dec 16, 2022
Цифрова збрoя проти xуйлoвської пропаганди.

Паляниця Цифрова зброя проти xуйлoвської пропаганди. Щоб негайно почати шкварити рашистські сайти – мерщій у швидкий старт! ⚡️ А коли ворожі сервери в

8 Mar 22, 2022
Shared utility scripts for AI for Earth projects and team members

Overview Shared utilities developed by the Microsoft AI for Earth team The general convention in this repo is that users who want to consume these uti

Microsoft 38 Dec 30, 2022
A timer for bird lovers, plays a random birdcall while displaying its image and info.

Birdcall Timer A timer for bird lovers. Siriema hatchling by Junior Peres Junior Background My partner needed a customizable timer for sitting and sta

Marcelo Sanches 1 Jul 08, 2022
use Notepad++ for real-time sync after python appending new log text

FTP远程log同步工具 使用Notepad++配合来获取实时更新的log文档效果 适用于FTP协议的log远程同步工具,配合MT管理器开启FTP服务器使用,通过Notepad++监听文本变化,更便捷的使用电脑查看方法注入打印后的信息 功能 过滤器 对每行要打印的文本使用回调函数筛选,支持链式调用

Liuhaixv 1 Oct 17, 2021
Weakly-Divisable - Takes an interger and seee if it is weakly divisible by seven

Weakly Divisble Project by Diana Arce-Hernandez, Ryan McAlpine, and Rommel Ravan

Diana Arce-Hernandez 1 Jan 12, 2022
The Python Achievements Framework!

Pychievements: The Python Achievements Framework! Pychievements is a framework for creating and tracking achievements within a Python application. It

Brian 114 Jul 21, 2022
A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

BoMb-OT Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via

Khai Ba Nguyen 18 Nov 14, 2022
PIP Manager written in python Tkinter

PIP Manager About PIP Manager is designed to make Python Package handling easier by just a click of a button!! Available Features Installing packages

Will Payne 9 Dec 09, 2022
python based clash stars made by grade 7 and 5

clash_stars python based clash stars made by grade 7 and 5 How to play: PLAYER ONE (LEFT PLAYER) Move: W,A,S,D Shoot: SHIFT PLAYER TWO (RIGHT PLAYER)

5 Oct 22, 2021
Script em python, utilizando PySimpleGUI, para a geração de arquivo txt a ser importado no sistema de Bilhetagem Eletrônica da RioCard, no Estado do Rio de Janeiro.

pedido-vt-riocard Script em python, utilizando PySimpleGUI, para a geração de arquivo txt a ser importado no sistema de Bilhetagem Eletrônica da RioCa

Carlos Bruno Gomes 1 Dec 01, 2021
A program to calculate the are of a triangle. made with Python.

Area-Calculator What is Area-Calculator? Area-Calculator is a program to find out the area of a triangle easily. fully made with Python. Needed a pyth

Chandula Janith 0 Nov 27, 2021
RecurrentArchitectures - See the accompanying blog post

Why this? What is the goal? The goal of this repository is to write all the recurrent architectures from scratch in tensorflow for learning purposes.

Debajyoti Datta 9 Feb 06, 2022