A best practice for tensorflow project template architecture.

Overview

Tensorflow Project Template

A simple and well designed structure is essential for any Deep Learning project, so after a lot of practice and contributing in tensorflow projects here's a tensorflow project template that combines simplcity, best practice for folder structure and good OOP design. The main idea is that there's much stuff you do every time you start your tensorflow project, so wrapping all this shared stuff will help you to change just the core idea every time you start a new tensorflow project.

So, here's a simple tensorflow template that help you get into your main project faster and just focus on your core (Model, Training, ...etc)

Table Of Contents

In a Nutshell

In a nutshell here's how to use this template, so for example assume you want to implement VGG model so you should do the following:

  • In models folder create a class named VGG that inherit the "base_model" class
    class VGGModel(BaseModel):
        def __init__(self, config):
            super(VGGModel, self).__init__(config)
            #call the build_model and init_saver functions.
            self.build_model() 
            self.init_saver() 
  • Override these two functions "build_model" where you implement the vgg model, and "init_saver" where you define a tensorflow saver, then call them in the initalizer.
     def build_model(self):
        # here you build the tensorflow graph of any model you want and also define the loss.
        pass
            
     def init_saver(self):
        # here you initalize the tensorflow saver that will be used in saving the checkpoints.
        self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
  • In trainers folder create a VGG trainer that inherit from "base_train" class
    class VGGTrainer(BaseTrain):
        def __init__(self, sess, model, data, config, logger):
            super(VGGTrainer, self).__init__(sess, model, data, config, logger)
  • Override these two functions "train_step", "train_epoch" where you write the logic of the training process
    def train_epoch(self):
        """
       implement the logic of epoch:
       -loop on the number of iterations in the config and call the train step
       -add any summaries you want using the summary
        """
        pass

    def train_step(self):
        """
       implement the logic of the train step
       - run the tensorflow session
       - return any metrics you need to summarize
       """
        pass
  • In main file, you create the session and instances of the following objects "Model", "Logger", "Data_Generator", "Trainer", and config
    sess = tf.Session()
    # create instance of the model you want
    model = VGGModel(config)
    # create your data generator
    data = DataGenerator(config)
    # create tensorboard logger
    logger = Logger(sess, config)
  • Pass the all these objects to the trainer object, and start your training by calling "trainer.train()"
    trainer = VGGTrainer(sess, model, data, config, logger)

    # here you train your model
    trainer.train()

You will find a template file and a simple example in the model and trainer folder that shows you how to try your first model simply.

In Details

Project architecture

Folder structure

├──  base
│   ├── base_model.py   - this file contains the abstract class of the model.
│   └── base_train.py   - this file contains the abstract class of the trainer.
│
│
├── model               - this folder contains any model of your project.
│   └── example_model.py
│
│
├── trainer             - this folder contains trainers of your project.
│   └── example_trainer.py
│   
├──  mains              - here's the main(s) of your project (you may need more than one main).
│    └── example_main.py  - here's an example of main that is responsible for the whole pipeline.

│  
├──  data _loader  
│    └── data_generator.py  - here's the data_generator that is responsible for all data handling.
│ 
└── utils
     ├── logger.py
     └── any_other_utils_you_need

Main Components

Models


  • Base model

    Base model is an abstract class that must be Inherited by any model you create, the idea behind this is that there's much shared stuff between all models. The base model contains:

    • Save -This function to save a checkpoint to the desk.
    • Load -This function to load a checkpoint from the desk.
    • Cur_epoch, Global_step counters -These variables to keep track of the current epoch and global step.
    • Init_Saver An abstract function to initialize the saver used for saving and loading the checkpoint, Note: override this function in the model you want to implement.
    • Build_model Here's an abstract function to define the model, Note: override this function in the model you want to implement.
  • Your model

    Here's where you implement your model. So you should :

    • Create your model class and inherit the base_model class
    • override "build_model" where you write the tensorflow model you want
    • override "init_save" where you create a tensorflow saver to use it to save and load checkpoint
    • call the "build_model" and "init_saver" in the initializer.

Trainer


  • Base trainer

    Base trainer is an abstract class that just wrap the training process.

  • Your trainer

    Here's what you should implement in your trainer.

    1. Create your trainer class and inherit the base_trainer class.
    2. override these two functions "train_step", "train_epoch" where you implement the training process of each step and each epoch.

Data Loader

This class is responsible for all data handling and processing and provide an easy interface that can be used by the trainer.

Logger

This class is responsible for the tensorboard summary, in your trainer create a dictionary of all tensorflow variables you want to summarize then pass this dictionary to logger.summarize().

This class also supports reporting to Comet.ml which allows you to see all your hyper-params, metrics, graphs, dependencies and more including real-time metric. Add your API key in the configuration file:

For example: "comet_api_key": "your key here"

Comet.ml Integration

This template also supports reporting to Comet.ml which allows you to see all your hyper-params, metrics, graphs, dependencies and more including real-time metric.

Add your API key in the configuration file:

For example: "comet_api_key": "your key here"

Here's how it looks after you start training:

You can also link your Github repository to your comet.ml project for full version control. Here's a live page showing the example from this repo

Configuration

I use Json as configuration method and then parse it, so write all configs you want then parse it using "utils/config/process_config" and pass this configuration object to all other objects.

Main

Here's where you combine all previous part.

  1. Parse the config file.
  2. Create a tensorflow session.
  3. Create an instance of "Model", "Data_Generator" and "Logger" and parse the config to all of them.
  4. Create an instance of "Trainer" and pass all previous objects to it.
  5. Now you can train your model by calling "Trainer.train()"

Future Work

  • Replace the data loader part with new tensorflow dataset API.

Contributing

Any kind of enhancement or contribution is welcomed.

Acknowledgments

Thanks for my colleague Mo'men Abdelrazek for contributing in this work. and thanks for Mohamed Zahran for the review. Thanks for Jtoy for including the repo in Awesome Tensorflow.

Owner
Mahmoud Gamal Salem
MSc. in AI at university of Guelph and Vector Institute. AI intern @samsung
Mahmoud Gamal Salem
Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector

HackED 2022 Team 3IQ - 2022 Imposter Detector By Aneeljyot Alagh, Curtis Kan, Jo

Joshua Ji 3 Aug 20, 2022
For auto aligning, cropping, and scaling HR and LR images for training image based neural networks

ImgAlign For auto aligning, cropping, and scaling HR and LR images for training image based neural networks Usage Make sure OpenCV is installed, 'pip

15 Dec 04, 2022
Employs neural networks to classify images into four categories: ship, automobile, dog or frog

Neural Net Image Classifier Employs neural networks to classify images into four categories: ship, automobile, dog or frog Viterbi_1.py uses a classic

Riley Baker 1 Jan 18, 2022
Parametric Contrastive Learning (ICCV2021)

Parametric-Contrastive-Learning This repository contains the implementation code for ICCV2021 paper: Parametric Contrastive Learning (https://arxiv.or

DV Lab 156 Dec 21, 2022
List of papers, code and experiments using deep learning for time series forecasting

Deep Learning Time Series Forecasting List of state of the art papers focus on deep learning and resources, code and experiments using deep learning f

Alexander Robles 2k Jan 06, 2023
NOD: Taking a Closer Look at Detection under Extreme Low-Light Conditions with Night Object Detection Dataset

NOD (Night Object Detection) Dataset NOD: Taking a Closer Look at Detection under Extreme Low-Light Conditions with Night Object Detection Dataset, BM

Igor Morawski 17 Nov 05, 2022
Self-describing JSON-RPC services made easy

ReflectRPC Self-describing JSON-RPC services made easy Contents What is ReflectRPC? Installation Features Datatypes Custom Datatypes Returning Errors

Andreas Heck 31 Jul 16, 2022
Human motion synthesis using Unity3D

Human motion synthesis using Unity3D Prerequisite: Software: amc2bvh.exe, Unity 2017, Blender. Unity: RockVR (Video Capture), scenes, character models

Hao Xu 9 Jun 01, 2022
Accelerated Multi-Modal MR Imaging with Transformers

Accelerated Multi-Modal MR Imaging with Transformers Dependencies numpy==1.18.5 scikit_image==0.16.2 torchvision==0.8.1 torch==1.7.0 runstats==1.8.0 p

54 Dec 16, 2022
MPI-IS Mesh Processing Library

Perceiving Systems Mesh Package This package contains core functions for manipulating meshes and visualizing them. It requires Python 3.5+ and is supp

Max Planck Institute for Intelligent Systems 494 Jan 06, 2023
Find the Heart simple Python Game

This is a simple Python game for finding a heart emoji. There is a 3 x 3 matrix in which a heart emoji resides. The location of the heart is randomized and is not revealed. The player must guess the

p.katekomol 1 Jan 24, 2022
Semi-automated OpenVINO benchmark_app with variable parameters

Semi-automated OpenVINO benchmark_app with variable parameters. User can specify multiple options for any parameters in the benchmark_app and the progam runs the benchmark with all combinations of gi

Yasunori Shimura 8 Apr 11, 2022
BMVC 2021: This is the github repository for "Few Shot Temporal Action Localization using Query Adaptive Transformers" accepted in British Machine Vision Conference (BMVC) 2021, Virtual

FS-QAT: Few Shot Temporal Action Localization using Query Adaptive Transformer Accepted as Poster in BMVC 2021 This is an official implementation in P

Sauradip Nag 14 Dec 09, 2022
Tracking code for the winner of track 1 in the MMP-Tracking Challenge at ICCV 2021 Workshop.

Tracking Code for the winner of track1 in MMP-Trakcing challenge This repository contains our tracking code for the Multi-camera Multiple People Track

DamoCV 29 Nov 13, 2022
Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary Differential Equations

ODE GAN (Prototype) in PyTorch Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary

Somshubra Majumdar 15 Feb 10, 2022
PyTorch Implementation of Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis

Daft-Exprt - PyTorch Implementation PyTorch Implementation of Daft-Exprt: Robust Prosody Transfer Across Speakers for Expressive Speech Synthesis The

Keon Lee 47 Dec 18, 2022
Code for the Shortformer model, from the paper by Ofir Press, Noah A. Smith and Mike Lewis.

Shortformer This repository contains the code and the final checkpoint of the Shortformer model. This file explains how to run our experiments on the

Ofir Press 138 Apr 15, 2022
Zeyuan Chen, Yangchao Wang, Yang Yang and Dong Liu.

Principled S2R Dehazing This repository contains the official implementation for PSD Framework introduced in the following paper: PSD: Principled Synt

zychen 78 Dec 30, 2022
An educational AI robot based on NVIDIA Jetson Nano.

JetBot Looking for a quick way to get started with JetBot? Many third party kits are now available! JetBot is an open-source robot based on NVIDIA Jet

NVIDIA AI IOT 2.6k Dec 29, 2022
Pyramid Grafting Network for One-Stage High Resolution Saliency Detection. CVPR 2022

PGNet Pyramid Grafting Network for One-Stage High Resolution Saliency Detection. CVPR 2022, CVPR 2022 (arXiv 2204.05041) Abstract Recent salient objec

CVTEAM 109 Dec 05, 2022