Implementation of the paper "Shapley Explanation Networks"

Overview

Shapley Explanation Networks

Implementation of the paper "Shapley Explanation Networks" at ICLR 2021. Note that this repo heavily uses the experimental feature of named tensors in PyTorch. As it was really confusing to implement the ideas for the authors, we find it tremendously easier to use this feature.

Dependencies

For running only ShapNets, one would mostly only need PyTorch, NumPy, and SciPy.

Usage

For a Shapley Module:

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule

b_size = 3
features = 4
out = 1
dims = ModuleDimensions(
    features=features,
    in_channel=1,
    out_channel=out
)

sm = ShapleyModule(
    inner_function=nn.Linear(features, out),
    dimensions=dims
)
sm(torch.randn(b_size, features), explain=True)

For a Shallow ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, OverlappingShallowShapleyNetwork

batch_size = 32
class_num = 10
dim = 32

overlapping_modules = [
    ShapleyModule(
        inner_function=nn.Sequential(nn.Linear(2, class_num)),
        dimensions=ModuleDimensions(
            features=2, in_channel=1, out_channel=class_num
        ),
    ) for _ in range(dim * (dim - 1) // 2)
]
shallow_shapnet = OverlappingShallowShapleyNetwork(
    list_modules=overlapping_modules
)
inputs = torch.randn(batch_size, dim, ), )
shallow_shapnet(torch.randn(batch_size, dim, ), )
output, bias = shallow_shapnet(inputs, explain=True, )

For a Deep ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, ShallowShapleyNetwork, DeepShapleyNetwork

dim = 32
dim_input_channels = 1
class_num = 10
inputs = torch.randn(32, dim, ), )


dims = ModuleDimensions(
    features=dim,
    in_channel=dim_input_channels,
    out_channel=class_num
)
deep_shapnet = DeepShapleyNetwork(
    list_shapnets=[
        ShallowShapleyNetwork(
            module_dict=nn.ModuleDict({
                "(0, 2)": ShapleyModule(
                    inner_function=nn.Linear(2, class_num),
                    dimensions=ModuleDimensions(
                        features=2, in_channel=1, out_channel=class_num
                    )
                )},
            ),
            dimensions=ModuleDimensions(dim, 1, class_num)
        ),
    ],
)
deep_shapnet(inputs)
outputs = deep_shapnet(inputs, explain=True, )

For a vision model:

import numpy as np
import torch
import torch.nn as nn

# =============================================================================
# Imports {\sc ShapNet}
# =============================================================================
from ShapNet import DeepConvShapNet, ShallowConvShapleyNetwork, ShapleyModule
from ShapNet.utils import ModuleDimensions, NAME_HEIGHT, NAME_WIDTH, \
    process_list_sizes

num_channels = 3
num_classes = 10
height = 32
width = 32
list_channels = [3, 16, 10]
pruning = [0.2, 0.]
kernel_sizes = process_list_sizes([2, (1, 3), ])
dilations = process_list_sizes([1, 2])
paddings = process_list_sizes([0, 0])
strides = process_list_sizes([1, 1])

args = {
    "list_shapnets": [
        ShallowConvShapleyNetwork(
            shapley_module=ShapleyModule(
                inner_function=nn.Sequential(
                    nn.Linear(
                        np.prod(kernel_sizes[i]) * list_channels[i],
                        list_channels[i + 1]),
                    nn.LeakyReLU()
                ),
                dimensions=ModuleDimensions(
                    features=int(np.prod(kernel_sizes[i])),
                    in_channel=list_channels[i],
                    out_channel=list_channels[i + 1])
            ),
            reference_values=None,
            kernel_size=kernel_sizes[i],
            dilation=dilations[i],
            padding=paddings[i],
            stride=strides[i]
        ) for i in range(len(list_channels) - 1)
    ],
    "reference_values": None,
    "residual": False,
    "named_output": False,
    "pruning": pruning
}

dcs = DeepConvShapNet(**args)

Citation

If this is useful, you could cite our work as

@inproceedings{
wang2021shapley,
title={Shapley Explanation Networks},
author={Rui Wang and Xiaoqian Wang and David I. Inouye},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=vsU0efpivw}
}
Owner
Prof. David I. Inouye's research lab at Purdue University.
[ICCV 2021 Oral] NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo

NerfingMVS Project Page | Paper | Video | Data NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo Yi Wei, Shaohui

Yi Wei 369 Dec 24, 2022
Earth Vision Foundation

EVer - A Library for Earth Vision Researcher EVer is a Pytorch-based Python library to simplify the training and inference of the deep learning model.

Zhuo Zheng 34 Nov 26, 2022
v objective diffusion inference code for JAX.

v-diffusion-jax v objective diffusion inference code for JAX, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman). The models

Katherine Crowson 186 Dec 21, 2022
1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

Lihe Yang 209 Jan 01, 2023
A new play-and-plug method of controlling an existing generative model with conditioning attributes and their compositions.

Viz-It Data Visualizer Web-Application If I ask you where most of the data wrangler looses their time ? It is Data Overview and EDA. Presenting "Viz-I

NVIDIA Research Projects 66 Jan 01, 2023
Self-Supervised Learning with Kernel Dependence Maximization

Self-Supervised Learning with Kernel Dependence Maximization This is the code for SSL-HSIC, a self-supervised learning loss proposed in the paper Self

DeepMind 29 Dec 29, 2022
A Pytorch Implementation of ClariNet

ClariNet A Pytorch Implementation of ClariNet (Mel Spectrogram -- Waveform) Requirements PyTorch 0.4.1 & python 3.6 & Librosa Examples Step 1. Downlo

Sungwon Kim 286 Sep 15, 2022
CoINN: Correlated-informed neural networks: a new machine learning framework to predict pressure drop in micro-channels

CoINN: Correlated-informed neural networks: a new machine learning framework to predict pressure drop in micro-channels Accurate pressure drop estimat

Alejandro Montanez 0 Jan 21, 2022
QSYM: A Practical Concolic Execution Engine Tailored for Hybrid Fuzzing

QSYM: A Practical Concolic Execution Engine Tailored for Hybrid Fuzzing Environment Tested on Ubuntu 14.04 64bit and 16.04 64bit Installation # disabl

gts3.org (<a href=[email protected])"> 581 Dec 30, 2022
Live Hand Tracking Using Python

Live-Hand-Tracking-Using-Python Project Description: In this project, we will be

Hassan Shahzad 2 Jan 06, 2022
[CVPR 2022 Oral] EPro-PnP: Generalized End-to-End Probabilistic Perspective-n-Points for Monocular Object Pose Estimation

EPro-PnP EPro-PnP: Generalized End-to-End Probabilistic Perspective-n-Points for Monocular Object Pose Estimation In CVPR 2022 (Oral). [paper] Hanshen

同济大学智能汽车研究所综合感知研究组 ( Comprehensive Perception Research Group under Institute of Intelligent Vehicles, School of Automotive Studies, Tongji University) 842 Jan 04, 2023
Code for "Diffusion is All You Need for Learning on Surfaces"

Source code for "Diffusion is All You Need for Learning on Surfaces", by Nicholas Sharp Souhaib Attaiki Keenan Crane Maks Ovsjanikov NOTE: the linked

Nick Sharp 247 Dec 28, 2022
Segment axon and myelin from microscopy data using deep learning

Segment axon and myelin from microscopy data using deep learning. Written in Python. Using the TensorFlow framework. Based on a convolutional neural network architecture. Pixels are classified as eit

NeuroPoly 103 Nov 29, 2022
Pytorch Implementation of "Desigining Network Design Spaces", Radosavovic et al. CVPR 2020.

RegNet Pytorch Implementation of "Desigining Network Design Spaces", Radosavovic et al. CVPR 2020. Paper | Official Implementation RegNet offer a very

Vishal R 2 Feb 11, 2022
Uncertainty Estimation via Response Scaling for Pseudo-mask Noise Mitigation in Weakly-supervised Semantic Segmentation

Uncertainty Estimation via Response Scaling for Pseudo-mask Noise Mitigation in Weakly-supervised Semantic Segmentation Introduction This is a PyTorch

XMed-Lab 30 Sep 23, 2022
This is a simple framework to make object detection dataset very quickly

FastAnnotation Table of contents General info Requirements Setup General info This is a simple framework to make object detection dataset very quickly

Serena Tetart 1 Jan 24, 2022
PyTorch implementation code for the paper MixCo: Mix-up Contrastive Learning for Visual Representation

How to Reproduce our Results This repository contains PyTorch implementation code for the paper MixCo: Mix-up Contrastive Learning for Visual Represen

opcrisis 46 Dec 15, 2022
Official TensorFlow code for the forthcoming paper

~ Efficient-CapsNet ~ Are you tired of over inflated and overused convolutional neural networks? You're right! It's time for CAPSULES :)

Vittorio Mazzia 203 Jan 08, 2023
CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper)

CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper) (Accepted for oral presentation at ACM

Minha Kim 1 Nov 12, 2021
Implementation of H-Transformer-1D, Hierarchical Attention for Sequence Learning

H-Transformer-1D Implementation of H-Transformer-1D, Transformer using hierarchical Attention for sequence learning with subquadratic costs. For now,

Phil Wang 123 Nov 17, 2022