PyTorch common framework to accelerate network implementation, training and validation

Overview

pytorch-framework

PyTorch common framework to accelerate network implementation, training and validation.

This framework is inspired by works from MMLab, which modularize the data, network, loss, metric, etc. to make the framework to be flexible, easy to modify and to extend.

How to use

# install necessary libs
pip install -r requirements.txt

The framework contains six different subfolders:

  • networks: all networks should be implemented under the networks folder with {NAME}_network.py filename.
  • datasets: all datasets should be implemented under the datasets folder with {NAME}_dataset.py filename.
  • losses: all losses should be implemented under the losses folder with {NAME}_loss.py filename.
  • metrics: all metrics should be implemented under the metrics folder with {NAME}_metric.py filename.
  • models: all models should be implemented under the models folder with {NAME}_model.py filename.
  • utils: all util functions should be implemented under the utils folder with {NAME}_util.py filename.

The training and validation procedure can be defined in the specified .yaml file.

# training 
CUDA_VISIBLE_DEVICES=gpu_ids python train.py --opt options/train.yaml

# validation/test
CUDA_VISIBLE_DEVICES=gpu_ids python test.py --opt options/test.yaml

In the .yaml file for training, you can define all the things related to training such as the experiment name, model, dataset, network, loss, optimizer, metrics and other hyper-parameters. Here is an example to train VGG16 for image classification:

# general setting
name: vgg_train
backend: dp # DataParallel
type: ClassifierModel
num_gpu: auto

# path to resume network
path:
  resume_state: ~

# datasets
datasets:
  train_dataset:
    name: TrainDataset
    type: ImageNet
    data_root: ../data/train_data
  val_dataset:
    name: ValDataset
    type: ImageNet
    data_root: ../data/val_data
  # setting for train dataset
  batch_size: 8

# network setting
networks:
  classifier:
    type: VGG16
    num_classes: 1000

# training setting
train:
  total_iter: 10000
  optims:
    classifier:
      type: Adam
      lr: 1.0e-4
  schedulers:
    classifier:
      type: none
  losses:
    ce_loss:
      type: CrossEntropyLoss

# validation setting
val:
  val_freq: 10000

# log setting
logger:
  print_freq: 100
  save_checkpoint_freq: 10000

In the .yaml file for validation, you can define all the things related to validation such as: model, dataset, metrics. Here is an example:

# general setting
name: test
backend: dp # DataParallel
type: ClassifierModel
num_gpu: auto
manual_seed: 1234

# path
path:
  resume_state: experiments/train/models/final.pth
  resume: false

# datasets
datasets:
  val_dataset:
    name: ValDataset
    type: ImageNet
    data_root: ../data/test_data

# network setting
networks:
  classifier:
    type: VGG
    num_classes: 1000

# validation setting
val:
  metrics:
    accuracy:
      type: calculate_accuracy

Framework Details

The core of the framework is the BaseModel in the base_model.py. The BaseModel controls the whole training/validation procedure from initialization over training/validation iteration to results saving.

  • Initialization: In the model initialization, it will read the configuration in the .yaml file and construct the corresponding networks, datasets, losses, optimizers, metrics, etc.
  • Training/Validation: In the training/validation procedure, you can refer the training process in the train.py and the validation process in the test.py.
  • Results saving: The model will automatically save the state_dict for networks, optimizers and other hyperparameters during the training.

The configuration of the framework is down by Register in the registry.py. The Register has a object map (key-value pair). The key is the name of the object, the value is the class of the object. There are total 4 different registers for networks, datasets, losses and metrics. Here is an example to register a new network:

import torch
import torch.nn as nn

from utils.registry import NETWORK_REGISTRY

@NETWORK_REGISTRY.register()
class MyNet(nn.Module):
  ...
Owner
Dongliang Cao
Dongliang Cao
Image-retrieval-baseline - MUGE Multimodal Retrieval Baseline

MUGE Multimodal Retrieval Baseline This repo is implemented based on the open_cl

47 Dec 16, 2022
PyTorch implementation of Self-supervised Contrastive Regularization for DG (SelfReg)

SelfReg PyTorch official implementation of Self-supervised Contrastive Regularization for Domain Generalization (SelfReg, https://arxiv.org/abs/2104.0

64 Dec 16, 2022
A Transformer-Based Feature Segmentation and Region Alignment Method For UAV-View Geo-Localization

University1652-Baseline [Paper] [Slide] [Explore Drone-view Data] [Explore Satellite-view Data] [Explore Street-view Data] [Video Sample] [中文介绍] This

Zhedong Zheng 335 Jan 06, 2023
Cards Against Humanity AI

cah-ai This is a Cards Against Humanity AI implemented using a pre-trained Semantic Search model. How it works A player is described by a combination

Alex Nichol 2 Aug 22, 2022
structured-generative-modeling

This repository contains the implementation for the paper Information Theoretic StructuredGenerative Modeling, Specially thanks for the open-source co

0 Oct 11, 2021
[ICCV 2021 Oral] NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo

NerfingMVS Project Page | Paper | Video | Data NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo Yi Wei, Shaohui

Yi Wei 369 Dec 24, 2022
Generalized and Efficient Blackbox Optimization System.

OpenBox Doc | OpenBox中文文档 OpenBox: Generalized and Efficient Blackbox Optimization System OpenBox is an efficient and generalized blackbox optimizatio

DAIR Lab 238 Dec 29, 2022
Official implementation of "Open-set Label Noise Can Improve Robustness Against Inherent Label Noise" (NeurIPS 2021)

Open-set Label Noise Can Improve Robustness Against Inherent Label Noise NeurIPS 2021: This repository is the official implementation of ODNL. Require

Hongxin Wei 12 Dec 07, 2022
The official PyTorch implementation for NCSNv2 (NeurIPS 2020)

Improved Techniques for Training Score-Based Generative Models This repo contains the official implementation for the paper Improved Techniques for Tr

174 Dec 26, 2022
Extremely easy multi instancing software for minecraft speedrunning.

Easy Multi Extremely easy multi/single instancing software for minecraft speedrunning. A couple of goals of this project: Setup multi in minutes No fi

Duncan 8 Jul 16, 2022
Python PID Tuner - Makes a model of the System from a Process Reaction Curve and calculates PID Gains

PythonPID_Tuner_SOPDT Step 1: Takes a Process Reaction Curve in csv format - assumes data at 100ms interval (column names CV and PV) Step 2: Makes a r

1 Jan 18, 2022
Final report with code for KAIST Course KSE 801.

Orthogonal collocation is a method for the numerical solution of partial differential equations

Chuanbo HUA 4 Apr 06, 2022
Code for Efficient Visual Pretraining with Contrastive Detection

Code for DetCon This repository contains code for the ICCV 2021 paper "Efficient Visual Pretraining with Contrastive Detection" by Olivier J. Hénaff,

DeepMind 56 Nov 13, 2022
Implements a fake news detection program using classifiers.

Fake news detection Implements a fake news detection program using classifiers for Data Mining course at UoA. Description The project is the categoriz

Apostolos Karvelas 1 Jan 09, 2022
A PyTorch implementation of "SimGNN: A Neural Network Approach to Fast Graph Similarity Computation" (WSDM 2019).

SimGNN ⠀⠀⠀ A PyTorch implementation of SimGNN: A Neural Network Approach to Fast Graph Similarity Computation (WSDM 2019). Abstract Graph similarity s

Benedek Rozemberczki 534 Dec 25, 2022
Capstone-Project-2 - A game program written in the Python language

Capstone-Project-2 My Pygame Game Information: Description This Pygame project i

Nhlakanipho Khulekani Hlophe 1 Jan 04, 2022
The official pytorch implementation of our paper "Is Space-Time Attention All You Need for Video Understanding?"

TimeSformer This is an official pytorch implementation of Is Space-Time Attention All You Need for Video Understanding?. In this repository, we provid

Facebook Research 1k Dec 31, 2022
SIEM Logstash parsing for more than hundred technologies

LogIndexer Pipeline Logstash Parsing Configurations for Elastisearch SIEM and OpenDistro for Elasticsearch SIEM Why this project exists The overhead o

146 Dec 29, 2022
CenterFace(size of 7.3MB) is a practical anchor-free face detection and alignment method for edge devices.

CenterFace Introduce CenterFace(size of 7.3MB) is a practical anchor-free face detection and alignment method for edge devices. Recent Update 2019.09.

StarClouds 1.2k Dec 21, 2022
Official code for the publication "HyFactor: Hydrogen-count labelled graph-based defactorization Autoencoder".

HyFactor Graph-based architectures are becoming increasingly popular as a tool for structure generation. Here, we introduce a novel open-source archit

Laboratoire-de-Chemoinformatique 11 Oct 10, 2022