[ICML 2021] A fast algorithm for fitting robust decision trees.

Overview

GROOT: Growing Robust Trees

Growing Robust Trees (GROOT) is an algorithm that fits binary classification decision trees such that they are robust against user-specified adversarial examples. The algorithm closely resembles algorithms used for fitting normal decision trees (i.e. CART) but changes the splitting criterion and the way samples propagate when creating a split.

This repository contains the module groot that implements GROOT as a Scikit-learn compatible classifier, an adversary for model evaluation and easy functions to import datasets. For documentation see https://groot.cyber-analytics.nl

Simple example

To train and evaluate GROOT on a toy dataset against an attacker that can move samples by 0.5 in each direction one can use the following code:

from groot.adversary import DecisionTreeAdversary
from groot.model import GrootTreeClassifier

from sklearn.datasets import make_moons

X, y = make_moons(noise=0.3, random_state=0)
X_test, y_test = make_moons(noise=0.3, random_state=1)

attack_model = [0.5, 0.5]
is_numerical = [True, True]
tree = GrootTreeClassifier(attack_model=attack_model, is_numerical=is_numerical, random_state=0)

tree.fit(X, y)
accuracy = tree.score(X_test, y_test)
adversarial_accuracy = DecisionTreeAdversary(tree, "groot").adversarial_accuracy(X_test, y_test)

print("Accuracy:", accuracy)
print("Adversarial Accuracy:", adversarial_accuracy)

Installation

groot can be installed from PyPi: pip install groot-trees

To use Kantchelian's MILP attack it is required that you have GUROBI installed along with their python package: python -m pip install -i https://pypi.gurobi.com gurobipy

Specific dependency versions

To reproduce our experiments with exact package versions you can clone the repository and run: pip install -r requirements.txt

We recommend using virtual environments.

Reproducing 'Efficient Training of Robust Decision Trees Against Adversarial Examples' (article)

To reproduce the results from the paper we provide generate_k_fold_results.py, a script that takes the trained models (from JSON format) and generates tables and figures. The resulting figures generate under /out/.

To not only generate the results but to also retrain all models we include the scripts train_kfold_models.py and fit_chen_xgboost.py. The first script runs the algorithms in parallel for each dataset then outputs to /out/trees/ and /out/forests/. Warning: the script can take a long time to run (about a day given 16 cores). The second script train specifically the Chen et al. boosting ensembles. /out/results.zip contains all results from when we ran the scripts.

To experiment on image datasets we have a script image_experiments.py that fits and output the results. In this script, one can change the dataset variable to 'mnist' or 'fmnist' to switch between the two.

The scripts summarize_datasets.py and visualize_threat_models.py output some figures we used in the text.

Implementation details

The TREANT implementation (groot.treant.py) is copied almost completely from the authors of TREANT at https://github.com/gtolomei/treant with small modifications to better interface with the experiments. The heuristic by Chen et al. runs in the GROOT code, only with a different score function. This score function can be enabled by setting chen_heuristic=True on a GrootTreeClassifier before calling .fit(X, y). The provably robust boosting implementation comes almost completely from their code at https://github.com/max-andr/provably-robust-boosting and we use a small wrapper around their code (groot.provably_robust_boosting.wrapper.py) to use it. When we recorded the runtimes we turned off all parallel options in the @jit annotations from the code. The implementation of Chen et al. boosting can be found in their own repo https://github.com/chenhongge/RobustTrees, from whic we need to compile and copy the binary xgboost to the current directory. The script fit_chen_xgboost.py then calls this binary and uses the command line interface to fit all models.

Important note on TREANT

To encode L-infinity norms correctly we had to modify TREANT to NOT apply rules recursively. This means we added a single break statement in the treant.Attacker.__compute_attack() method. If you are planning on using TREANT with recursive attacker rules then you should remove this statement or use TREANT's unmodified code at https://github.com/gtolomei/treant .

Contact

For any questions or comments please create an issue or contact me directly.

Comments
  • Reproducing results from the article, issue with runtimes.csv

    Reproducing results from the article, issue with runtimes.csv

    Hello! I am trying to reproduce results from the article, and I can't figure out certain problem. First I am trying to run train_kfold_models, but the code always ouputs an error: "ImportError: cannot import name 'GrootTree' from 'groot.model'". Is there something wrong with the .py file I am trying to run, or is this problem something that doesn't occur to you and everyone else (-->something wrong on computer or files or environment)?

    Onni Mansikkamäki

    opened by OnniMansikkamaki 3
  • is_numerical argument GrootTreeClassifier

    is_numerical argument GrootTreeClassifier

    Running the example code on the make moons data in the README I get:

    Traceback (most recent call last):
      File "/home/.../groot_test.py", line 11, in <module>
        tree = GrootTreeClassifier(attack_model=attack_model, is_numerical=is_numerical, random_state=0)
    TypeError: __init__() got an unexpected keyword argument 'is_numerical'
    

    Leaving out the argument and having this line instead: tree = GrootTreeClassifier(attack_model=attack_model, random_state=0) results in this error:

    Traceback (most recent call last):
      File "/home/.../groot_test.py", line 15, in <module>
        adversarial_accuracy = DecisionTreeAdversary(tree, "groot").adversarial_accuracy(X_test, y_test)
      File "/home/.../venv/lib/python3.9/site-packages/groot/adversary.py", line 259, in __init__
        self.is_numeric = self.decision_tree.is_numerical
    AttributeError: 'GrootTreeClassifier' object has no attribute 'is_numerical'
    

    I'm guessing the code got an update, but the readme didn't. Or I made a stupid mistake, also very possible.

    opened by laudv 2
  • Reproducing result from paper

    Reproducing result from paper

    Hello! I am trying to reproduce the results from the paper. I am struggling to find, where these files: generate_k_fold_results.py, train_kfold_models.py, fit_chen_xgboost.py, image_experiments.py, summarize_datasets.py and visualize_threat_models.py are provided?

    Onni Mansikkamäki

    opened by OnniMansikkamaki 0
  • Regression decision trees and random forests

    Regression decision trees and random forests

    This PR adds GROOT decision trees and random forests that use the adversarial sum of absolute errors to make splits. It also adds new tests, speeds them up and updates the documentation.

    opened by daniel-vos 0
  • Add regression, tests and refactor into base class

    Add regression, tests and refactor into base class

    This PR adds a regression GROOT tree based on the adversarial sum of absolute errors, more tests and refactors GROOT trees into a base class (BaseGrootTree) with subclasses GrootTreeClassifier and GrootTreeRegressor extending it.

    opened by daniel-vos 0
Releases(v0.0.1)
Owner
Cyber Analytics Lab
@ Delft University of Technology
Cyber Analytics Lab
[CVPR 21] Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), 2021.

Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, CVPR 2021. Ayan Kumar Bhunia, Pinaki nath Chowdhury, Yongxin Yan

Ayan Kumar Bhunia 44 Dec 12, 2022
Graph Transformer Architecture. Source code for

Graph Transformer Architecture Source code for the paper "A Generalization of Transformer Networks to Graphs" by Vijay Prakash Dwivedi and Xavier Bres

NTU Graph Deep Learning Lab 561 Jan 08, 2023
Program your own vulkan.gpuinfo.org query in Python. Used to determine baseline hardware for WebGPU.

query-gpuinfo-data License This software is not presently released under a license. The data in data/ is obtained under CC BY 4.0 as specified there.

Kai Ninomiya 5 Jul 18, 2022
This is the official implementation for the paper "Heterogeneous Multi-player Multi-armed Bandits: Closing the Gap and Generalization" in NeurIPS 2021.

MPMAB_BEACON This is code used for the paper "Decentralized Multi-player Multi-armed Bandits: Beyond Linear Reward Functions", Neurips 2021. Requireme

Cong Shen Research Group 0 Oct 26, 2021
A baseline code for VSPW

A baseline code for VSPW Preparation Download VSPW dataset The VSPW dataset with extracted frames and masks is available here.

28 Aug 22, 2022
Implementation of EMNLP 2017 Paper "Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog" using PyTorch and ParlAI

Language Emergence in Multi Agent Dialog Code for the Paper Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog Satwik Kottur, José M.

Karan Desai 105 Nov 25, 2022
Task-based end-to-end model learning in stochastic optimization

Task-based End-to-end Model Learning in Stochastic Optimization This repository is by Priya L. Donti, Brandon Amos, and J. Zico Kolter and contains th

CMU Locus Lab 164 Dec 29, 2022
WSDM2022 Challenge - Large scale temporal graph link prediction

WSDM 2022 Large-scale Temporal Graph Link Prediction - Baseline and Initial Test Set WSDM Cup Website link Link to this challenge This branch offers A

Deep Graph Library 34 Dec 29, 2022
[NeurIPS 2021] ORL: Unsupervised Object-Level Representation Learning from Scene Images

Unsupervised Object-Level Representation Learning from Scene Images This repository contains the official PyTorch implementation of the ORL algorithm

Jiahao Xie 55 Dec 03, 2022
Classification of ecg datas for disease detection

ecg_classification Classification of ecg datas for disease detection

Atacan ÖZKAN 5 Sep 09, 2022
《Single Image Reflection Removal Beyond Linearity》(CVPR 2019)

Single-Image-Reflection-Removal-Beyond-Linearity Paper Single Image Reflection Removal Beyond Linearity. Qiang Wen, Yinjie Tan, Jing Qin, Wenxi Liu, G

Qiang Wen 51 Jun 24, 2022
D2LV: A Data-Driven and Local-Verification Approach for Image Copy Detection

Facebook AI Image Similarity Challenge: Matching Track —— Team: imgFp This is the source code of our 3rd place solution to matching track of Image Sim

16 Dec 25, 2022
The Multi-Mission Maximum Likelihood framework (3ML)

PyPi Conda The Multi-Mission Maximum Likelihood framework (3ML) A framework for multi-wavelength/multi-messenger analysis for astronomy/astrophysics.

The Multi-Mission Maximum Likelihood (3ML) 62 Dec 30, 2022
[CVPR 2022] Back To Reality: Weak-supervised 3D Object Detection with Shape-guided Label Enhancement

Back To Reality: Weak-supervised 3D Object Detection with Shape-guided Label Enhancement Announcement 🔥 We have not tested the code yet. We will fini

Xiuwei Xu 7 Oct 30, 2022
A benchmark framework for Tensorflow

TensorFlow benchmarks This repository contains various TensorFlow benchmarks. Currently, it consists of two projects: PerfZero: A benchmark framework

1.1k Dec 30, 2022
Files for a tutorial to train SegNet for road scenes using the CamVid dataset

SegNet and Bayesian SegNet Tutorial This repository contains all the files for you to complete the 'Getting Started with SegNet' and the 'Bayesian Seg

Alex Kendall 800 Dec 31, 2022
This is a simple backtesting framework to help you test your crypto currency trading. It includes a way to download and store historical crypto data and to execute a trading strategy.

You can use this simple crypto backtesting script to ensure your trading strategy is successful Minimal setup required and works well with static TP a

Andrei 154 Sep 12, 2022
Reinforcement learning framework and algorithms implemented in PyTorch.

Reinforcement learning framework and algorithms implemented in PyTorch.

Robotic AI & Learning Lab Berkeley 2.1k Jan 04, 2023
A few stylization coreML models that I've trained with CreateML

CoreML-StyleTransfer A few stylization coreML models that I've trained with CreateML You can open and use the .mlmodel files in the "models" folder in

Doron Adler 8 Aug 18, 2022
FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation.

FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation [Project] [Paper] [arXiv] [Home] Official implementation of FastFCN:

Wu Huikai 815 Dec 29, 2022