CATE: Computation-aware Neural Architecture Encoding with Transformers

Overview

CATE: Computation-aware Neural Architecture Encoding with Transformers

Code for paper:

CATE: Computation-aware Neural Architecture Encoding with Transformers
Shen Yan, Kaiqiang Song, Fei Liu, Mi Zhang.
ICML 2021 (Long Talk).

CATE
Overview of CATE: It takes computationally similar architecture pairs as the input and trained to predict masked operators given the pairwise computation information. Apart from the cross-attention blocks, the pretrained Transformer encoder is used to extract architecture encodings for the downstream search.

The repository is built upon pybnn and nas-encodings.

Requirements

conda create -n tf python=3.7
source activate tf
cat requirements.txt | xargs -n 1 -L 1 pip install

Experiments on NAS-Bench-101

Dataset preparation on NAS-Bench-101

Install nasbench and download nasbench_only108.tfrecord in ./data folder.

python preprocessing/gen_json.py

Data will be saved in ./data/nasbench101.json.

Generate architecture pairs

python preprocessing/data_generate.py --dataset nasbench101 --flag extract_seq
python preprocessing/data_generate.py --dataset nasbench101 --flag build_pair --k 2 --d 2000000 --metric params

The corresponding training data and pairs will be saved in ./data/nasbench101/.

Alternatively, you can download the data train_data.pt, test_data.pt and pair indices train_pair_k2_d2000000_metric_params.pt, test_pair_k2_d2000000_metric_params.pt from here.

Pretraining

bash run_scripts/pretrain_nasbench101.sh

The pretrained models will be saved in ./model/.

Alternatively, you can download the pretrained model nasbench101_model_best.pth from here.

Extract the pretrained encodings

python inference/inference.py --pretrained_path model/nasbench101_model_best.pth.tar --train_data data/nasbench101/train_data.pt --valid_data data/nasbench101/test_data.pt --dataset nasbench101

The extracted embeddings will be saved in ./cate_nasbench101.pt.

Alternatively, you can download the pretrained embeddings cate_nasbench101.pt from here.

Run search experiments on NAS-Bench-101

bash run_scripts/run_search_nasbench101.sh

Search results will be saved in ./nasbench101/.

Experiments on NAS-Bench-301

Dataset preparation

Install nasbench301 and download the xgb_v1.0 and lgb_runtime_v1.0 file. You may need to make pytorch_geometric compatible with Pytorch and CUDA version.

python preprocessing/gen_json_darts.py # randomly sample 1,000,000 archs

Data will be saved in ./data/nasbench301_proxy.json.

Alternatively, you can download the json file nasbench301_proxy.json from here.

Generate architecture pairs

python preprocessing/data_generate.py --dataset nasbench301 --flag extract_seq
python preprocessing/data_generate.py --dataset nasbench301 --flag build_pair --k 1 --d 5000000 --metric flops

The correspoding training data and pairs will be saved in ./data/nasbench301/.

Alternatively, you can download the data train_data.pt, test_data.pt and pair indices train_pair_k1_d5000000_metric_flops.pt, test_pair_k1_d5000000_metric_flops.pt from here.

Pretraining

bash run_scripts/pretrain_nasbench301.sh

The pretrained models will be saved in ./model/.

Alternatively, you can download the pretrained model nasbench301_model_best.pth from here.

Extract the pretrained encodings

python inference/inference.py --pretrained_path model/nasbench301_model_best.pth.tar --train_data data/nasbench301/train_data.pt --valid_data data/nasbench301/test_data.pt --dataset nasbench301 --n_vocab 11

The extracted encodings will be saved in ./cate_nasbench301.pt.

Alternatively, you can download the pretrained embeddings cate_nasbench301.pt from here.

Run search experiments on NAS-Bench-301

bash run_scripts/run_search_nasbench301.sh

Search results will be saved in ./nasbench301/.

DARTS experiments without surrogate models

Download the pretrained embeddings cate_darts.pt from here.

python search_methods/dngo_ls_darts.py --dim 64 --init_size 16 --topk 5 --dataset darts --output_path bo  --embedding_path cate_darts.pt

Search log will be saved in ./darts/. Final search result will be saved in ./darts/bo/dim64.

Evaluate the learned cell on DARTS Search Space on CIFAR-10

python darts/cnn/train.py --auxiliary --cutout --arch cate_small
python darts/cnn/train.py --auxiliary --cutout --arch cate_large
  • Expected results (CATE-Small): 2.55% avg. test error with 3.5M model params.
  • Expected results (CATE-Large): 2.46% avg. test error with 4.1M model params.

Transfer learning on ImageNet

python darts/cnn/train_imagenet.py  --arch cate_small --seed 1 
python darts/cnn/train_imagenet.py  --arch cate_large --seed 1
  • Expected results (CATE-Small): 26.05% test error with 5.0M model params and 556M mult-adds.
  • Expected results (CATE-Large): 25.01% test error with 5.8M model params and 642M mult-adds.

Visualize the learned cell

python darts/cnn/visualize.py cate_small
python darts/cnn/visualize.py cate_large

Experiments on outside search space

Build outside search space dataset

bash run_scripts/generate_oo.sh

Data will be saved in ./data/nasbench101_oo_train.json and ./data/nasbench101_oo_test.json.

Generate architecture pairs

python preprocessing/data_generate_oo.py --flag extract_seq
python preprocessing/data_generate_oo.py --flag build_pair

The corresponding training data and pair indices will be saved in ./data/nasbench101/.

Pretraining

python run.py --do_train --parallel --train_data data/nasbench101/nasbench101_oo_trainSet_train.pt --train_pair data/nasbench101/oo_train_pairs_k2_params_dist2e6.pt  --valid_data data/nasbench101/nasbench101_oo_trainSet_validation.pt --valid_pair data/nasbench101/oo_validation_pairs_k2_params_dist2e6.pt --dataset oo

The pretrained models will be saved in ./model/.

Extract embeddings on outside search space

# Adjacency encoding
python inference/inference_adj.py
# CATE encoding
python inference/inference.py --pretrained_path model/oo_model_best.pth.tar --train_data data/nasbench101/nasbench101_oo_testSet_split1.pt --valid_data data/nasbench101/nasbench101_oo_testSet_split2.pt --dataset oo_nasbench101

The extracted encodings will be saved as ./adj_oo_nasbench101.pt and ./cate_oo_nasbench101.pt.

Alternatively, you can download the data, pair indices, pretrained models, and extracted embeddings from here.

Run MLP predictor experiments on outside search space

for s in {1..500}; do python search_methods/oo_mlp.py --dim 27 --seed $s --init_size 16 --topk 5 --dataset oo_nasbench101 --output_path np_adj  --embedding_path adj_oo_nasbench101.pt; done
for s in {1..500}; do python search_methods/oo_mlp.py --dim 64 --seed $s --init_size 16 --topk 5 --dataset oo_nasbench101 --output_path np_cate  --embedding_path cate_oo_nasbench101.pt; done

Search results will be saved in ./oo_nasbench101.

Citation

If you find this useful for your work, please consider citing:

@InProceedings{yan2021cate,
  title = {CATE: Computation-aware Neural Architecture Encoding with Transformers},
  author = {Yan, Shen and Song, Kaiqiang and Liu, Fei and Zhang, Mi},
  booktitle = {ICML},
  year = {2021}
}
Clustergram - Visualization and diagnostics for cluster analysis in Python

Clustergram Visualization and diagnostics for cluster analysis Clustergram is a diagram proposed by Matthias Schonlau in his paper The clustergram: A

Martin Fleischmann 96 Dec 26, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 03, 2022
Adjusting for Autocorrelated Errors in Neural Networks for Time Series

Adjusting for Autocorrelated Errors in Neural Networks for Time Series This repository is the official implementation of the paper "Adjusting for Auto

Fan-Keng Sun 51 Nov 05, 2022
This repository consists of Blender python scripts and corresponding assets to generate variants of the CANDLE dataset

candle-simulator This repository consists of Blender python scripts and corresponding assets to generate variants of the IITH-CANDLE dataset. The rend

1 Dec 15, 2021
Code for our CVPR 2021 paper "MetaCam+DSCE"

Joint Noise-Tolerant Learning and Meta Camera Shift Adaptation for Unsupervised Person Re-Identification (CVPR'21) Introduction Code for our CVPR 2021

FlyingRoastDuck 59 Oct 31, 2022
Software for Multimodalty 2D+3D Facial Expression Recognition (FER) UI

EmotionUI Software for Multimodalty 2D+3D Facial Expression Recognition (FER) UI. demo screenshot (with RealSense) required packages Python = 3.6 num

Yang Jiao 2 Dec 23, 2021
Metadata-Extractor - Metadata Extractor Script can be used to read in exif metadata

Metadata Extractor The exifextract script can be used to read in exif metadata f

1 Feb 16, 2022
Jupyter notebooks for the code samples of the book "Deep Learning with Python"

Jupyter notebooks for the code samples of the book "Deep Learning with Python"

François Chollet 16.2k Dec 30, 2022
Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face Manipulation" published in CVPR 2020.

FFD Source Code Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face M

88 Nov 22, 2022
Underwater industrial application yolov5m6

This project wins the intelligent algorithm contest finalist award and stands out from over 2000teams in China Underwater Robot Professional Contest, entering the final of China Underwater Robot Prof

8 Nov 09, 2022
Points2Surf: Learning Implicit Surfaces from Point Clouds (ECCV 2020 Spotlight)

Points2Surf: Learning Implicit Surfaces from Point Clouds (ECCV 2020 Spotlight)

Philipp Erler 329 Jan 06, 2023
The official implementation of the research paper "DAG Amendment for Inverse Control of Parametric Shapes"

DAG Amendment for Inverse Control of Parametric Shapes This repository is the official Blender implementation of the paper "DAG Amendment for Inverse

Elie Michel 157 Dec 26, 2022
Multi-Objective Loss Balancing for Physics-Informed Deep Learning

Multi-Objective Loss Balancing for Physics-Informed Deep Learning Code for ReLoBRaLo. Abstract Physics Informed Neural Networks (PINN) are algorithms

Rafael Bischof 16 Dec 12, 2022
Locally Constrained Self-Attentive Sequential Recommendation

LOCKER This is the pytorch implementation of this paper: Locally Constrained Self-Attentive Sequential Recommendation. Zhankui He, Handong Zhao, Zhe L

Zhankui (Aaron) He 8 Jul 30, 2022
Interpretable-contrastive-word-mover-s-embedding

Interpretable-contrastive-word-mover-s-embedding Paper Datasets Here is a Dropbox link to the datasets used in the paper: https://www.dropbox.com/sh/n

0 Nov 02, 2021
PyTorch reimplementation of hand-biomechanical-constraints (ECCV2020)

Hand Biomechanical Constraints Pytorch Unofficial PyTorch reimplementation of Hand-Biomechanical-Constraints (ECCV2020). This project reimplement foll

Hao Meng 59 Dec 20, 2022
UT-Sarulab MOS prediction system using SSL models

UTMOS: UTokyo-SaruLab MOS Prediction System Official implementation of "UTMOS: UTokyo-SaruLab System for VoiceMOS Challenge 2022" submitted to INTERSP

sarulab-speech 58 Nov 22, 2022
A python library to build Model Trees with Linear Models at the leaves.

A python library to build Model Trees with Linear Models at the leaves.

Marco Cerliani 212 Dec 30, 2022
This is a Keras-based Python implementation of DeepMask- a complex deep neural network for learning object segmentation masks

NNProject - DeepMask This is a Keras-based Python implementation of DeepMask- a complex deep neural network for learning object segmentation masks. Th

189 Nov 16, 2022
[ICCV '21] In this repository you find the code to our paper Keypoint Communities

Keypoint Communities In this repository you will find the code to our ICCV '21 paper: Keypoint Communities Duncan Zauss, Sven Kreiss, Alexandre Alahi,

Duncan Zauss 262 Dec 13, 2022