Implementations of paper Controlling Directions Orthogonal to a Classifier

Overview

Classifier Orthogonalization

Implementations of paper Controlling Directions Orthogonal to a Classifier , ICLR 2022,  Yilun Xu, Hao He, Tianxiao Shen, Tommi Jaakkola

Let's construct orthogonal classifiers for controlled style transfer, domain adaptation with label shifts and fairness problems 🤠 !

Outline

Controlled Style Transfer

Prepare CelebA-GH dataset:

python style_transfer/celeba_dataset.py --data_dir {path}

path: path to the CelebA dataset

bash example: python style_transfer/celeba_dataset.py --data_dir ./data

One can modify the domain_fn dictionary in the style_transfer/celeba_dataset.py file to create new groups 💡

Step 1: Train principal, full and oracle orthogonal classifiers

sh style_transfer/train_classifiers.sh {gpu} {path} {dataset} {alg}

gpu: the number of gpu
path: path to the dataset (Celeba or MNIST)
dataset: dataset (Celeba | CMNIST)
alg: ERM, Fish, TRM or MLDG

CMNIST bash example: sh style_transfer/train_classifiers.sh 0 ./data CMNIST ERM

Step 2: Train controlled CycleGAN

python style_transfer/train_cyclegan.py --data_dir {path} --dataset {dataset} \
  --obj {obj} --name {name}

path: path to the dataset (Celeba or MNIST)
dataset: dataset (Celeba | CMNIST)
obj: training objective (vanilla | orthogonal)
name: name of the model

CMNIST bash example: python style_transfer/train_cyclegan.py --data_dir ./data --dataset CMNIST --obj orthogonal --name cmnist

To view training results and loss plots, run python -m visdom.server and click the URL http://localhost:8097

Evaluation and Generation

python style_transfer/generate.py --data_dir {path} --dataset {dataset} --name {name} \
 --obj {obj} --out_path {out_path} --resume_epoch {epoch} (--save)

path: path to the dataset (Celeba or MNIST)
dataset: dataset (Celeba | CMNIST)
name: name of the model
obj: training objective (vanilla | orthogonal)
out_path: output path
epoch: resuming epoch of checkpoint

Images will be save to style_transfer/generated_images/out_path

CMNIST bash example: python style_transfer/generate.py --data_dir ./data --dataset CMNIST --name cmnist --obj orthogonal --out_path cmnist_out --resume_epoch 5


Domain Adaptation (DA) with label shifts

Prepare src/tgt pairs with label shifts

Please cd /da/data and run

python {dataset}.py --r {r0} {r1}

r0: subsample ratio for the first half classes (default=0.7)
r1: subsample ratio for the first half classes (default=0.3)
dataset: mnist | mnistm | svhn | cifar | stl | signs | digits

For SynthDigits / SynthSignsdataset, please download them at link_digits / link_signs. All the other datasets will be automatically downloaded 😉

Training

python da/vada_train.py --r {r0} {r1} --src {source} --tgt {target}  --seed {seed} \
 (--iw) (--orthogonal) (--source_only)

r0: subsample ratio for the first half classes (default=0.7)
r1: subsample ratio for the first half classes (default=0.3)
source: source domain (mnist | mnistm | svhn | cifar | stl | signs | digits)
target: target domain (mnist | mnistm | svhn | cifar | stl | signs | digits)
seed: random seed
--source_only: vanilla ERM on the source domain
--iw: use importance-weighted domain adaptation algorithm [1]
--orthogonal: use orthogonal classifier
--vada: vanilla VADA [2]

Fairness

python fairness/methods/train.py --data {data} --gamma {gamma} --sigma {sigma} \
 (--orthogonal) (--laftr) (--mifr) (--hsic)

data: dataset (adult | german)
gamma: hyper-parameter for MIFR, HSIC, LAFTR
sigma: hyper-parameter for HSIC (kernel width)
--orthogonal: use orthogonal classifier
--MIFR: use L-MIFR algorithm [3]
--HSIC: use ReBias algorithm [4]
--LAFTR: use LAFTR algorithm [5]



Reference

[1] Remi Tachet des Combes, Han Zhao, Yu-Xiang Wang, and Geoffrey J. Gordon. Domain adaptation with conditional distribution matching and generalized label shift. ArXiv, abs/2003.04475, 2020.

[2] Rui Shu, H. Bui, H. Narui, and S. Ermon. A dirt-t approach to unsupervised domain adaptation. ArXiv, abs/1802.08735, 2018.

[3] Jiaming Song, Pratyusha Kalluri, Aditya Grover, Shengjia Zhao, and S. Ermon. Learning controllable fair representations. In AISTATS, 2019.

[4] Hyojin Bahng, Sanghyuk Chun, Sangdoo Yun, Jaegul Choo, and Seong Joon Oh. Learning de-biased representations with biased representations. In ICML, 2020.

[5] David Madras, Elliot Creager, T. Pitassi, and R. Zemel. Learning adversarially fair and transferable representations. In ICML, 2018.


The implementation of this repo is based on / inspired by:

Owner
Yilun Xu
Hello!
Yilun Xu
MiniSom is a minimalistic implementation of the Self Organizing Maps

MiniSom Self Organizing Maps MiniSom is a minimalistic and Numpy based implementation of the Self Organizing Maps (SOM). SOM is a type of Artificial N

Giuseppe Vettigli 1.2k Jan 03, 2023
Language Models Can See: Plugging Visual Controls in Text Generation

Language Models Can See: Plugging Visual Controls in Text Generation Authors: Yixuan Su, Tian Lan, Yahui Liu, Fangyu Liu, Dani Yogatama, Yan Wang, Lin

Yixuan Su 195 Dec 22, 2022
Source code of CIKM2021 Long Paper "PSSL: Self-supervised Learning for Personalized Search with Contrastive Sampling".

PSSL Source code of CIKM2021 Long Paper "PSSL: Self-supervised Learning for Personalized Search with Contrastive Sampling". It consists of the pre-tra

2 Dec 21, 2021
PyTorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision.

PyTorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision @misc{CV2018, author = {Donny You ( Donny You 40 Sep 14, 2022

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
The GitHub repository for the paper: “Time Series is a Special Sequence: Forecasting with Sample Convolution and Interaction“.

SCINet This is the original PyTorch implementation of the following work: Time Series is a Special Sequence: Forecasting with Sample Convolution and I

386 Jan 01, 2023
Human Dynamics from Monocular Video with Dynamic Camera Movements

Human Dynamics from Monocular Video with Dynamic Camera Movements Ri Yu, Hwangpil Park and Jehee Lee Seoul National University ACM Transactions on Gra

215 Jan 01, 2023
A fast model to compute optical flow between two input images.

DCVNet: Dilated Cost Volumes for Fast Optical Flow This repository contains our implementation of the paper: @InProceedings{jiang2021dcvnet, title={

Huaizu Jiang 8 Sep 27, 2021
This repo is official PyTorch implementation of MobileHumanPose: Toward real-time 3D human pose estimation in mobile devices(CVPRW 2021).

Github Code of "MobileHumanPose: Toward real-time 3D human pose estimation in mobile devices" Introduction This repo is official PyTorch implementatio

Choi Sang Bum 203 Jan 05, 2023
Aws-machine-learning-university-accelerated-tab - Machine Learning University: Accelerated Tabular Data Class

Machine Learning University: Accelerated Tabular Data Class This repository contains slides, notebooks, and datasets for the Machine Learning Universi

AWS Samples 916 Dec 23, 2022
Code for "Unsupervised State Representation Learning in Atari"

Unsupervised State Representation Learning in Atari Ankesh Anand*, Evan Racah*, Sherjil Ozair*, Yoshua Bengio, Marc-Alexandre Côté, R Devon Hjelm This

Mila 217 Jan 03, 2023
Self-supervised spatio-spectro-temporal represenation learning for EEG analysis

EEG-Oriented Self-Supervised Learning and Cluster-Aware Adaptation This repository provides a tensorflow implementation of a submitted paper: EEG-Orie

Wonjun Ko 4 Jun 09, 2022
YOLOv5 Series Multi-backbone, Pruning and quantization Compression Tool Box.

YOLOv5-Compression Update News Requirements 环境安装 pip install -r requirements.txt Evaluation metric Visdrone Model mAP ZhangYuan 719 Jan 02, 2023

EfficientDet (Scalable and Efficient Object Detection) implementation in Keras and Tensorflow

EfficientDet This is an implementation of EfficientDet for object detection on Keras and Tensorflow. The project is based on the official implementati

1.3k Dec 19, 2022
Synthetic structured data generators

Join us on What is Synthetic Data? Synthetic data is artificially generated data that is not collected from real world events. It replicates the stati

YData 850 Jan 07, 2023
TensorFlow implementation of "A Simple Baseline for Bayesian Uncertainty in Deep Learning"

TensorFlow implementation of "A Simple Baseline for Bayesian Uncertainty in Deep Learning"

YeongHyeon Park 7 Aug 28, 2022
The undersampled DWI image using Slice-Interleaved Diffusion Encoding (SIDE) method can be reconstructed by the UNet network.

UNet-SIDE The undersampled DWI image using Slice-Interleaved Diffusion Encoding (SIDE) method can be reconstructed by the UNet network. For Super Reso

TIANTIAN XU 1 Jan 13, 2022
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Matthias Wright 169 Dec 26, 2022
A curated list of the latest breakthroughs in AI (in 2021) by release date with a clear video explanation, link to a more in-depth article, and code.

2021: A Year Full of Amazing AI papers- A Review 📌 A curated list of the latest breakthroughs in AI by release date with a clear video explanation, l

Louis-François Bouchard 2.9k Dec 31, 2022
GCC: Graph Contrastive Coding for Graph Neural Network Pre-Training @ KDD 2020

GCC: Graph Contrastive Coding for Graph Neural Network Pre-Training Original implementation for paper GCC: Graph Contrastive Coding for Graph Neural N

THUDM 274 Dec 27, 2022