Differentiable architecture search for convolutional and recurrent networks

Overview

Differentiable Architecture Search

Code accompanying the paper

DARTS: Differentiable Architecture Search
Hanxiao Liu, Karen Simonyan, Yiming Yang.
arXiv:1806.09055.

darts

The algorithm is based on continuous relaxation and gradient descent in the architecture space. It is able to efficiently design high-performance convolutional architectures for image classification (on CIFAR-10 and ImageNet) and recurrent architectures for language modeling (on Penn Treebank and WikiText-2). Only a single GPU is required.

Requirements

Python >= 3.5.5, PyTorch == 0.3.1, torchvision == 0.2.0

NOTE: PyTorch 0.4 is not supported at this moment and would lead to OOM.

Datasets

Instructions for acquiring PTB and WT2 can be found here. While CIFAR-10 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded (preferably to a SSD) following the instructions here.

Pretrained models

The easist way to get started is to evaluate our pretrained DARTS models.

CIFAR-10 (cifar10_model.pt)

cd cnn && python test.py --auxiliary --model_path cifar10_model.pt
  • Expected result: 2.63% test error rate with 3.3M model params.

PTB (ptb_model.pt)

cd rnn && python test.py --model_path ptb_model.pt
  • Expected result: 55.68 test perplexity with 23M model params.

ImageNet (imagenet_model.pt)

cd cnn && python test_imagenet.py --auxiliary --model_path imagenet_model.pt
  • Expected result: 26.7% top-1 error and 8.7% top-5 error with 4.7M model params.

Architecture search (using small proxy models)

To carry out architecture search using 2nd-order approximation, run

cd cnn && python train_search.py --unrolled     # for conv cells on CIFAR-10
cd rnn && python train_search.py --unrolled     # for recurrent cells on PTB

Note the validation performance in this step does not indicate the final performance of the architecture. One must train the obtained genotype/architecture from scratch using full-sized models, as described in the next section.

Also be aware that different runs would end up with different local minimum. To get the best result, it is crucial to repeat the search process with different seeds and select the best cell(s) based on validation performance (obtained by training the derived cell from scratch for a small number of epochs). Please refer to fig. 3 and sect. 3.2 in our arXiv paper.

progress_convolutional_normal progress_convolutional_reduce progress_recurrent

Figure: Snapshots of the most likely normal conv, reduction conv, and recurrent cells over time.

Architecture evaluation (using full-sized models)

To evaluate our best cells by training from scratch, run

cd cnn && python train.py --auxiliary --cutout            # CIFAR-10
cd rnn && python train.py                                 # PTB
cd rnn && python train.py --data ../data/wikitext-2 \     # WT2
            --dropouth 0.15 --emsize 700 --nhidlast 700 --nhid 700 --wdecay 5e-7
cd cnn && python train_imagenet.py --auxiliary            # ImageNet

Customized architectures are supported through the --arch flag once specified in genotypes.py.

The CIFAR-10 result at the end of training is subject to variance due to the non-determinism of cuDNN back-prop kernels. It would be misleading to report the result of only a single run. By training our best cell from scratch, one should expect the average test error of 10 independent runs to fall in the range of 2.76 +/- 0.09% with high probability.

cifar10 ptb ptb

Figure: Expected learning curves on CIFAR-10 (4 runs), ImageNet and PTB.

Visualization

Package graphviz is required to visualize the learned cells

python visualize.py DARTS

where DARTS can be replaced by any customized architectures in genotypes.py.

Citation

If you use any part of this code in your research, please cite our paper:

@article{liu2018darts,
  title={DARTS: Differentiable Architecture Search},
  author={Liu, Hanxiao and Simonyan, Karen and Yang, Yiming},
  journal={arXiv preprint arXiv:1806.09055},
  year={2018}
}
Owner
Hanxiao Liu
Research Scientist @ Google Brain
Hanxiao Liu
Robustness via Cross-Domain Ensembles

Robustness via Cross-Domain Ensembles [ICCV 2021, Oral] This repository contains tools for training and evaluating: Pretrained models Demo code Traini

Visual Intelligence & Learning Lab, Swiss Federal Institute of Technology (EPFL) 27 Dec 23, 2022
use machine learning to recognize gesture on raspberrypi

Raspberrypi_Gesture-Recognition use machine learning to recognize gesture on raspberrypi 說明 利用 tensorflow lite 訓練手部辨識模型 分辨 "剪刀"、"石頭"、"布" 之手勢 再將訓練模型匯入

1 Dec 10, 2021
Algorithmic trading with deep learning experiments

Deep-Trading Algorithmic trading with deep learning experiments. Now released part one - simple time series forecasting. I plan to implement more soph

Alex Honchar 1.4k Jan 02, 2023
📝 Wrapper library for text generation / language models at char and word level with RNN in TensorFlow

tensorlm Generate Shakespeare poems with 4 lines of code. Installation tensorlm is written in / for Python 3.4+ and TensorFlow 1.1+ pip3 install tenso

Kilian Batzner 63 May 22, 2021
PyTorch implementation of the paper Deep Networks from the Principle of Rate Reduction

Deep Networks from the Principle of Rate Reduction This repository is the official PyTorch implementation of the paper Deep Networks from the Principl

459 Dec 27, 2022
Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Tianyang Li 1 Jan 06, 2022
It's a implement of this paper:Relation extraction via Multi-Level attention CNNs

Relation Classification via Multi-Level Attention CNNs It's a implement of this paper:Relation Classification via Multi-Level Attention CNNs. Training

Aybss 2 Nov 04, 2022
A collection of models for image<->text generation in ACM MM 2021.

Bi-directional Image and Text Generation UMT-BITG (image & text generator) Unifying Multimodal Transformer for Bi-directional Image and Text Generatio

Multimedia Research 63 Oct 30, 2022
PyTorch implementation of Constrained Policy Optimization

PyTorch implementation of Constrained Policy Optimization (CPO) This repository has a simple to understand and use implementation of CPO in PyTorch. A

Sapana Chaudhary 25 Dec 08, 2022
LeViT a Vision Transformer in ConvNet's Clothing for Faster Inference

LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference This repository contains PyTorch evaluation code, training code and pretrained

Facebook Research 504 Jan 02, 2023
NP DRAW paper released code

NP-DRAW: A Non-Parametric Structured Latent Variable Model for Image Generation This repo contains the official implementation for the NP-DRAW paper.

ZENG Xiaohui 22 Mar 13, 2022
StyleGAN2 with adaptive discriminator augmentation (ADA) - Official TensorFlow implementation

StyleGAN2 with adaptive discriminator augmentation (ADA) — Official TensorFlow implementation Training Generative Adversarial Networks with Limited Da

NVIDIA Research Projects 1.7k Dec 29, 2022
Learning Dynamic Network Using a Reuse Gate Function in Semi-supervised Video Object Segmentation.

Training Script for Reuse-VOS This code implementation of CVPR 2021 paper : Learning Dynamic Network Using a Reuse Gate Function in Semi-supervised Vi

HYOJINPARK 22 Jan 01, 2023
A simple program for training and testing vit

Vit This is a simple program for training and testing vit. Key requirements: torch, torchvision and timm. Dataset I put 5 categories of the cub classi

xiezhenyu 2 Oct 11, 2022
RADIal is available now! Check the download section

Latest news: RADIal is available now! Check the download section. However, because we are currently working on the data anonymization, we provide for

valeo.ai 55 Jan 03, 2023
The description of FMFCC-A (audio track of FMFCC) dataset and Challenge resluts.

FMFCC-A This project is the description of FMFCC-A (audio track of FMFCC) dataset and Challenge resluts. The FMFCC-A dataset is shared through BaiduCl

18 Dec 24, 2022
Object Detection Projekt in GKI WS2021/22

tfObjectDetection Object Detection Projekt with tensorflow in GKI WS2021/22 Docker Container: docker run -it --name --gpus all -v path/to/project:p

Tim Eggers 1 Jul 18, 2022
Resources for the Ki testnet challenge

Ki Testnet Challenge This repository hosts ki-testnet-challenge. A set of scripts and resources to be used for the Ki Testnet Challenge What is the te

Ki Foundation 23 Aug 08, 2022
SIR model parameter estimation using a novel algorithm for differentiated uniformization.

TenSIR Parameter estimation on epidemic data under the SIR model using a novel algorithm for differentiated uniformization of Markov transition rate m

The Spang Lab 4 Nov 30, 2022
A multilingual version of MS MARCO passage ranking dataset

mMARCO A multilingual version of MS MARCO passage ranking dataset This repository presents a neural machine translation-based method for translating t

75 Dec 27, 2022