A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
A Linux webcam plugin for BGMv2 as used in our demos.

The goal of this repository is to supplement the main Real-Time High Resolution Background Matting repo with a working demo of a videoconferencing plu

Andrey Ryabtsev 144 Dec 27, 2022
A python package to adjust the bias of probabilistic forecasts/hindcasts using "Mean and Variance Adjustment" method.

Documentation A python package to adjust the bias of probabilistic forecasts/hindcasts using "Mean and Variance Adjustment" method. Read documentation

1 Feb 02, 2022
Step by step development of a vending coffee machine project, including tkinter, sqlite3, simulation, etc.

Step by step development of a vending coffee machine project, including tkinter, sqlite3, simulation, etc.

Nikolaos Avouris 2 Dec 05, 2021
Python library for the analysis of dynamic measurements

Python library for the analysis of dynamic measurements The goal of this library is to provide a starting point for users in metrology and related are

Physikalisch-Technische Bundesanstalt - Department 9.4 'Metrology for the digital Transformation' 18 Dec 21, 2022
Diff Match Patch is a high-performance library in multiple languages that manipulates plain text.

The Diff Match and Patch libraries offer robust algorithms to perform the operations required for synchronizing plain text. Diff: Compare two blocks o

Google 5.9k Dec 30, 2022
It converts ING BANK account historic into a csv file you can import in HomeBank application.

ing2homebank It converts your ING Bank account historic csv file into another csv file you can import in HomeBank application

1 Feb 14, 2022
A script to automatically update bot status at GitHub as well as in Telegram channel.

A simple & short repository to show your bot's status in your GitHub README.md file as well as in you channel.

Jainam Oswal 55 Dec 13, 2022
A package with multiple bias correction methods for climatic variables, including the QM, DQM, QDM, UQM, and SDM methods

A package with multiple bias correction methods for climatic variables, including the QM, DQM, QDM, UQM, and SDM methods

Sebastián A. Aedo Quililongo 9 Nov 18, 2022
A general purpose low level programming language written in Python.

A general purpose low level programming language written in Python. Basal is an easy mid level programming language compiling to C. It has an easy syntax, similar to Python, Rust etc.

Snm Logic 6 Mar 30, 2022
Paxos in Python, tested with Jepsen

Python implementation of Multi-Paxos with a stable leader and reconfiguration, roughly following "Paxos Made Moderately Complex". Run python3 paxos/st

A. Jesse Jiryu Davis 25 Dec 15, 2022
- Auto join teams teams ( from calendar invite )

Auto Join Teams Meetings Requirements: Python 3.7 or higher Latest Google Chrome This script automatically logins to your account and joins the meetin

Prajin Khadka 10 Aug 20, 2022
ThinkPHP全日志扫描工具,命令行版和BurpSuite插件版

ThinkPHP3和5日志扫描工具,提供命令行版和BurpSuite插件版,尽可能全的发掘网站日志信息 命令行版 安装 git clone https://github.com/r3change/TPLogScan.git cd TPLogScan/ pip install -r requireme

119 Dec 27, 2022
Robotic hamster to give you financial advice

hampp Robotic hamster to give you financial advice. I am not liable for any advice that the hamster gives. Follow at your own peril. Description Hampp

1 Nov 17, 2021
Python-Kite: Simple python code to make kite pattern

Python-Kite Simple python code to make kite pattern. Getting Started These instr

Anoint 0 Mar 22, 2022
A simple package for interacting with the 9kw.eu anti-captcha service.

Welcome to captcha9kw’s documentation! captcha9kw is a smallish Python package for making use of the 9kw.eu services, including solving of interactive

2 Feb 26, 2022
An easy-to-learn, dynamic, interpreted, procedural programming language

Gen Programming Language WARNING!! THIS LANGUAGE IS IN DEVELOPMENT. ANYTHING CAN CHANGE AT ANY MOMENT. Gen is a dynamic, interpreted, procedural progr

Gen Programming Language 7 Oct 17, 2022
Module for remote in-memory Python package/module loading through HTTP/S

httpimport Python's missing feature! The feature has been suggested in Python Mailing List Remote, in-memory Python package/module importing through H

John Torakis 220 Dec 17, 2022
A random cat fact python module

A random cat fact python module

Fayas Noushad 4 Nov 28, 2021
Open slidebook .sldy files in Python

Work in progress slidebook-python Open slidebook .sldy files in Python To install slidebook-python requires Python = 3.9 pip install slidebook-python

The Institute of Cancer Research 2 May 04, 2022
Team collaborative evaluation tracker.

Team collaborative evaluation tracker.

2 Dec 19, 2021