InvTorch: memory-efficient models with invertible functions

Related tags

Deep Learninginvtorch
Overview

InvTorch: Memory-Efficient Invertible Functions

This module extends the functionality of torch.utils.checkpoint.checkpoint to work with invertible functions. So, not only the intermediate activations will be released from memory. The input tensors get deallocated and recomputed later using the inverse function only in the backward pass. This is useful in extreme situations where more compute is traded with memory. However, there are few caveats to consider which are detailed here.

Installation

InvTorch has minimal dependencies. It only requires PyTorch version 1.10.0 or later.

conda install pytorch==1.10.0 torchvision torchaudio cudatoolkit=11.3 -c pytorch
pip install invtorch

Basic Usage

The main module that we are interested in is InvertibleModule which inherits from torch.nn.Module. Subclass it to implement your own invertible code.

import torch
from torch import nn
from invtorch import InvertibleModule


class InvertibleLinear(InvertibleModule):
    def __init__(self, in_features, out_features):
        super().__init__(invertible=True, checkpoint=True)
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.randn(out_features))

    def function(self, inputs):
        outputs = inputs @ self.weight.T + self.bias
        requires_grad = self.do_require_grad(inputs, self.weight, self.bias)
        return outputs.requires_grad_(requires_grad)

    def inverse(self, outputs):
        return (outputs - self.bias) @ self.weight.T.pinverse()

Structure

You can immediately notice few differences to the regular PyTorch module here. There is no longer a need to define forward(). Instead, it is replaced with function(*inputs). Additionally, it is necessary to define its inverse function as inverse(*outputs). Both methods can only take one or more positional arguments and return a torch.Tensor or a tuple of outputs which can have anything including tensors.

Requires Gradient

function() must manually call .requires_grad_(True/False) on all output tensors. The forward pass is run in no_grad mode and there is no way to detect which output need gradients without tracing. It is possible to infer this from requires_grad values of the inputs and self.parameters(). The above code uses do_require_grad() which returns True if any input did require gradient.

Example

Now, this model is ready to be instantiated and used directly.

x = torch.randn(10, 3)
model = InvertibleLinear(3, 5)
print('Is invertible:', model.check_inverse(x))

y = model(x)
print('Output requires_grad:', y.requires_grad)
print('Input was freed:', x.storage().size() == 0)

y.backward(torch.randn_like(y))
print('Input was restored:', x.storage().size() != 0)

Checkpoint and Invertible Modes

InvertibleModule has two flags which control the mode of operation; checkpoint and invertible. If checkpoint was set to False, or when working in no_grad mode, or no input or parameter has requires_grad set to True, it acts exactly as a normal PyTorch module. Otherwise, the model is either invertible or an ordinary checkpoint depending on whether invertible is set to True or False, respectively. Those, flags can be changed at any time during operation without any repercussions.

Limitations

Under the hood, InvertibleModule uses invertible_checkpoint(); a low-level implementation which allows it to function. There are few considerations to keep in mind when working with invertible checkpoints and non-materialized tensors. Please, refer to the documentation in the code for more details.

Overriding forward()

Although forward() is now doing important things to ensure the validity of the results when calling invertible_checkpoint(), it can still be overridden. The main reason of doing so is to provide a more user-friendly interface; function signature and output format. For example, function() could return extra outputs that are not needed in the module outputs but are essential for correctly computing the inverse(). In such case, define forward() to wrap outputs = super().forward(*inputs) more cleanly.

TODOs

Here are few feature ideas that could be implemented to enrich the utility of this package:

  • Add more basic operations and modules
  • Add coupling and interleave -based invertible operations
  • Add more checks to help the user in debugging more features
  • Allow picking some inputs to not be freed in invertible mode
  • Context-manager to temporarily change the mode of operation
  • Implement dynamic discovery for outputs that requires_grad
  • Develop an automatic mode optimization for a network for various objectives
You might also like...
A memory-efficient implementation of DenseNets

efficient_densenet_pytorch A PyTorch =1.0 implementation of DenseNets, optimized to save GPU memory. Recent updates Now works on PyTorch 1.0! It uses

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021
Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021

Memory-Efficient Multi-Level In-Situ Generation (MLG) By Jiaqi Gu, Hanqing Zhu, Chenghao Feng, Mingjie Liu, Zixuan Jiang, Ray T. Chen and David Z. Pan

Memory-efficient optimum einsum using opt_einsum planning and PyTorch kernels.

opt-einsum-torch There have been many implementations of Einstein's summation. numpy's numpy.einsum is the least efficient one as it only runs in sing

Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

XtremeDistil framework for distilling/compressing massive multilingual neural network models to tiny and efficient models for AI at scale

XtremeDistilTransformers for Distilling Massive Multilingual Neural Networks ACL 2020 Microsoft Research [Paper] [Video] Releasing [XtremeDistilTransf

Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

Efficient-GlobalPointer - Pytorch Efficient GlobalPointer
Efficient-GlobalPointer - Pytorch Efficient GlobalPointer

引言 感谢苏神带来的模型,原文地址:https://spaces.ac.cn/archives/8877 如何运行 对应模型EfficientGlobalPoi

Releases(v0.5.0)
Owner
Modar M. Alfadly
Deep learning researcher interested in understanding neural networks
Modar M. Alfadly
PyTorch-lightning implementation of the ESFW module proposed in our paper Edge-Selective Feature Weaving for Point Cloud Matching

Edge-Selective Feature Weaving for Point Cloud Matching This repository contains a PyTorch-lightning implementation of the ESFW module proposed in our

5 Feb 14, 2022
Improving adversarial robustness by a coupling rejection strategy

Adversarial Training with Rectified Rejection The code for the paper Adversarial Training with Rectified Rejection. Environment settings and libraries

Tianyu Pang 29 Jan 06, 2023
Python package for dynamic system estimation of time series

PyDSE Toolset for Dynamic System Estimation for time series inspired by DSE. It is in a beta state and only includes ARMA models right now. Documentat

Blue Yonder GmbH 40 Oct 07, 2022
Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more

Apache MXNet (incubating) for Deep Learning Master Docs License Apache MXNet (incubating) is a deep learning framework designed for both efficiency an

ROCm Software Platform 29 Nov 16, 2022
[NeurIPS 2021] Garment4D: Garment Reconstruction from Point Cloud Sequences

Garment4D [PDF] | [OpenReview] | [Project Page] Overview This is the codebase for our NeurIPS 2021 paper Garment4D: Garment Reconstruction from Point

Fangzhou Hong 112 Dec 23, 2022
Animation of solving the traveling salesman problem to optimality using mixed-integer programming and iteratively eliminating sub tours

tsp-streamlit Animation of solving the traveling salesman problem to optimality using mixed-integer programming and iteratively eliminating sub tours.

4 Nov 05, 2022
Official Implementation of "Transformers Can Do Bayesian Inference"

Official Code for the Paper "Transformers Can Do Bayesian Inference" We train Transformers to do Bayesian Prediction on novel datasets for a large var

AutoML-Freiburg-Hannover 103 Dec 25, 2022
LoveDA: A Remote Sensing Land-Cover Dataset for Domain Adaptive Semantic Segmentation

LoveDA: A Remote Sensing Land-Cover Dataset for Domain Adaptive Semantic Segmentation by Junjue Wang, Zhuo Zheng, Ailong Ma, Xiaoyan Lu, and Yanfei Zh

Payphone 8 Nov 21, 2022
Software Platform for solving and manipulating multiparametric programs in Python

PPOPT Python Parametric OPtimization Toolbox (PPOPT) is a software platform for solving and manipulating multiparametric programs in Python. This pack

10 Sep 13, 2022
PyTorch trainer and model for Sequence Classification

PyTorch-trainer-and-model-for-Sequence-Classification After cloning the repository, modify your training data so that the training data is a .csv file

NhanTieu 2 Dec 09, 2022
TrackFormer: Multi-Object Tracking with Transformers

TrackFormer: Multi-Object Tracking with Transformers This repository provides the official implementation of the TrackFormer: Multi-Object Tracking wi

Tim Meinhardt 321 Dec 29, 2022
《Rethinking Sptil Dimensions of Vision Trnsformers》(2021)

Rethinking Spatial Dimensions of Vision Transformers Byeongho Heo, Sangdoo Yun, Dongyoon Han, Sanghyuk Chun, Junsuk Choe, Seong Joon Oh | Paper NAVER

NAVER AI 224 Dec 27, 2022
some classic model used to segment the medical images like CT、X-ray and so on

github_project This is a project for medical image segmentation. This project includes common medical image segmentation models such as U-net, FCN, De

2 Mar 30, 2022
Pocsploit is a lightweight, flexible and novel open source poc verification framework

Pocsploit is a lightweight, flexible and novel open source poc verification framework

cckuailong 208 Dec 24, 2022
Key information extraction from invoice document with Graph Convolution Network

Key Information Extraction from Scanned Invoices Key information extraction from invoice document with Graph Convolution Network Related blog post fro

Phan Hoang 39 Dec 16, 2022
Source code for "Understanding Knowledge Integration in Language Models with Graph Convolutions"

Graph Convolution Simulator (GCS) Source code for "Understanding Knowledge Integration in Language Models with Graph Convolutions" Requirements: PyTor

yifan 10 Oct 18, 2022
DeLiGAN - This project is an implementation of the Generative Adversarial Network

This project is an implementation of the Generative Adversarial Network proposed in our CVPR 2017 paper - DeLiGAN : Generative Adversarial Net

Video Analytics Lab -- IISc 110 Sep 13, 2022
FairyTailor: Multimodal Generative Framework for Storytelling

FairyTailor: Multimodal Generative Framework for Storytelling

Eden Bens 172 Dec 30, 2022
Aspect-Sentiment-Multiple-Opinion Triplet Extraction (NLPCC 2021)

The code and data for the paper "Aspect-Sentiment-Multiple-Opinion Triplet Extraction" Requirements Python 3.6.8 torch==1.2.0 pytorch-transformers==1.

慢半拍 5 Jul 02, 2022
Learning Generative Models of Textured 3D Meshes from Real-World Images, ICCV 2021

Learning Generative Models of Textured 3D Meshes from Real-World Images This is the reference implementation of "Learning Generative Models of Texture

Dario Pavllo 115 Jan 07, 2023