PConv-Keras - Unofficial implementation of "Image Inpainting for Irregular Holes Using Partial Convolutions". Try at: www.fixmyphoto.ai

Overview

Partial Convolutions for Image Inpainting using Keras

Keras implementation of "Image Inpainting for Irregular Holes Using Partial Convolutions", https://arxiv.org/abs/1804.07723. A huge shoutout the authors Guilin Liu, Fitsum A. Reda, Kevin J. Shih, Ting-Chun Wang, Andrew Tao and Bryan Catanzaro from NVIDIA corporation for releasing this awesome paper, it's been a great learning experience for me to implement the architecture, the partial convolutional layer, and the loss functions.

Dependencies

  • Python 3.6
  • Keras 2.2.4
  • Tensorflow 1.12

How to use this repository

The easiest way to try a few predictions with this algorithm is to go to www.fixmyphoto.ai, where I've deployed it on a serverless React application with AWS lambda functions handling inference.

If you want to dig into the code, the primary implementations of the new PConv2D keras layer as well as the UNet-like architecture using these partial convolutional layers can be found in libs/pconv_layer.py and libs/pconv_model.py, respectively - this is where the bulk of the implementation can be found. Beyond this I've set up four jupyter notebooks, which details the several steps I went through while implementing the network, namely:

Step 1: Creating random irregular masks
Step 2: Implementing and testing the implementation of the PConv2D layer
Step 3: Implementing and testing the UNet architecture with PConv2D layers
Step 4: Training & testing the final architecture on ImageNet
Step 5: Simplistic attempt at predicting arbitrary image sizes through image chunking

Pre-trained weights

I've ported the VGG16 weights from PyTorch to keras; this means the 1/255. pixel scaling can be used for the VGG16 network similarly to PyTorch.

Training on your own dataset

You can either go directly to step 4 notebook, or alternatively use the CLI (make sure to download the converted VGG16 weights):

python main.py \
    --name MyDataset \
    --train TRAINING_PATH \
    --validation VALIDATION_PATH \
    --test TEST_PATH \
    --vgg_path './data/logs/pytorch_to_keras_vgg16.h5'

Implementation details

Details of the implementation are in the paper itself, however I'll try to summarize some details here.

Mask Creation

In the paper they use a technique based on occlusion/dis-occlusion between two consecutive frames in videos for creating random irregular masks - instead I've opted for simply creating a simple mask-generator function which uses OpenCV to draw some random irregular shapes which I then use for masks. Plugging in a new mask generation technique later should not be a problem though, and I think the end results are pretty decent using this method as well.

Partial Convolution Layer

A key element in this implementation is the partial convolutional layer. Basically, given the convolutional filter W and the corresponding bias b, the following partial convolution is applied instead of a normal convolution:

where ⊙ is element-wise multiplication and M is a binary mask of 0s and 1s. Importantly, after each partial convolution, the mask is also updated, so that if the convolution was able to condition its output on at least one valid input, then the mask is removed at that location, i.e.

The result of this is that with a sufficiently deep network, the mask will eventually be all ones (i.e. disappear)

UNet Architecture

Specific details of the architecture can be found in the paper, but essentially it's based on a UNet-like structure, where all normal convolutional layers are replace with partial convolutional layers, such that in all cases the image is passed through the network alongside the mask. The following provides an overview of the architecture.

Loss Function(s)

The loss function used in the paper is kinda intense, and can be reviewed in the paper. In short it includes:

  • Per-pixel losses both for maskes and un-masked regions
  • Perceptual loss based on ImageNet pre-trained VGG-16 (pool1, pool2 and pool3 layers)
  • Style loss on VGG-16 features both for predicted image and for computed image (non-hole pixel set to ground truth)
  • Total variation loss for a 1-pixel dilation of the hole region

The weighting of all these loss terms are as follows:

Training Procedure

Network was trained on ImageNet with a batch size of 1, and each epoch was specified to be 10,000 batches long. Training was furthermore performed using the Adam optimizer in two stages since batch normalization presents an issue for the masked convolutions (since mean and variance is calculated for hole pixels).

Stage 1 Learning rate of 0.0001 for 50 epochs with batch normalization enabled in all layers

Stage 2 Learning rate of 0.00005 for 50 epochs where batch normalization in all encoding layers is disabled.

Training time for shown images was absolutely crazy long, but that is likely because of my poor personal setup. The few tests I've tried on a 1080Ti (with batch size of 4) indicates that training time could be around 10 days, as specified in the paper.

Owner
Mathias Gruber
Chief Data Scientist
Mathias Gruber
基于AlphaPose的TensorRT加速

1. Requirements CUDA 11.1 TensorRT 7.2.2 Python 3.8.5 Cython PyTorch 1.8.1 torchvision 0.9.1 numpy 1.17.4 (numpy版本过高会出报错 this issue ) python-package s

52 Dec 06, 2022
Tree-based Search Graph for Approximate Nearest Neighbor Search

TBSG: Tree-based Search Graph for Approximate Nearest Neighbor Search. TBSG is a graph-based algorithm for ANNS based on Cover Tree, which is also an

Fanxbin 2 Dec 27, 2022
My coursework for Machine Learning (2021 Spring) at National Taiwan University (NTU)

Machine Learning 2021 Machine Learning (NTU EE 5184, Spring 2021) Instructor: Hung-yi Lee Course Website : (https://speech.ee.ntu.edu.tw/~hylee/ml/202

100 Dec 26, 2022
An educational AI robot based on NVIDIA Jetson Nano.

JetBot Looking for a quick way to get started with JetBot? Many third party kits are now available! JetBot is an open-source robot based on NVIDIA Jet

NVIDIA AI IOT 2.6k Dec 29, 2022
[ICLR 2021, Spotlight] Large Scale Image Completion via Co-Modulated Generative Adversarial Networks

Large Scale Image Completion via Co-Modulated Generative Adversarial Networks, ICLR 2021 (Spotlight) Demo | Paper [NEW!] Time to play with our interac

Shengyu Zhao 373 Jan 02, 2023
GLANet - The code for Global and Local Alignment Networks for Unpaired Image-to-Image Translation arxiv

GLANet The code for Global and Local Alignment Networks for Unpaired Image-to-Image Translation arxiv Framework: visualization results: Getting Starte

stanley 29 Dec 14, 2022
Official PyTorch Implementation of Unsupervised Learning of Scene Flow Estimation Fusing with Local Rigidity

UnRigidFlow This is the official PyTorch implementation of UnRigidFlow (IJCAI2019). Here are two sample results (~10MB gif for each) of our unsupervis

Liang Liu 28 Nov 16, 2022
Position detection system of mobile robot in the warehouse enviroment

Autonomous-Forklift-System About | GUI | Tests | Starting | License | Author | 🎯 About An application that run the autonomous forklift paletization a

Kamil Goś 1 Nov 24, 2021
Conversational text Analysis using various NLP techniques

PyConverse Let me try first Installation pip install pyconverse Usage Please try this notebook that demos the core functionalities: basic usage noteb

Rita Anjana 158 Dec 25, 2022
Efficiently computes derivatives of numpy code.

Note: Autograd is still being maintained but is no longer actively developed. The main developers (Dougal Maclaurin, David Duvenaud, Matt Johnson, and

Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton 6.1k Jan 08, 2023
Employee-Managment - Company employee registration software in the face recognition system

Employee-Managment Company employee registration software in the face recognitio

Alireza Kiaeipour 7 Jul 10, 2022
Learning Confidence for Out-of-Distribution Detection in Neural Networks

Learning Confidence Estimates for Neural Networks This repository contains the code for the paper Learning Confidence for Out-of-Distribution Detectio

235 Jan 05, 2023
Python package for multiple object tracking research with focus on laboratory animals tracking.

motutils is a Python package for multiple object tracking research with focus on laboratory animals tracking. Features loads: MOTChallenge CSV, sleap

Matěj Šmíd 2 Sep 05, 2022
Official Pytorch Code for the paper TransWeather

TransWeather Official Code for the paper TransWeather, Arxiv Tech Report 2021 Paper | Website About this repo: This repo hosts the implentation code,

Jeya Maria Jose 81 Dec 30, 2022
This application is the basic of automated online-class-joiner(for YıldızEdu) within the right time. Gets the ZOOM link by scheduled date and time.

This application is the basic of automated online-class-joiner(for YıldızEdu) within the right time. Gets the ZOOM link by scheduled date and time.

215355 1 Dec 16, 2021
Training neural models with structured signals.

Neural Structured Learning in TensorFlow Neural Structured Learning (NSL) is a new learning paradigm to train neural networks by leveraging structured

955 Jan 02, 2023
A paper using optimal transport to solve the graph matching problem.

GOAT A paper using optimal transport to solve the graph matching problem. https://arxiv.org/abs/2111.05366 Repo structure .github: Files specifying ho

neurodata 8 Jan 04, 2023
Code repository accompanying the paper "On Adversarial Robustness: A Neural Architecture Search perspective"

On Adversarial Robustness: A Neural Architecture Search perspective Preparation: Clone the repository: https://github.com/tdchaitanya/nas-robustness.g

Chaitanya Devaguptapu 4 Nov 10, 2022
AlphaNet Improved Training of Supernet with Alpha-Divergence

AlphaNet: Improved Training of Supernet with Alpha-Divergence This repository contains our PyTorch training code, evaluation code and pretrained model

Facebook Research 87 Oct 10, 2022
A list of all papers and resoureces on Semantic Segmentation

Semantic-Segmentation A list of all papers and resoureces on Semantic Segmentation. Dataset importance SemanticSegmentation_DL Some implementation of

Alan Tang 1.1k Dec 12, 2022