Optimized code based on M2 for faster image captioning training

Overview

Transformer Captioning

This repository contains the code for Transformer-based image captioning. Based on meshed-memory-transformer, we further optimize the code for FASTER training without any accuracy decline.

Specifically, we optimize following aspects:

  • vocab: we pre-tokenize the dataset so there are no ' '(space token) in vocab or generated sentences.
  • Dataloader: we optimize speed of dataloader and achieve 2x~6x speed-up.
  • BeamSearch:
    • Make ops parallel in beam_search.py (e.g. loop gather -> parallel gather)
    • Use cheaper ops (e.g. torch.sort -> torch.topk)
    • Use faster and specialized functions instead of general ones
  • Self-critical Training
    • Compute Cider by index instead of raw text
    • Cache tf-idf vector of gts instead of computing it again and again
    • drop on-the-fly tokenization since it is too SLOW.
  • contiguous model parameter
  • other details...

speed-up result (1 GeForce 1080Ti GPU, num_workers=8, batch_size=50(XE)/100(SCST))

Training its/s Original Optimized Accelerate
XE 7.5 10.3 138%
SCST 0.6 1.3 204%
Dataloader its/s Original XE Optimized XE Accelerate Original SCST Optimized SCST Accelerate
batch size=50 12.5 52.5 320% 29.3 90.7 209%
batch size=100 5.5 33.5 510% 22.3 88.5 297%
batch size=150 3.7 25.4 580% 13.4 71.8 435%
batch size=200 2.7 20.1 650% 11.4 54.1 376%

Things I have tried but not useful

  • TorchText n-gram counter: slower than the original one.
  • nn.Module.MultiHeadAttention: slightly faster than original one.
  • GPU cider: very slow
  • BeamableMM: slower than the original

Environment setup

Clone the repository and create the m2release conda environment using the environment.yml file:

conda env create -f environment.yml
conda activate m2release

Then download spacy data by executing the following command:

python -m spacy download en

Note: Python 3.6 is required to run our code.

Data preparation

To run the code, annotations and detection features for the COCO dataset are needed. Please download the annotations file annotations.zip and extract it.

Detection features are computed with the code provided by [1]. To reproduce our result, please download the COCO features file coco_detections.hdf5 (~53.5 GB), in which detections of each image are stored under the <image_id>_features key. <image_id> is the id of each COCO image, without leading zeros (e.g. the <image_id> for COCO_val2014_000000037209.jpg is 37209), and each value should be a (N, 2048) tensor, where N is the number of detections.

REMEMBER to do pre-tokenize

python pre_tokenize.py

Evaluation

Run python test.py using the following arguments:

Argument Possible values
--batch_size Batch size (default: 10)
--workers Number of workers (default: 0)
--features_path Path to detection features file
--annotation_folder Path to folder with COCO annotations

Training procedure

Run python train.py using the following arguments:

Argument Possible values
--exp_name Experiment name
--batch_size Batch size (default: 10)
--workers Number of workers (default: 0)
--head Number of heads (default: 8)
--resume_last If used, the training will be resumed from the last checkpoint.
--resume_best If used, the training will be resumed from the best checkpoint.
--features_path Path to detection features file
--annotation_folder Path to folder with COCO annotations
--logs_folder Path folder for tensorboard logs (default: "tensorboard_logs")

For example, to train our model with the parameters used in our experiments, use

We recommend to use batch size=100 during SCST stage. Since it will accelerate convergence without obvious accuracy decline

python train.py --exp_name test --batch_size 50 --head 8 --features_path ~/datassd/coco_detections.hdf5 --annotation_folder annotation --workers 8 --rl_batch_size 100 --image_field FasterImageDetectionsField --model transformer --seed 118

References

Owner
lyricpoem
lyricpoem
PyTorch code for Composing Partial Differential Equations with Physics-Aware Neural Networks

FInite volume Neural Network (FINN) This repository contains the PyTorch code for models, training, and testing, and Python code for data generation t

Cognitive Modeling 20 Dec 18, 2022
Keras code and weights files for popular deep learning models.

Trained image classification models for Keras THIS REPOSITORY IS DEPRECATED. USE THE MODULE keras.applications INSTEAD. Pull requests will not be revi

François Chollet 7.2k Dec 29, 2022
BigDetection: A Large-scale Benchmark for Improved Object Detector Pre-training

BigDetection: A Large-scale Benchmark for Improved Object Detector Pre-training By Likun Cai, Zhi Zhang, Yi Zhu, Li Zhang, Mu Li, Xiangyang Xue. This

290 Dec 29, 2022
End-to-end face detection, cropping, norm estimation, and landmark detection in a single onnx model

onnx-facial-lmk-detector End-to-end face detection, cropping, norm estimation, and landmark detection in a single onnx model, model.onnx. Demo You can

atksh 42 Dec 30, 2022
Tensorflow implementation of "BEGAN: Boundary Equilibrium Generative Adversarial Networks"

BEGAN in Tensorflow Tensorflow implementation of BEGAN: Boundary Equilibrium Generative Adversarial Networks. Requirements Python 2.7 or 3.x Pillow tq

Taehoon Kim 922 Dec 21, 2022
An end-to-end machine learning web app to predict rugby scores (Pandas, SQLite, Keras, Flask, Docker)

Rugby score prediction An end-to-end machine learning web app to predict rugby scores Overview An demo project to provide a high-level overview of the

34 May 24, 2022
Code for CVPR2019 Towards Natural and Accurate Future Motion Prediction of Humans and Animals

Motion prediction with Hierarchical Motion Recurrent Network Introduction This work concerns motion prediction of articulate objects such as human, fi

Shuang Wu 85 Dec 11, 2022
Simple and Effective Few-Shot Named Entity Recognition with Structured Nearest Neighbor Learning

structshot Code and data for paper "Simple and Effective Few-Shot Named Entity Recognition with Structured Nearest Neighbor Learning", Yi Yang and Arz

ASAPP Research 47 Dec 27, 2022
An University Project of Quera Web Crawling.

WebCrawlerProject An University Project of Quera Web Crawling. خزشگر اینستاگرام در این پروژه شما باید با استفاده از کتابخانه های زیر یک خزشگر اینستاگر

Mahdi 3 Aug 12, 2022
learned_optimization: Training and evaluating learned optimizers in JAX

learned_optimization: Training and evaluating learned optimizers in JAX learned_optimization is a research codebase for training learned optimizers. I

Google 533 Dec 30, 2022
Code for testing various M1 Chip benchmarks with TensorFlow.

M1, M1 Pro, M1 Max Machine Learning Speed Test Comparison This repo contains some sample code to benchmark the new M1 MacBooks (M1 Pro and M1 Max) aga

Daniel Bourke 348 Jan 04, 2023
YolactEdge: Real-time Instance Segmentation on the Edge

YolactEdge, the first competitive instance segmentation approach that runs on small edge devices at real-time speeds. Specifically, YolactEdge runs at up to 30.8 FPS on a Jetson AGX Xavier (and 172.7

Haotian Liu 1.1k Jan 06, 2023
An educational AI robot based on NVIDIA Jetson Nano.

JetBot Looking for a quick way to get started with JetBot? Many third party kits are now available! JetBot is an open-source robot based on NVIDIA Jet

NVIDIA AI IOT 2.6k Dec 29, 2022
BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalanced Tongue Data

Balanced-Evolutionary-Semi-Stacking Code for the paper ''BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalan

0 Jan 16, 2022
My usage of Real-ESRGAN to upscale anime, some test and results in the test_img folder

anime upscaler My usage of Real-ESRGAN to upscale anime, I hope to use this on a proper GPU cuz doing this on CPU is completely shit 😂 , I even tried

Shangar Muhunthan 29 Jan 07, 2023
TextureGAN in Pytorch

TextureGAN This code is our PyTorch implementation of TextureGAN [Project] [Arxiv] TextureGAN is a generative adversarial network conditioned on sketc

Patsorn 147 Dec 14, 2022
Supplementary code for the AISTATS 2021 paper "Matern Gaussian Processes on Graphs".

Matern Gaussian Processes on Graphs This repo provides an extension for gpflow with Matérn kernels, inducing variables and trainable models implemente

41 Dec 17, 2022
Whisper is a file-based time-series database format for Graphite.

Whisper Overview Whisper is one of three components within the Graphite project: Graphite-Web, a Django-based web application that renders graphs and

Graphite Project 1.2k Dec 25, 2022
Pytorch implementation of the paper "Optimization as a Model for Few-Shot Learning"

Optimization as a Model for Few-Shot Learning This repo provides a Pytorch implementation for the Optimization as a Model for Few-Shot Learning paper.

Albert Berenguel Centeno 238 Jan 04, 2023
StyleGAN2 - Official TensorFlow Implementation

StyleGAN2 - Official TensorFlow Implementation

NVIDIA Research Projects 10.1k Dec 28, 2022