Official Pytorch implementation of "Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral)"

Overview

Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral): Official Project Webpage

This repository provides the official PyTorch implementation of the following paper:

Learning Debiased Representation via Disentangled Feature Augmentation
Jungsoo Lee* (KAIST AI, Kakao Enterprise), Eungyeup Kim* (KAIST AI, Kakao Enterprise),
Juyoung Lee (Kakao Enterprise), Jihyeon Lee (KAIST AI), and Jaegul Choo (KAIST AI)
(* indicates equal contribution. The order of first authors was chosen by tossing a coin.)
NeurIPS 2021, Oral

Paper: Arxiv

Abstract: Image classification models tend to make decisions based on peripheral attributes of data items that have strong correlation with a target variable (i.e., dataset bias). These biased models suffer from the poor generalization capability when evaluated on unbiased datasets. Existing approaches for debiasing often identify and emphasize those samples with no such correlation (i.e., bias-conflicting) without defining the bias type in advance. However, such bias-conflicting samples are significantly scarce in biased datasets, limiting the debiasing capability of these approaches. This paper first presents an empirical analysis revealing that training with "diverse" bias-conflicting samples beyond a given training set is crucial for debiasing as well as the generalization capability. Based on this observation, we propose a novel feature-level data augmentation technique in order to synthesize diverse bias-conflicting samples. To this end, our method learns the disentangled representation of (1) the intrinsic attributes (i.e., those inherently defining a certain class) and (2) bias attributes (i.e., peripheral attributes causing the bias), from a large number of bias-aligned samples, the bias attributes of which have strong correlation with the target variable. Using the disentangled representation, we synthesize bias-conflicting samples that contain the diverse intrinsic attributes of bias-aligned samples by swapping their latent features. By utilizing these diversified bias-conflicting features during the training, our approach achieves superior classification accuracy and debiasing results against the existing baselines on both synthetic as well as a real-world dataset.

Code Contributors

Jungsoo Lee [Website] [LinkedIn] [Google Scholar] (KAIST AI, Kakao Enterprise)
Eungyeup Kim [Website] [LinkedIn] [Google Scholar] (KAIST AI, Kakao Enterprise)
Juyoung Lee [Website] (Kakao Enterprise)

Pytorch Implementation

Installation

Clone this repository.

git clone https://github.com/kakaoenterprise/Learning-Debiased-Disentangled.git
cd Learning-Debiased-Disentangled
pip install -r requirements.txt

Datasets

We used three datasets in our paper.

Download the datasets with the following url. Note that BFFHQ is the dataset used in "BiaSwap: Removing Dataset Bias with Bias-Tailored Swapping Augmentation" (Kim et al., ICCV 2021). Unzip the files and the directory structures will be as following:

cmnist
 └ 0.5pct / 1pct / 2pct / 5pct
     └ align
     └ conlict
     └ valid
 └ test
cifar10c
 └ 0.5pct / 1pct / 2pct / 5pct
     └ align
     └ conlict
     └ valid
 └ test
bffhq
 └ 0.5pct
 └ valid
 └ test

How to Run

CMNIST

Vanilla
python train.py --dataset cmnist --exp=cmnist_0.5_vanilla --lr=0.01 --percent=0.5pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_1_vanilla --lr=0.01 --percent=1pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_2_vanilla --lr=0.01 --percent=2pct --train_vanilla --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_5_vanilla --lr=0.01 --percent=5pct --train_vanilla --tensorboard --wandb
bash scripts/run_cmnist_vanilla.sh
Ours
python train.py --dataset cmnist --exp=cmnist_0.5_ours --lr=0.01 --percent=0.5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_1_ours --lr=0.01 --percent=1pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_2_ours --lr=0.01 --percent=2pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cmnist --exp=cmnist_5_ours --lr=0.01 --percent=5pct  --curr_step=10000 --lambda_swap=1 --lambda_dis_align=10 --lambda_swap_align=10 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
bash scripts/run_cmnist_ours.sh

Corrupted CIFAR10

Vanilla
python train.py --dataset cifar10c --exp=cifar10c_0.5_vanilla --lr=0.001 --percent=0.5pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_1_vanilla --lr=0.001 --percent=1pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_2_vanilla --lr=0.001 --percent=2pct --train_vanilla --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_5_vanilla --lr=0.001 --percent=5pct --train_vanilla --tensorboard --wandb
bash scripts/run_cifar10c_vanilla.sh
Ours
python train.py --dataset cifar10c --exp=cifar10c_0.5_ours --lr=0.0005 --percent=0.5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=1 --lambda_swap_align=1 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_1_ours --lr=0.001 --percent=1pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=5 --lambda_swap_align=5 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_2_ours --lr=0.001 --percent=2pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=5 --lambda_swap_align=5 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
python train.py --dataset cifar10c --exp=cifar10c_5_ours --lr=0.001 --percent=5pct --curr_step=10000 --lambda_swap=1 --lambda_dis_align=1 --lambda_swap_align=1 --use_lr_decay --lr_decay_step=10000 --lr_gamma=0.5 --train_ours --tensorboard --wandb
bash scripts/run_cifar10c_ours.sh

BFFHQ

Vanilla
python train.py --dataset bffhq --exp=bffhq_0.5_vanilla --lr=0.0001 --percent=0.5pct --train_vanilla --tensorboard --wandb
bash scripts/run_bffhq_vanilla.sh
Ours
python train.py --dataset bffhq --exp=bffhq_0.5_ours --lr=0.0001 --percent=0.5pct --lambda_swap=0.1 --curr_step=10000 --use_lr_decay --lr_decay_step=10000 --lambda_dis_align 2. --lambda_swap_align 2. --dataset bffhq --train_ours --tensorboard --wandb
bash scripts/run_bffhq_ours.sh

Pretrained Models

In order to test our pretrained models, run the following command.

python test.py --pretrained_path=
   
     --dataset=
    
      --percent=
     

     
    
   

We provide the pretrained models in the following urls.
CMNIST 0.5pct
CMNIST 1pct
CMNIST 2pct
CMNIST 5pct

CIFAR10C 0.5pct
CIFAR10C 1pct
CIFAR10C 2pct
CIFAR10C 5pct

BFFHQ 0.5pct

Citations

Bibtex coming soon!

Contact

Jungsoo Lee

Eungyeup Kim

Juyoung Lee

Kakao Enterprise/Vision Team

Acknowledgments

This work was mainly done when both of the first authors were doing internship at Vision Team/AI Lab/Kakao Enterprise. Our pytorch implementation is based on LfF. Thanks for the implementation.

Owner
Kakao Enterprise Corp.
Kakao Enterprise Corp.
[ICCV2021] 3DVG-Transformer: Relation Modeling for Visual Grounding on Point Clouds

3DVG-Transformer This repository is for the ICCV 2021 paper "3DVG-Transformer: Relation Modeling for Visual Grounding on Point Clouds" Our method "3DV

22 Dec 11, 2022
PyTorch framework, for reproducing experiments from the paper Implicit Regularization in Hierarchical Tensor Factorization and Deep Convolutional Neural Networks

Implicit Regularization in Hierarchical Tensor Factorization and Deep Convolutional Neural Networks. Code, based on the PyTorch framework, for reprodu

Asaf 3 Dec 27, 2022
A modified version of DeepMind's Alphafold2 to divide CPU part (MSA and template searching) and GPU part (prediction model)

ParallelFold Author: Bozitao Zhong This is a modified version of DeepMind's Alphafold2 to divide CPU part (MSA and template searching) and GPU part (p

Bozitao Zhong 77 Dec 22, 2022
PyTorch Implementation of Backbone of PicoDet

PicoDet-Backbone PyTorch Implementation of Backbone of PicoDet Original Implementation is implemented on PaddlePaddle. Example picodet_l_backbone = ES

Yonghye Kwon 7 Jul 12, 2022
Attendance Monitoring with Face Recognition using Python

Attendance Monitoring with Face Recognition using Python A python GUI integrated attendance system using face recognition to take attendance. In this

Vaibhav Rajput 2 Jun 21, 2022
Audio Visual Emotion Recognition using TDA

Audio Visual Emotion Recognition using TDA RAVDESS database with two datasets analyzed: Video and Audio dataset: Audio-Dataset: https://www.kaggle.com

Combinatorial Image Analysis research group 3 May 11, 2022
QMagFace: Simple and Accurate Quality-Aware Face Recognition

Quality-Aware Face Recognition 26.11.2021 start readme QMagFace: Simple and Accurate Quality-Aware Face Recognition Research Paper Implementation - To

Philipp Terhörst 59 Jan 04, 2023
working repo for my xumx-sliCQ submissions to the ISMIR 2021 MDX

Music Demixing Challenge - xumx-sliCQ This repository is the GitHub mirror of my working submission repository for the AICrowd ISMIR 2021 Music Demixi

4 Aug 25, 2021
Solutions of Reinforcement Learning 2nd Edition

Solutions of Reinforcement Learning, An Introduction

YIFAN WANG 1.4k Dec 30, 2022
Using modified BiSeNet for face parsing in PyTorch

face-parsing.PyTorch Contents Training Demo References Training Prepare training data: -- download CelebAMask-HQ dataset -- change file path in the pr

zll 1.6k Jan 08, 2023
KE-Dialogue: Injecting knowledge graph into a fully end-to-end dialogue system.

Learning Knowledge Bases with Parameters for Task-Oriented Dialogue Systems This is the implementation of the paper: Learning Knowledge Bases with Par

CAiRE 42 Nov 10, 2022
Pytorch Implementation of "Diagonal Attention and Style-based GAN for Content-Style disentanglement in image generation and translation" (ICCV 2021)

DiagonalGAN Official Pytorch Implementation of "Diagonal Attention and Style-based GAN for Content-Style Disentanglement in Image Generation and Trans

32 Dec 06, 2022
[NeurIPS'21 Spotlight] PyTorch code for our paper "Aligned Structured Sparsity Learning for Efficient Image Super-Resolution"

ASSL This repository is for a new network pruning method (Aligned Structured Sparsity Learning, ASSL) for efficient single image super-resolution (SR)

Huan Wang 47 Nov 28, 2022
Source code for the ACL-IJCNLP 2021 paper entitled "T-DNA: Taming Pre-trained Language Models with N-gram Representations for Low-Resource Domain Adaptation" by Shizhe Diao et al.

T-DNA Source code for the ACL-IJCNLP 2021 paper entitled Taming Pre-trained Language Models with N-gram Representations for Low-Resource Domain Adapta

shizhediao 17 Dec 22, 2022
Anatomy of Matplotlib -- tutorial developed for the SciPy conference

Introduction This tutorial is a complete re-imagining of how one should teach users the matplotlib library. Hopefully, this tutorial may serve as insp

Matplotlib Developers 1.1k Dec 29, 2022
The repository for freeCodeCamp's YouTube course, Algorithmic Trading in Python

Algorithmic Trading in Python This repository Course Outline Section 1: Algorithmic Trading Fundamentals What is Algorithmic Trading? The Differences

Nick McCullum 1.8k Jan 02, 2023
An inofficial PyTorch implementation of PREDATOR based on KPConv.

PREDATOR: Registration of 3D Point Clouds with Low Overlap An inofficial PyTorch implementation of PREDATOR based on KPConv. The code has been tested

ZhuLifa 14 Aug 03, 2022
Specification language for generating Generalized Linear Models (with or without mixed effects) from conceptual models

tisane Tisane: Authoring Statistical Models via Formal Reasoning from Conceptual and Data Relationships TL;DR: Analysts can use Tisane to author gener

Eunice Jun 11 Nov 15, 2022
A large-image collection explorer and fast classification tool

IMAX: Interactive Multi-image Analysis eXplorer This is an interactive tool for visualize and classify multiple images at a time. It written in Python

Matias Carrasco Kind 23 Dec 16, 2022
Multi-Object Tracking in Satellite Videos with Graph-Based Multi-Task Modeling

TGraM Multi-Object Tracking in Satellite Videos with Graph-Based Multi-Task Modeling, Qibin He, Xian Sun, Zhiyuan Yan, Beibei Li, Kun Fu Abstract Rece

Qibin He 6 Nov 25, 2022