Neural network pruning for finding a sparse computational model for controlling a biological motor task.

Overview

MothPruning

Scientific Overview

Originally inspired by biological nervous systems, deep neural networks (DNNs) are powerful computational tools for modeling complex systems. DNNs are used in a diversity of domains and have helped solve some of the most intractable problems in physics, biology, and computer science. Despite their prevalence, the use of DNNs as a modeling tool comes with some major downsides. DNNs are highly overparameterized, which often results in them being difficult to generalize and interpret, as well as being incredibly computationally expensive. Unlike DNNs, which are often trained until they reach the highest accuracy possible, biological networks have to balance performance with robustness to a noisy and dynamic environment. Biological neural systems use a variety of mechanisms to promote specialized and efficient pathways capable of performing complex tasks in the presence of noise. One such mechanism, synaptic pruning, plays a significant role in refining task-specific behaviors. Synaptic pruning results in a more sparsely connected network that can still perform complex cognitive and motor tasks. Here, we draw inspiration from biology and use DNNs and the method of neural network pruning to find a sparse computational model for controlling a biological motor task.

In this work, we use the inertial dynamics model in [2] to simulate examples of M. sexta hovering flight. These data are used to train a DNN to learn the controllers for hovering. Drawing inspiration from pruning in biological neural systems, we sparsify the network using neural network pruning. Here, we prune weights based simply on their magnitudes, removing those weights closest to zero. Insects must maneuver through high noise environments to accomplish controlled flight. It is often assumed that there is a trade-off between perfect flight control and robustness to noise and that the sensory data may be limited by the signal-to-noise ratio. Thus the network need not train for the most accurate model since in practice noise prevents high-fidelity models from exhibiting their underlying accuracy. Rather, we seek to find the sparsest model capable of performing the task given the noisy environment. We employed two methods for neural network pruning: either through manually setting weights to zero or by utilizing binary masking layers. Furthermore, the DNN is pruned sequentially, meaning groups of weights are removed slowly from the network, with retraining in-between successive prunes, until a target sparsity is reached. Monte Carlo simulations are also used to quantify the statistical distribution of network weights during pruning given random initialization of network weights.

For more information, please see our paper [1].

This is an image!

Project Description

The deep, fully-connected neural network was constructed with ten input variables and seven output variables. The initial and final state space conditions are the inputs to the network: i, i, i, i, i, i, f, f, f, and f. The network predicts the control variables and the final derivatives of the state space in its output layer: x, y, , f, f, f, and f.

After the fully-connected network is trained to a minimum error, we used the method of neural network pruning to promote sparsity between the network layers. In this work, a target sparsity (percentage of pruned network weights) is specified and the smallest magnitude weights are forced to zero. The network is then retrained until a minimum error is reached. This process is repeated until most of the weights have been pruned from the network.

The training and pruning protocols were developed using Keras with the TensorFlow backend. To scale up training for the statistical analysis of many networks, the training and pruning protocols were parallelized using the Jax framework.

To ensure weights remain pruned during retraining, we implemented the pruning functionality of a TensorFlow built toolkit called the Model Optimization Toolkit. The toolkit contains functions for pruning deep neural networks. In the Model Optimization Toolkit, pruning is achieved through the use of binary masking layers that are multiplied element-wise to each weight matrix in the network.

To be able to train and analyze many neural networks, the training and pruning protocols were parallelized in the Jax framework. Jax however does not come with a toolkit for pruning, therefore pruning by way of the binary masking matrices was coded into the training loop.

Installation

Create new conda environment with tools for generating data and training network (Note that this environment requires a GPU and the correct NVIDIA drivers).

conda env create -f environment_ODE_DL.yml

Create kernelspec (so you can see this kernel in JupyterLab).

conda activate [environment name]
python -m ipykernel install --user --name [environment name]
conda deactivate

To install Jax and Flax please follow the instructions on the Jax Github.

Data

To use the TensorFlow version of this code, you need to gerenate simulations of moth hovering for the data. The Jax version (multi-network train and prune) has data provided in this repository.

cd MothMachineLearning/Underactuated/GenerateData

and use 010_OneTorqueParallelSims.ipynb to generate the simulations.

How to use

The following guide walks through the process of training and pruning many networks in parallel using the Jax framework. However, the TensorFlow code is also provided for experimentation and visualization.

Step 1: Train networks

cd MothMachineLearning/Underactuated/TrainNetwork/multiNetPrune/

First we train and prune the desired number of networks in parallel using the Jax framework. Choose the number of networks you wish to train/prune in parallel by adjusting the numParallel parameter. You can also define the number of layers, units, and other hyperparameters. Use the command

python3 step1_train.py

to train and prune the networks in parallel.

Step 2: Evaluate at prunes

Next, the networks need to be evaulated at each prune. Use the command

python3 step2_pruneEval.py

to evaluate the networks at each prune.

Step 3: Pre-process networks

This code prepares the networks for sparse network identification (explained in the next step). It essentially just reorganizes the data. Open and run step3_preprocess.ipynb to preprocess, making sure to change modeltimestamp and the file names to the correct ones for your run.

Step 4: Find sparse networks

This codes finds the optimally sparse networks. For each network, the most pruned version whose loss is below a specified threshold (here 0.001) is kept. For example, the image below is a single network that has gone through the sequential pruning process and the red line specifies the defined threshold. For this example, the optimally sparse network is the one pruned by 94% (i.e. 6% of the original weights remain).

This is an image!

The sparse networks are collected and saved to a file called sparseNetworks.pkl. Open and run step4_findSparse.ipynb, making sure to change modeltimestamp and the file names to the correct ones for your run.

Note that if a network does not have a single prune that is below the loss threshold, it will be skipped and not included in the list of sparseNetworks. For example, if you trained and pruned 10 networks and 3 did not have a prune below a loss of 0.001, the list sparseNetworks will be length 7.

References

[1] Zahn, O., Bustamante, Jr J., Switzer, C., Daniel, T., and Kutz, J. N. (2022). Pruning deep neural networks generates a sparse, bio-inspired nonlinear controller for insect flight.

[2] Bustamante, Jr J., Ahmed, M., Deora, T., Fabien, B., and Daniel, T. (2021). Abdominal movements in insect flight reshape the role of non-aerodynamic structures for flight maneuverability. J. Integrative and Comparative Biology. In revision.

Owner
Olivia Thomas
Physics graduate student at the University of Washington
Olivia Thomas
Generating images from caption and vice versa via CLIP-Guided Generative Latent Space Search

CLIP-GLaSS Repository for the paper Generating images from caption and vice versa via CLIP-Guided Generative Latent Space Search An in-browser demo is

Federico Galatolo 172 Dec 22, 2022
PyTorch code for JEREX: Joint Entity-Level Relation Extractor

JEREX: "Joint Entity-Level Relation Extractor" PyTorch code for JEREX: "Joint Entity-Level Relation Extractor". For a description of the model and exp

LAVIS - NLP Working Group 50 Dec 01, 2022
This repository contains the PyTorch implementation of the paper STaCK: Sentence Ordering with Temporal Commonsense Knowledge appearing at EMNLP 2021.

STaCK: Sentence Ordering with Temporal Commonsense Knowledge This repository contains the pytorch implementation of the paper STaCK: Sentence Ordering

Deep Cognition and Language Research (DeCLaRe) Lab 23 Dec 16, 2022
Forecasting directional movements of stock prices for intraday trading using LSTM and random forest

Forecasting directional movements of stock-prices for intraday trading using LSTM and random-forest https://arxiv.org/abs/2004.10178 Pushpendu Ghosh,

Pushpendu Ghosh 270 Dec 24, 2022
DeLighT: Very Deep and Light-Weight Transformers

DeLighT: Very Deep and Light-weight Transformers This repository contains the source code of our work on building efficient sequence models: DeFINE (I

Sachin Mehta 440 Dec 18, 2022
An Ensemble of CNN (Python 3.5.1 Tensorflow 1.3 numpy 1.13)

An Ensemble of CNN (Python 3.5.1 Tensorflow 1.3 numpy 1.13)

0 May 06, 2022
The first dataset on shadow generation for the foreground object in real-world scenes.

Object-Shadow-Generation-Dataset-DESOBA Object Shadow Generation is to deal with the shadow inconsistency between the foreground object and the backgr

BCMI 105 Dec 30, 2022
PyTorch implementation of CVPR'18 - Perturbative Neural Networks

This is an attempt to reproduce results in Perturbative Neural Networks paper. See original repo for details.

Michael Klachko 57 May 14, 2021
Add gui for YoloV5 using PyQt5

HEAD 更新2021.08.16 **添加图片和视频保存功能: 1.图片和视频按照当前系统时间进行命名 2.各自检测结果存放入output文件夹 3.摄像头检测的默认设备序号更改为0,减少调试报错 温馨提示: 1.项目放置在全英文路径下,防止项目报错 2.默认使用cpu进行检测,自

Ruihao Wang 65 Dec 27, 2022
Autoregressive Models in PyTorch.

Autoregressive This repository contains all the necessary PyTorch code, tailored to my presentation, to train and generate data from WaveNet-like auto

Christoph Heindl 41 Oct 09, 2022
Numbering permanent and deciduous teeth via deep instance segmentation in panoramic X-rays

Numbering permanent and deciduous teeth via deep instance segmentation in panoramic X-rays In this repo, you will find the instructions on how to requ

Intelligent Vision Research Lab 4 Jul 21, 2022
RODD: A Self-Supervised Approach for Robust Out-of-Distribution Detection

RODD Official Implementation of 2022 CVPRW Paper RODD: A Self-Supervised Approach for Robust Out-of-Distribution Detection Introduction: Recent studie

Umar Khalid 17 Oct 11, 2022
Lightweight, Python library for fast and reproducible experimentation :microscope:

Steppy What is Steppy? Steppy is a lightweight, open-source, Python 3 library for fast and reproducible experimentation. Steppy lets data scientist fo

minerva.ml 134 Jul 10, 2022
zeus is a Python implementation of the Ensemble Slice Sampling method.

zeus is a Python implementation of the Ensemble Slice Sampling method. Fast & Robust Bayesian Inference, Efficient Markov Chain Monte Carlo (MCMC), Bl

Minas Karamanis 197 Dec 04, 2022
Transferable Unrestricted Attacks, which won 1st place in CVPR’21 Security AI Challenger: Unrestricted Adversarial Attacks on ImageNet.

Transferable Unrestricted Adversarial Examples This is the PyTorch implementation of the Arxiv paper: Towards Transferable Unrestricted Adversarial Ex

equation 16 Dec 29, 2022
code from "Tensor decomposition of higher-order correlations by nonlinear Hebbian plasticity"

Code associated with the paper "Tensor decomposition of higher-order correlations by nonlinear Hebbian learning," Ocker & Buice, Neurips 2021. "plot_f

Gabriel Koch Ocker 4 Oct 16, 2022
A Shading-Guided Generative Implicit Model for Shape-Accurate 3D-Aware Image Synthesis

A Shading-Guided Generative Implicit Model for Shape-Accurate 3D-Aware Image Synthesis Project Page | Paper A Shading-Guided Generative Implicit Model

Xingang Pan 115 Dec 18, 2022
pytorch implementation of ABC : Auxiliary Balanced Classifier for Class-imbalanced Semi-supervised Learning

ABC:Auxiliary Balanced Classifier for Class-imbalanced Semi-supervised Learning, NeurIPS 2021 pytorch implementation of ABC : Auxiliary Balanced Class

Hyuck Lee 25 Dec 22, 2022
This repository accompanies the ACM TOIS paper "What can I cook with these ingredients?" - Understanding cooking-related information needs in conversational search

In this repository you find data that has been gathered when conducting in-situ experiments in a conversational cooking setting. These data include tr

6 Sep 22, 2022
SpeechNAS Better Trade off between Latency and Accuracy for Large Scale Speaker Verification

SpeechNAS Better Trade off between Latency and Accuracy for Large Scale Speaker Verification

Wentao Zhu 24 May 20, 2022