Implementation of the SUMO (Slim U-Net trained on MODA) model

Related tags

Deep Learningsumo
Overview

SUMO - Slim U-Net trained on MODA

Implementation of the SUMO (Slim U-Net trained on MODA) model as described in:

TODO: add reference to paper once available

Installation Guide

On Linux with anaconda or miniconda installed, the project can be used by running the following commands to clone the repository, create a new environment and install the required dependencies:

git clone https://github.com/dslaborg/sumo.git
cd sumo
conda env create --file environment.yaml
conda activate sumo

Scripts - Quick Guide

Running and evaluating an experiment

The main model training and evaluation procedure is implemented in bin/train.py and bin/eval.py using the Pytorch Lightning framework. A chosen configuration used to train the model is called an experiment, and the evaluation is carried out using a configuration and the result folder of a training run.

train.py

Trains the model as specified in the corresponding configuration file, writes its log to the console and saves a log file and intermediate results for Tensorboard and model checkpoints to a result directory.

Arguments:

  • -e NAME, --experiment NAME: name of experiment to run, for which a NAME.yaml file has to exist in the config directory; default is default

eval.py

Evaluates a trained model, either on the validation data or test data and reports the achieved metrics.

Arguments:

  • -e NAME, --experiment NAME: name of configuration file, that should be used for evaluation, for which a NAME.yaml file has to exist in the config directory; usually equals the experiment used to train the model; default is default
  • -i PATH, --input PATH: path containing the model that should be evaluated; the given input can either be a model checkpoint, which then will be used directly, or the output directory of a train.py execution, in which case the best model will be used from PATH/models/; if the configuration has cross validation enabled, the output directory is expected and the best model per fold will be obtained from PATH/fold_*/models/; no default value
  • -t, --test: if given, the test data is used instead of the validation data

Further example scripts

In addition to scripts used to create the figures in our manuscript (spindle_analysis.py, spindle_analysis_correlations.py and spindle_detection_examply.py), the scripts directory contains two scripts that demonstrate the usage of this project.

create_data_splits.py

Demonstrates the procedure used to split the data into test and non-test subjects and the subsequent creation of a hold-out validation set and (alternatively) cross validation folds.

Arguments:

  • -i PATH, --input PATH: path containing the (necessary) input data, as produced by the MODA file MODA02_genEEGVectBlock.m; relative paths starting from the scripts directory; default is ../input/
  • -o PATH, --output PATH: path in which the generated data splits should be stored in; relative paths starting from the scripts directory; default is ../output/datasets_{datatime}
  • -n NUMBER, --n_datasets NUMBER: number of random split-candidates drawn/generated; default is 25
  • -t FRACTION, --test FRACTION: Proportion of data that is used as test data; 0<=FRACTION<=1; default is 0.2

predict_plain_data.py

Demonstrates how to predict spindles with a trained SUMO model on arbitrary EEG data, which is expected as a dict with the keys representing the EEG channels and the values the corresponding data vector.

Arguments:

  • -d PATH, --data_path PATH: path containing the input data, either in .pickle or .npy format, as a dict with the channel name as key and the EEG data as value; relative paths starting from the scripts directory; no default value
  • -m PATH, --model_path PATH: path containing the model checkpoint, which should be used to predict spindles; relative paths starting from the scripts directory; default is ../output/final.ckpt
  • -g NUMBER, --gpus NUMBER: number of GPUs to use, if 0 is given, calculations are done using CPUs; default is 0
  • -sr RATE, --sample_rate RATE: sample rate of the provided data; default is 100.0

Project Setup

The project is set up as follows:

  • bin/: contains the train.py and eval.py scripts, which are used for model training and subsequent evaluation in experiments (as configured within the config directory) using the Pytorch Lightning framework
  • config/: contains the configurations of the experiments, configuring how to train or evaluate the model
    • default.yaml: provides a sensible default configuration
    • final.yaml: contains the configuration used to train the final model checkpoint (output/final.ckpt)
    • predict.yaml: configuration that can be used to predict spindles on arbitrary data, e.g. by using the script at scripts/predict_plain_data.py
  • input/: should contain the used input files, e.g. the EEG data and annotated spindles as produced by the MODA repository and transformed as demonstrated in the /scripts/create_data_splits.py file
  • output/: contains generated output by any experiment runs or scripts, e.g. the created figures
    • final.ckpt: the final model checkpoint, on which the test data performance, as reported in the paper, was obtained
  • scripts/: various scripts used to create the plots of our paper and to demonstrate the usage of this project
    • a7/: python implementation of the A7 algorithm as described in:
      Karine Lacourse, Jacques Delfrate, Julien Beaudry, Paul E. Peppard and Simon C. Warby. "A sleep spindle detection algorithm that emulates human expert spindle scoring." Journal of Neuroscience Methods 316 (2019): 3-11.
      
    • create_data_splits.py: demonstrates the procedure, how the data set splits were obtained, including the evaluation on the A7 algorithm
    • predict_plain_data.py: demonstrates the prediction of spindles on arbitrary EEG data, using a trained model checkpoint
    • spindle_analysis.py, spindle_analysis_correlations.py, spindle_detection_example.py: scripts used to create some of the figures used in our paper
  • sumo/: the implementation of the SUMO model and used classes and functions, for more information see the docstrings

Configuration Parameters

The configuration of an experiment is implemented using yaml configuration files. These files must be placed within the config directory and must match the name past as --experiment to the eval.py or train.py script. The default.yaml is always loaded as a set of default configuration parameters and parameters specified in an additional file overwrite the default values. Any parameters or groups of parameters that should be None, have to be configured as either null or Null following the YAML definition.

The available parameters are as follows:

  • data: configuration of the used input data; optional, can be None if spindle should be annotated on arbitrary EEG data
    • directory and file_name: the input file containing the Subject objects (see scripts/create_data_splits.py) is expected to be located at ${directory}/${file_name}, where relative paths are to be starting from the root project directory; the file should be a (pickled) dict with the name of a data set as key and the list of corresponding subjects as value; default is input/subjects.pickle
    • split: describing the keys of the data sets to be used, specifying either train and validation, or cross_validation, and optionally test
      • cross_validation: can be either an integer k>=2, in which the keys fold_0, ..., fold_{k-1} are expected to exist, or a list of keys
    • batch_size: size of the used minbatches during training; default is 12
    • preprocessing: if z-scoring should be performed on the EEG data, default is True
  • experiment: definition of the performed experiment; mandatory
    • model: definition of the model configuration; mandatory
      • n_classes: number of output parameters; default is 2
      • activation: name of an activation function as defined in torch.nn package; default is ReLU
      • depth: number of layers of the U excluding the last layer; default is 2
      • channel_size: number of filters of the convolutions in the first layer; default is 16
      • pools: list containing the size of pooling and upsampling operations; has to contain as many values as the value of depth; default [4;4]
      • convolution_params: parameters used by the Conv1d modules
      • moving_avg_size: width of the moving average filter; default is 42
    • train: configuration used in training the model; mandatory
      • n_epochs: maximal number of epochs to be run before stopping training; default is 800
      • early_stopping: number of epochs without any improvement in the val_f1_mean metric, after which training is stopped; default is 300
      • optimizer: configuration of an optimizer as defined in torch.optim package; contains class_name (default is Adam) and parameters, which are passed to the constructor of the used optimizer class
      • lr_scheduler: used learning rate scheduler; optional, default is None
      • loss: configuration of loss function as defined either in sumo.loss package (GeneralizedDiceLoss) or torch.nn package; contains class_name (default is GeneralizedDiceLoss) and parameters, which are passed to the constructor of the used loss class
    • validation: configuration used in evaluating the model; mandatory
      • overlap_threshold_step: step size of the overlap thresholds used to calculate (validation) F1 scores
hipCaffe: the HIP port of Caffe

Caffe Caffe is a deep learning framework made with expression, speed, and modularity in mind. It is developed by the Berkeley Vision and Learning Cent

ROCm Software Platform 126 Dec 05, 2022
Pytorch implementation of One-Shot Affordance Detection

One-shot Affordance Detection PyTorch implementation of our one-shot affordance detection models. This repository contains PyTorch evaluation code, tr

46 Dec 12, 2022
UpChecker is a simple opensource project to host it fast on your server and check is server up, view statistic, get messages if it is down. UpChecker - just run file and use project easy

UpChecker UpChecker is a simple opensource project to host it fast on your server and check is server up, view statistic, get messages if it is down.

Yan 4 Apr 07, 2022
Easily pull telemetry data and create beautiful visualizations for analysis.

This repository is a work in progress. Anything and everything is subject to change. Porpo Table of Contents Porpo Table of Contents General Informati

Ryan Dawes 33 Nov 30, 2022
PyTorch implementation of ICLR 2022 paper PiCO: Contrastive Label Disambiguation for Partial Label Learning

PiCO: Contrastive Label Disambiguation for Partial Label Learning This is a PyTorch implementation of ICLR 2022 paper PiCO: Contrastive Label Disambig

王皓波 147 Jan 07, 2023
Official code for "Focal Self-attention for Local-Global Interactions in Vision Transformers"

Focal Transformer This is the official implementation of our Focal Transformer -- "Focal Self-attention for Local-Global Interactions in Vision Transf

Microsoft 486 Dec 20, 2022
TensorFlow implementation of Style Transfer Generative Adversarial Networks: Learning to Play Chess Differently.

Adversarial Chess TensorFlow implementation of Style Transfer Generative Adversarial Networks: Learning to Play Chess Differently. Requirements To run

Muthu Chidambaram 30 Sep 07, 2021
Learning Neural Painters Fast! using PyTorch and Fast.ai

The Joy of Neural Painting Learning Neural Painters Fast! using PyTorch and Fast.ai Blogpost with more details: The Joy of Neural Painting The impleme

Libre AI 72 Nov 10, 2022
Implementation of MA-Trace - a general-purpose multi-agent RL algorithm for cooperative environments.

Off-Policy Correction For Multi-Agent Reinforcement Learning This repository is the official implementation of Off-Policy Correction For Multi-Agent R

4 Aug 18, 2022
Sample code and notebooks for Vertex AI, the end-to-end machine learning platform on Google Cloud

Google Cloud Vertex AI Samples Welcome to the Google Cloud Vertex AI sample repository. Overview The repository contains notebooks and community conte

Google Cloud Platform 560 Dec 31, 2022
Official code of "Mitigating the Mutual Error Amplification for Semi-Supervised Object Detection"

CrossTeaching-SSOD 0. Introduction Official code of "Mitigating the Mutual Error Amplification for Semi-Supervised Object Detection" This repo include

Bruno Ma 9 Nov 29, 2022
the code of the paper: Recurrent Multi-view Alignment Network for Unsupervised Surface Registration (CVPR 2021)

RMA-Net This repo is the implementation of the paper: Recurrent Multi-view Alignment Network for Unsupervised Surface Registration (CVPR 2021). Paper

Wanquan Feng 205 Nov 09, 2022
Code of paper "CDFI: Compression-Driven Network Design for Frame Interpolation", CVPR 2021

CDFI (Compression-Driven-Frame-Interpolation) [Paper] (Coming soon...) | [arXiv] Tianyu Ding*, Luming Liang*, Zhihui Zhu, Ilya Zharkov IEEE Conference

Tianyu Ding 95 Dec 04, 2022
Video Representation Learning by Recognizing Temporal Transformations. In ECCV, 2020.

Video Representation Learning by Recognizing Temporal Transformations [Project Page] Simon Jenni, Givi Meishvili, and Paolo Favaro. In ECCV, 2020. Thi

Simon Jenni 46 Nov 14, 2022
The official codes for the ICCV2021 presentation "Uniformity in Heterogeneity: Diving Deep into Count Interval Partition for Crowd Counting"

UEPNet (ICCV2021 Poster Presentation) This repository contains codes for the official implementation in PyTorch of UEPNet as described in Uniformity i

Tencent YouTu Research 15 Dec 14, 2022
[ICCV 2021 Oral] Deep Evidential Action Recognition

DEAR (Deep Evidential Action Recognition) Project | Paper & Supp Wentao Bao, Qi Yu, Yu Kong International Conference on Computer Vision (ICCV Oral), 2

Wentao Bao 80 Jan 03, 2023
Measuring Coding Challenge Competence With APPS

Measuring Coding Challenge Competence With APPS This is the repository for Measuring Coding Challenge Competence With APPS by Dan Hendrycks*, Steven B

Dan Hendrycks 218 Dec 27, 2022
using STGCN to achieve egg classification task

EEG Classification   The task requires us to classify electroencephalography(EEG) into six categories, including human body, human face, animal body,

4 Jun 13, 2022
Official code for On Path Integration of Grid Cells: Group Representation and Isotropic Scaling (NeurIPS 2021)

On Path Integration of Grid Cells: Group Representation and Isotropic Scaling This repo contains the official implementation for the paper On Path Int

Ruiqi Gao 39 Nov 10, 2022
Video Corpus Moment Retrieval with Contrastive Learning (SIGIR 2021)

Video Corpus Moment Retrieval with Contrastive Learning PyTorch implementation for the paper "Video Corpus Moment Retrieval with Contrastive Learning"

ZHANG HAO 42 Dec 29, 2022