Proximal Backpropagation - a neural network training algorithm that takes implicit instead of explicit gradient steps

Related tags

Deep Learningproxprop
Overview

Proximal Backpropagation

Proximal Backpropagation (ProxProp) is a neural network training algorithm that takes implicit instead of explicit gradient steps to update the network parameters. We have analyzed this algorithm in our ICLR 2018 paper:

Proximal Backpropagation (Thomas Frerix, Thomas Möllenhoff, Michael Moeller, Daniel Cremers; ICLR 2018) [https://arxiv.org/abs/1706.04638]

tl;dr

  • We provide a PyTorch implementation of ProxProp for Python 3 and PyTorch 1.0.1.
  • The results of our paper can be reproduced by executing the script paper_experiments.sh.
  • ProxProp is implemented as a torch.nn.Module (a 'layer') and can be combined with any other layer and first-order optimizer. While a ProxPropConv2d and a ProxPropLinear layer already exist, you can generate a ProxProp layer for your favorite linear layer with one line of code.

Installation

  1. Make sure you have a running Python 3 (tested with Python 3.7) ecosytem. We recommend that you use a conda install, as this is also the recommended option to get the latest PyTorch running. For this README and for the scripts, we assume that you have conda running with Python 3.7.
  2. Clone this repository and switch to the directory.
  3. Install the dependencies via conda install --file conda_requirements.txt and pip install -r pip_requirements.txt.
  4. Install PyTorch with magma support. We have tested our code with PyTorch 1.0.1 and CUDA 10.0. You can install this setup via
    conda install -c pytorch magma-cuda100
    conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
    
  5. (optional, but necessary to reproduce paper experiments) Download the CIFAR-10 dataset by executing get_data.sh

Training neural networks with ProxProp

ProxProp is implemented as a custom linear layer (torch.nn.Module) with its own backward pass to take implicit gradient steps on the network parameters. With this design choice it can be combined with any other layer, for which one takes explicit gradient steps. Furthermore, the resulting update direction can be used with any first-order optimizer that expects a suitable update direction in parameter space. In our paper we prove that ProxProp generates a descent direction and show experiments with Nesterov SGD and Adam.

You can use our pre-defined layers ProxPropConv2d and ProxPropLinear, corresponding to nn.Conv2d and nn.Linear, by importing

from ProxProp import ProxPropConv2d, ProxPropLinear

Besides the usual layer parameters, as detailed in the PyTorch docs, you can provide:

  • tau_prox: step size for a proximal step; default is tau_prox=1
  • optimization_mode: can be one of 'prox_exact', 'prox_cg{N}', 'gradient' for an exact proximal step, an approximate proximal step with N conjugate gradient steps and an explicit gradient step, respectively; default is optimization_mode='prox_cg1'. The 'gradient' mode is for a fair comparison with SGD, as it incurs the same overhead as the other methods in exploiting a generic implementation with the provided PyTorch API.

If you want to use ProxProp to optimize your favorite linear layer, you can generate the respective module with one line of code. As an example for the the Conv3d layer:

from ProxProp import proxprop_module_generator
ProxPropConv3d = proxprop_module_generator(torch.nn.Conv3d)

This gives you a default implementation for the approximate conjugate gradient solver, which treats all parameters as a stacked vector. If you want to use the exact solver or want to use the conjugate gradient solver more efficiently, you have to provide the respective reshaping methods to proxprop_module_generator, as this requires specific knowledge of the layer's structure and cannot be implemented generically. As a template, take a look at the ProxProp.py file, where we have done this for the ProxPropLinear layer.

By reusing the forward/backward implementations of existing PyTorch modules, ProxProp becomes readily accessible. However, we pay an overhead associated with generically constructing the backward pass using the PyTorch API. We have intentionally sided with genericity over speed.

Reproduce paper experiments

To reproduce the paper experiments execute the script paper_experiments.sh. This will run our paper's experiments, store the results in the directory paper_experiments/ and subsequently compile the results into the file paper_plots.pdf. We use an NVIDIA Titan X GPU; executing the script takes roughly 3 hours.

Acknowledgement

We want to thank Soumith Chintala for helping us track down a mysterious bug and the whole PyTorch dev team for their continued development effort and great support to the community.

Publication

If you use ProxProp, please acknowledge our paper by citing

@article{Frerix-et-al-18,
    title = {Proximal Backpropagation},
    author={Thomas Frerix, Thomas Möllenhoff, Michael Moeller, Daniel Cremers},
    journal={International Conference on Learning Representations},
    year={2018},
    url = {https://arxiv.org/abs/1706.04638}
}
Owner
Thomas Frerix
Thomas Frerix
The Ludii general game system, developed as part of the ERC-funded Digital Ludeme Project.

The Ludii General Game System Ludii is a general game system being developed as part of the ERC-funded Digital Ludeme Project (DLP). This repository h

Digital Ludeme Project 50 Jan 04, 2023
Optimal space decomposition based-product quantization for approximate nearest neighbor search

Optimal space decomposition based-product quantization for approximate nearest neighbor search Abstract Product quantization(PQ) is an effective neare

Mylove 1 Nov 19, 2021
利用Tensorflow实现基于CNN的中文短文本分类

Text Classification with CNN 使用卷积神经网络进行中文文本分类 CNN做句子分类的论文可以参看: Convolutional Neural Networks for Sentence Classification 还可以去读dennybritz大牛的博客:Implemen

Jeremiah 4 Nov 08, 2022
Transformer Tracking (CVPR2021)

TransT - Transformer Tracking [CVPR2021] Official implementation of the TransT (CVPR2021) , including training code and trained models. We are revisin

chenxin 465 Jan 06, 2023
Official Keras Implementation for UNet++ in IEEE Transactions on Medical Imaging and DLMIA 2018

UNet++: A Nested U-Net Architecture for Medical Image Segmentation UNet++ is a new general purpose image segmentation architecture for more accurate i

Zongwei Zhou 1.8k Dec 27, 2022
Using pytorch to implement unet network for liver image segmentation.

Using pytorch to implement unet network for liver image segmentation.

zxq 1 Dec 17, 2021
A dataset for online Arabic calligraphy

Calliar Calliar is a dataset for Arabic calligraphy. The dataset consists of 2500 json files that contain strokes manually annotated for Arabic callig

ARBML 114 Dec 28, 2022
Split your patch similarly to `git add -p` but supporting multiple buckets

split-patch.py This is git add -p on steroids for patches. Given a my.patch you can run ./split-patch.py my.patch You can choose in which bucket to p

102 Oct 06, 2022
Data Engineering ZoomCamp

Data Engineering ZoomCamp I'm partaking in a Data Engineering Bootcamp / Zoomcamp and will be tracking my progress here. I can't promise these notes w

Aaron 61 Jan 06, 2023
a dnn ai project to classify which food people are eating on audio recordings

Deep Learning - EAT Challenge About This project is part of an AI challenge of the DeepLearning course 2021 at the University of Augsburg. The objecti

Marco Tröster 1 Oct 24, 2021
Structural Constraints on Information Content in Human Brain States

Structural Constraints on Information Content in Human Brain States Code accompanying the paper "The information content of brain states is explained

Leon Weninger 3 Sep 07, 2022
DNA-RECON { Automatic Web Reconnaissance Tool }

ABOUT TOOL : DNA-RECON is an automatic web reconnaissance tool written in python. This tool made for reconnaissance and information gathering with an

NIKUNJ BHATT 25 Aug 11, 2021
Code for the paper: "On the Bottleneck of Graph Neural Networks and Its Practical Implications"

On the Bottleneck of Graph Neural Networks and its Practical Implications This is the official implementation of the paper: On the Bottleneck of Graph

75 Dec 22, 2022
Distance-Ratio-Based Formulation for Metric Learning

Distance-Ratio-Based Formulation for Metric Learning Environment Python3 Pytorch (http://pytorch.org/) (version 1.6.0+cu101) json tqdm Preparing datas

Hyeongji Kim 1 Dec 07, 2022
An introduction to satellite image analysis using Python + OpenCV and JavaScript + Google Earth Engine

A Gentle Introduction to Satellite Image Processing Welcome to this introductory course on Satellite Image Analysis! Satellite imagery has become a pr

Edward Oughton 32 Jan 03, 2023
This repository provides an efficient PyTorch-based library for training deep models.

s3sec Test AWS S3 buckets for read/write/delete access This tool was developed to quickly test a list of s3 buckets for public read, write and delete

Bytedance Inc. 123 Jan 05, 2023
Text-to-Image generation

Generate vivid Images for Any (Chinese) text CogView is a pretrained (4B-param) transformer for text-to-image generation in general domain. Read our p

THUDM 1.3k Dec 29, 2022
O-CNN: Octree-based Convolutional Neural Networks for 3D Shape Analysis

O-CNN This repository contains the implementation of our papers related with O-CNN. The code is released under the MIT license. O-CNN: Octree-based Co

Microsoft 607 Dec 28, 2022
PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization using Augmented-Self Reference and Dense Semantic Correspondence) and pre-trained model on ImageNet dataset

Reference-Based-Sketch-Image-Colorization-ImageNet This is a PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization usin

Yuzhi ZHAO 11 Jul 28, 2022
Rainbow DQN implementation that outperforms the paper's results on 40% of games using 20x less data 🌈

Rainbow 🌈 An implementation of Rainbow DQN which outperforms the paper's (Hessel et al. 2017) results on 40% of tested games while using 20x less dat

Dominik Schmidt 31 Dec 21, 2022