A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

Overview

CapsNet-Tensorflow

Contributions welcome License Gitter

A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

capsVSneuron

Notes:

  1. The current version supports MNIST and Fashion-MNIST datasets. The current test accuracy for MNIST is 99.64%, and Fashion-MNIST 90.60%, see details in the Results section
  2. See dist_version for multi-GPU support
  3. Here(知乎) is an article explaining my understanding of the paper. It may be helpful in understanding the code.

Important:

If you need to apply CapsNet model to your own datasets or build up a new model with the basic block of CapsNet, please follow my new project CapsLayer, which is an advanced library for capsule theory, aiming to integrate capsule-relevant technologies, provide relevant analysis tools, develop related application examples, and promote the development of capsule theory. For example, you can use capsule layer block in your code easily with the API capsLayer.layers.fully_connected and capsLayer.layers.conv2d

Requirements

  • Python
  • NumPy
  • Tensorflow>=1.3
  • tqdm (for displaying training progress info)
  • scipy (for saving images)

Usage

Step 1. Download this repository with git or click the download ZIP button.

$ git clone https://github.com/naturomics/CapsNet-Tensorflow.git
$ cd CapsNet-Tensorflow

Step 2. Download MNIST or Fashion-MNIST dataset. In this step, you have two choices:

  • a) Automatic downloading with download_data.py script
$ python download_data.py   (for mnist dataset)
$ python download_data.py --dataset fashion-mnist --save_to data/fashion-mnist (for fashion-mnist dataset)
  • b) Manual downloading with wget or other tools, move and extract dataset into data/mnist or data/fashion-mnist directory, for example:
$ mkdir -p data/mnist
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
$ gunzip data/mnist/*.gz

Step 3. Start the training(Using the MNIST dataset by default):

$ python main.py
$ # or training for fashion-mnist dataset
$ python main.py --dataset fashion-mnist
$ # If you need to monitor the training process, open tensorboard with this command
$ tensorboard --logdir=logdir
$ # or use `tail` command on linux system
$ tail -f results/val_acc.csv

Step 4. Calculate test accuracy

$ python main.py --is_training=False
$ # for fashion-mnist dataset
$ python main.py --dataset fashion-mnist --is_training=False

Note: The default parameters of batch size is 128, and epoch 50. You may need to modify the config.py file or use command line parameters to suit your case, e.g. set batch size to 64 and do once test summary every 200 steps: python main.py --test_sum_freq=200 --batch_size=48

Results

The pictures here are plotted by tensorboard and my tool plot_acc.R

  • training loss

total_loss margin_loss reconstruction_loss

Here are the models I trained and my talk and something else:

Baidu Netdisk(password:ahjs)

  • The best val error(using reconstruction)
Routing iteration 1 3 4
val error 0.36 0.36 0.41
Paper 0.29 0.25 -

test_acc

My simple comments for capsule

  1. A new version neural unit(vector in vector out, not scalar in scalar out)
  2. The routing algorithm is similar to attention mechanism
  3. Anyway, a great potential work, a lot to be built upon

My weChat:

my_wechat

Reference

Owner
Huadong Liao
Explore Nature from an Omics Perspective
Huadong Liao
BOVText: A Large-Scale, Multidimensional Multilingual Dataset for Video Text Spotting

BOVText: A Large-Scale, Bilingual Open World Dataset for Video Text Spotting Updated on December 10, 2021 (Release all dataset(2021 videos)) Updated o

weijiawu 47 Dec 26, 2022
The official PyTorch implementation for NCSNv2 (NeurIPS 2020)

Improved Techniques for Training Score-Based Generative Models This repo contains the official implementation for the paper Improved Techniques for Tr

174 Dec 26, 2022
Rendering color and depth images for ShapeNet models.

Color & Depth Renderer for ShapeNet This library includes the tools for rendering multi-view color and depth images of ShapeNet models. Physically bas

Yinyu Nie 41 Dec 19, 2022
This repository introduces a short project about Transfer Learning for Classification of MRI Images.

Transfer Learning for MRI Images Classification This repository introduces a short project made during my stay at Neuromatch Summer School 2021. This

Oscar Guarnizo 3 Nov 15, 2022
A Pytorch Implementation of Source Data-free Domain Adaptation for a Faster R-CNN

A Pytorch Implementation of Source Data-free Domain Adaptation for a Faster R-CNN Please follow Faster R-CNN and DAF to complete the environment confi

2 Jan 12, 2022
implement of SwiftNet:Real-time Video Object Segmentation

SwiftNet The official PyTorch implementation of SwiftNet:Real-time Video Object Segmentation, which has been accepted by CVPR2021. Requirements Python

haochen wang 64 Dec 14, 2022
PyTorch code for 'Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning'

Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning This repository is for EMSRDPN introduced in the foll

7 Feb 10, 2022
Codes for 'Dual Parameterization of Sparse Variational Gaussian Processes'

Dual Parameterization of Sparse Variational Gaussian Processes Documentation | Notebooks | API reference Introduction This repository is the official

AaltoML 7 Dec 23, 2022
Finetune SSL models for MOS prediction

Finetune SSL models for MOS prediction This is code for our paper under review for ICASSP 2022: "Generalization Ability of MOS Prediction Networks" Er

Yamagishi and Echizen Laboratories, National Institute of Informatics 32 Nov 22, 2022
A Flow-based Generative Network for Speech Synthesis

WaveGlow: a Flow-based Generative Network for Speech Synthesis Ryan Prenger, Rafael Valle, and Bryan Catanzaro In our recent paper, we propose WaveGlo

NVIDIA Corporation 2k Dec 26, 2022
Le dataset des images du projet d'IA de 2021

face-mask-dataset-ilc-2021 Le dataset des images du projet d'IA de 2021, Indiquez vos id git dans la issue pour les droits TL;DR: Choisir 200 images J

7 Nov 15, 2021
An end-to-end machine learning library to directly optimize AUC loss

LibAUC An end-to-end machine learning library for AUC optimization. Why LibAUC? Deep AUC Maximization (DAM) is a paradigm for learning a deep neural n

Andrew 75 Dec 12, 2022
Python Jupyter kernel using Poetry for reproducible notebooks

Poetry Kernel Use per-directory Poetry environments to run Jupyter kernels. No need to install a Jupyter kernel per Python virtual environment! The id

Pathbird 204 Jan 04, 2023
Reviatalizing Optimization for 3D Human Pose and Shape Estimation: A Sparse Constrained Formulation

Reviatalizing Optimization for 3D Human Pose and Shape Estimation: A Sparse Constrained Formulation This is the implementation of the approach describ

Taosha Fan 47 Nov 15, 2022
PyTorch implementation of Histogram Layers from DeepHist: Differentiable Joint and Color Histogram Layers for Image-to-Image Translation

deep-hist PyTorch implementation of Histogram Layers from DeepHist: Differentiable Joint and Color Histogram Layers for Image-to-Image Translation PyT

Winfried Lötzsch 10 Dec 06, 2022
The implementation of our CIKM 2021 paper titled as: "Cross-Market Product Recommendation"

FOREC: A Cross-Market Recommendation System This repository provides the implementation of our CIKM 2021 paper titled as "Cross-Market Product Recomme

Hamed Bonab 16 Sep 12, 2022
PyTorch implementation of our paper: Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition

Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition, arxiv This is a PyTorch implementation of our paper. 1. Re

DamoCV 11 Nov 19, 2022
How to train a CNN to 99% accuracy on MNIST in less than a second on a laptop

Training a NN to 99% accuracy on MNIST in 0.76 seconds A quick study on how fast you can reach 99% accuracy on MNIST with a single laptop. Our answer

Tuomas Oikarinen 42 Dec 10, 2022
Technical Analysis library in pandas for backtesting algotrading and quantitative analysis

bta-lib - A pandas based Technical Analysis Library bta-lib is pandas based technical analysis library and part of the backtrader family. Links Main P

DRo 393 Dec 20, 2022
Conversion between units used in magnetism

convmag Conversion between various units used in magnetism The conversions between base units available are: T - G : 1e4

0 Jul 15, 2021