Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions

Overview

torch-imle

Concise and self-contained PyTorch library implementing the I-MLE gradient estimator proposed in our NeurIPS 2021 paper Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions.

This repository contains a library for transforming any combinatorial black-box solver in a differentiable layer. All code for reproducing the experiments in the NeurIPS paper is available in the official NEC Laboratories Europe repository.

Overview

Implicit MLE (I-MLE) makes it possible to include discrete combinatorial optimization algorithms, such as Dijkstra's algorithm or integer linear program (ILP) solvers, in standard deep learning architectures. The core idea of I-MLE is that it defines an implicit maximum likelihood objective whose gradients are used to update upstream parameters of the model. Every instance of I-MLE requires two ingredients:

  1. A method to approximately sample from a complex and intractable distribution. For this we use Perturb-and-MAP (aka the Gumbel-max trick) and propose a novel family of noise perturbations tailored to the problem at hand.
  2. A method to compute a surrogate empirical distribution: Vanilla MLE reduces the KL divergence between the current distribution and the empirical distribution. Since in our setting, we do not have access to an empirical distribution, we have to design surrogate empirical distributions. Here we propose two families of surrogate distributions which are widely applicable and work well in practice.

Example

For example, let's consider a map from a simple game where the task is to find the shortest path from the top-left to the bottom-right corner. Black areas have the highest and white areas the lowest cost. In the centre, you can see what happens when we use the proposed sum-of-gamma noise distribution to sample paths. On the right, you can see the resulting marginal probabilities for every tile (the probability of each tile being part of a sampled path).

Gradients and Learning

Let us assume that the optimal shortest path is the one of the left. Starting from random weights, the model can learn to produce the weights that will result in the optimal shortest path via Gradient Descent, by minimising the Hamming loss between the produced path and the gold path. Here we show the paths being produced during training (middle), and the corresponding map weights (right).

Input noise temperature set to 0.0, and target noise temperature set to 0.0:

Input noise temperature set to 1.0, and target noise temperature set to 1.0:

Input noise temperature set to 2.0, and target noise temperature set to 2.0:

Input noise temperature set to 5.0, and target noise temperature set to 5.0:

Input noise temperature set to 5.0, and target noise temperature set to 0.0:

All animations were generated by this script.

Code

Using this library is extremely easy -- see this example as a reference. Assuming we have a method that implements a black-box combinatorial solver such as Dijkstra's algorithm:

import numpy as np

import torch
from torch import Tensor

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

We can obtain the corresponding distribution and gradients in this way:

from imle.wrapper import imle
from imle.target import TargetDistribution
from imle.noise import SumOfGammaNoiseDistribution

target_distribution = TargetDistribution(alpha=0.0, beta=10.0)
noise_distribution = SumOfGammaNoiseDistribution(k=k, nb_iterations=100)

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

imle_solver = imle(torch_solver,
                   target_distribution=target_distribution,
                    noise_distribution=noise_distribution,
                    nb_samples=10,
                    input_noise_temperature=input_noise_temperature,
                    target_noise_temperature=target_noise_temperature)

Or, alternatively, using a simple function annotation:

@imle(target_distribution=target_distribution,
      noise_distribution=noise_distribution,
      nb_samples=10,
      input_noise_temperature=input_noise_temperature,
      target_noise_temperature=target_noise_temperature)
def imle_solver(weights_batch: Tensor) -> Tensor:
    return torch_solver(weights_batch)

Papers using I-MLE

Reference

@inproceedings{niepert21imle,
  author    = {Mathias Niepert and
               Pasquale Minervini and
               Luca Franceschi},
  title     = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family
               Distributions},
  booktitle = {NeurIPS},
  series    = {Proceedings of Machine Learning Research},
  publisher = {{PMLR}},
  year      = {2021}
}
Owner
UCL Natural Language Processing
UCL Natural Language Processing
Detecting Blurred Ground-based Sky/Cloud Images

Detecting Blurred Ground-based Sky/Cloud Images With the spirit of reproducible research, this repository contains all the codes required to produce t

1 Oct 20, 2021
A code implementation of AC-GC: Activation Compression with Guaranteed Convergence, in NeurIPS 2021.

Code For AC-GC: Lossy Activation Compression with Guaranteed Convergence This code is intended to be used as a supplemental material for submission to

Dave Evans 2 Nov 01, 2022
一些经典的CTR算法的复现; LR, FM, FFM, AFM, DeepFM,xDeepFM, PNN, DCN, DCNv2, DIFM, AutoInt, FiBiNet,AFN,ONN,DIN, DIEN ... (pytorch, tf2.0)

CTR Algorithm 根据论文, 博客, 知乎等方式学习一些CTR相关的算法 理解原理并自己动手来实现一遍 pytorch & tf2.0 保持一颗学徒的心! Schedule Model pytorch tensorflow2.0 paper LR ✔️ ✔️ \ FM ✔️ ✔️ Fac

luo han 149 Dec 20, 2022
Continuous Augmented Positional Embeddings (CAPE) implementation for PyTorch

PyTorch implementation of Continuous Augmented Positional Embeddings (CAPE), by Likhomanenko et al. Enhance your Transformer positional embeddings with easy-to-use augmentations!

Guillermo Cámbara 26 Dec 13, 2022
Course content and resources for the AIAIART course.

AIAIART course This repo will house the notebooks used for the AIAIART course. Part 1 (first four lessons) ran via Discord in September/October 2021.

Jonathan Whitaker 492 Jan 06, 2023
TUPÃ was developed to analyze electric field properties in molecular simulations

TUPÃ: Electric field analyses for molecular simulations What is TUPÃ? TUPÃ (pronounced as tu-pan) is a python algorithm that employs MDAnalysis engine

Marcelo D. Polêto 10 Jul 17, 2022
Code for sound field predictions in domains with impedance boundaries. Used for generating results from the paper

Code for sound field predictions in domains with impedance boundaries. Used for generating results from the paper

DTU Acoustic Technology Group 11 Dec 17, 2022
Self-Supervised Vision Transformers Learn Visual Concepts in Histopathology (LMRL Workshop, NeurIPS 2021)

Self-Supervised Vision Transformers Learn Visual Concepts in Histopathology Self-Supervised Vision Transformers Learn Visual Concepts in Histopatholog

Richard Chen 95 Dec 24, 2022
《Towards High Fidelity Face Relighting with Realistic Shadows》(CVPR 2021)

Towards High Fidelity Face-Relighting with Realistic Shadows Andrew Hou, Ze Zhang, Michel Sarkis, Ning Bi, Yiying Tong, Xiaoming Liu. In CVPR, 2021. T

114 Dec 10, 2022
A web-based application for quick, scalable, and automated hyperparameter tuning and stacked ensembling in Python.

Xcessiv Xcessiv is a tool to help you create the biggest, craziest, and most excessive stacked ensembles you can think of. Stacked ensembles are simpl

Reiichiro Nakano 1.3k Nov 17, 2022
MonoScene: Monocular 3D Semantic Scene Completion

MonoScene: Monocular 3D Semantic Scene Completion MonoScene: Monocular 3D Semantic Scene Completion] [arXiv + supp] | [Project page] Anh-Quan Cao, Rao

298 Jan 08, 2023
PyKale is a PyTorch library for multimodal learning and transfer learning as well as deep learning and dimensionality reduction on graphs, images, texts, and videos

PyKale is a PyTorch library for multimodal learning and transfer learning as well as deep learning and dimensionality reduction on graphs, images, texts, and videos. By adopting a unified pipeline-ba

PyKale 370 Dec 27, 2022
Public scripts, services, and configuration for running a smart home K3S network cluster

makerhouse_network Public scripts, services, and configuration for running MakerHouse's home network. This network supports: TODO features here For mo

Scott Martin 1 Jan 15, 2022
Structured Edge Detection Toolbox

################################################################### # # # Structure

Piotr Dollar 779 Jan 02, 2023
FlexConv: Continuous Kernel Convolutions with Differentiable Kernel Sizes

FlexConv: Continuous Kernel Convolutions with Differentiable Kernel Sizes This repository contains the source code accompanying the paper: FlexConv: C

Robert-Jan Bruintjes 96 Dec 12, 2022
Python implementation of "Elliptic Fourier Features of a Closed Contour"

PyEFD An Python/NumPy implementation of a method for approximating a contour with a Fourier series, as described in [1]. Installation pip install pyef

Henrik Blidh 71 Dec 09, 2022
GBIM(Gesture-Based Interaction map)

手势交互地图 GBIM(Gesture-Based Interaction map),基于视觉深度神经网络的交互地图,通过电脑摄像头观察使用者的手势变化,进而控制地图进行简单的交互。网络使用PaddleX提供的轻量级模型PPYOLO Tiny以及MobileNet V3 small,使得整个模型大小约10MB左右,即使在CPU下也能快速定位和识别手势。

8 Feb 10, 2022
UA-GEC: Grammatical Error Correction and Fluency Corpus for the Ukrainian Language

UA-GEC: Grammatical Error Correction and Fluency Corpus for the Ukrainian Language This repository contains UA-GEC data and an accompanying Python lib

Grammarly 226 Dec 29, 2022
Classification Modeling: Probability of Default

Credit Risk Modeling in Python Introduction: If you've ever applied for a credit card or loan, you know that financial firms process your information

Aktham Momani 2 Nov 07, 2022
code for EMNLP 2019 paper Text Summarization with Pretrained Encoders

PreSumm This code is for EMNLP 2019 paper Text Summarization with Pretrained Encoders Updates Jan 22 2020: Now you can Summarize Raw Text Input!. Swit

Yang Liu 1.2k Dec 28, 2022