Pretraining Representations For Data-Efficient Reinforcement Learning

Related tags

Deep LearningSGI
Overview

Pretraining Representations For Data-Efficient Reinforcement Learning

Max Schwarzer, Nitarshan Rajkumar, Michael Noukhovitch, Ankesh Anand, Laurent Charlin, Devon Hjelm, Philip Bachman & Aaron Courville

This repo provides code for implementing SGI.

Install

To install the requirements, follow these steps:

# PyTorch
export LANG=C.UTF-8
# Install requirements
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt

# Finally, install the project
pip install --user -e .

Usage:

The default branch for the latest and stable changes is release.

  • To run SGI:
  1. Download the DQN replay dataset from https://research.google/tools/datasets/dqn-replay/
    • Or substitute your own pre-training data! The codebase expects a series of .gz files, one each for observations, actions and terminals.
  2. To pretrain with SGI:
python -m scripts.run public=True model_folder=./ offline.runner.save_every=2500 \
    env.game=pong seed=1 offline_model_save={your model name} \
    offline.runner.epochs=10 offline.runner.dataloader.games=[Pong] \
    offline.runner.no_eval=1 \
    +offline.algo.goal_weight=1 \
    +offline.algo.inverse_model_weight=1 \
    +offline.algo.spr_weight=1 \
    +offline.algo.target_update_tau=0.01 \
    +offline.agent.model_kwargs.momentum_tau=0.01 \
    do_online=False \
    algo.batch_size=256 \
    +offline.agent.model_kwargs.noisy_nets_std=0 \
    offline.runner.dataloader.dataset_on_disk=True \
    offline.runner.dataloader.samples=1000000 \
    offline.runner.dataloader.checkpoints='{your checkpoints}' \
    offline.runner.dataloader.num_workers=2 \
    offline.runner.dataloader.data_path={your data dir} \
    offline.runner.dataloader.tmp_data_path=./ 
  1. To fine-tune with SGI:
python -m scripts.run public=True env.game=pong seed=1 num_logs=10  \
    model_load={your_model_name} model_folder=./ \
    algo.encoder_lr=0.000001 algo.q_l1_lr=0.00003 algo.clip_grad_norm=-1 algo.clip_model_grad_norm=-1

When reporting scores, we average across 10 fine-tuning seeds.

./scripts/experiments contains a number of example configurations, including for SGI-M, SGI-M/L and SGI-W, for both pre-training and fine-tuning. Each of these scripts can be launched by providing a game and seed, e.g., ./scripts/experiments/sgim_pretrain.sh pong 1. These scripts are provided primarily to illustrate the hyperparameters used for different experiments; you will likely need to modify the arguments in these scripts to point to your data and model directories.

Data for SGI-R and SGI-E is not included due to its size, but can be re-generated locally. Contact us for details.

What does each file do?

.
β”œβ”€β”€ scripts
β”‚   β”œβ”€β”€ run.py                # The main runner script to launch jobs.
β”‚   β”œβ”€β”€ config.yaml           # The hydra configuration file, listing hyperparameters and options.
|   └── experiments           # Configurations for various experiments done by SGI.
|   
β”œβ”€β”€ src                     
β”‚   β”œβ”€β”€ agent.py              # Implements the Agent API for action selection 
β”‚   β”œβ”€β”€ algos.py              # Distributional RL loss and optimization
β”‚   β”œβ”€β”€ models.py             # Forward passes, network initialization.
β”‚   β”œβ”€β”€ networks.py           # Network architecture and forward passes.
β”‚   β”œβ”€β”€ offline_dataset.py    # Dataloader for offline data.
β”‚   β”œβ”€β”€ gcrl.py               # Utils for SGI's goal-conditioned RL objective.
β”‚   β”œβ”€β”€ rlpyt_atari_env.py    # Slightly modified Atari env from rlpyt
β”‚   β”œβ”€β”€ rlpyt_utils.py        # Utility methods that we use to extend rlpyt's functionality
β”‚   └── utils.py              # Command line arguments and helper functions 
β”‚
└── requirements.txt          # Dependencies
Owner
Mila
Quebec Artificial Intelligence Institute
Mila
PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation

PyTorch implementation of Interpretable Explanations of Black Boxes by Meaningful Perturbation The paper: https://arxiv.org/abs/1704.03296 What makes

Jacob Gildenblat 322 Dec 17, 2022
An Open-Source Tool for Automatic Disease Diagnosis..

OpenMedicalChatbox An Open-Source Package for Automatic Disease Diagnosis. Overview Due to the lack of open source for existing RL-base automated diag

8 Nov 08, 2022
Dynamic Divide-and-Conquer Adversarial Training for Robust Semantic Segmentation (ICCV2021οΌ‰

Dynamic Divide-and-Conquer Adversarial Training for Robust Semantic Segmentation This is a pytorch project for the paper Dynamic Divide-and-Conquer Ad

DV Lab 29 Nov 21, 2022
Model search is a framework that implements AutoML algorithms for model architecture search at scale

Model search (MS) is a framework that implements AutoML algorithms for model architecture search at scale. It aims to help researchers speed up their exploration process for finding the right model a

Google 3.2k Dec 31, 2022
Attention-driven Robot Manipulation (ARM) which includes Q-attention

Attention-driven Robotic Manipulation (ARM) This codebase is home to: Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation I

Stephen James 84 Dec 29, 2022
Illuminated3D This project participates in the Nasa Space Apps Challenge 2021.

Illuminated3D This project participates in the Nasa Space Apps Challenge 2021.

Eleftheriadis Emmanouil 1 Oct 09, 2021
Domain Generalization for Mammography Detection via Multi-style and Multi-view Contrastive Learning

MSVCL_MICCAI2021 Installation Please follow the instruction in pytorch-CycleGAN-and-pix2pix to install. Example Usage An example of vendor-styles tran

Jaron Lee 11 Oct 19, 2022
GraPE is a Rust/Python library for high-performance Graph Processing and Embedding.

GraPE GraPE (Graph Processing and Embedding) is a fast graph processing and embedding library, designed to scale with big graphs and to run on both of

AnacletoLab 194 Dec 29, 2022
Model Agnostic Interpretability for Multiple Instance Learning

MIL Model Agnostic Interpretability This repo contains the code for "Model Agnostic Interpretability for Multiple Instance Learning". Overview Executa

Joe Early 10 Dec 17, 2022
Predicting future trajectories of people in cameras of novel scenarios and views.

Pedestrian Trajectory Prediction Predicting future trajectories of pedestrians in cameras of novel scenarios and views. This repository contains the c

8 Sep 03, 2022
A Python library for adversarial machine learning focusing on benchmarking adversarial robustness.

ARES This repository contains the code for ARES (Adversarial Robustness Evaluation for Safety), a Python library for adversarial machine learning rese

Tsinghua Machine Learning Group 377 Dec 20, 2022
Taming Transformers for High-Resolution Image Synthesis

Taming Transformers for High-Resolution Image Synthesis CVPR 2021 (Oral) Taming Transformers for High-Resolution Image Synthesis Patrick Esser*, Robin

CompVis Heidelberg 3.5k Jan 03, 2023
Speech Recognition is an important feature in several applications used such as home automation, artificial intelligence

Speech Recognition is an important feature in several applications used such as home automation, artificial intelligence, etc. This article aims to provide an introduction on how to make use of the S

RISHABH MISHRA 1 Feb 13, 2022
Emotion classification of online comments based on RNN

emotion_classification Emotion classification of online comments based on RNN, the accuracy of the model in the test set reaches 99% data: Large Movie

1 Nov 23, 2021
An implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019).

MixHop and N-GCN β € A PyTorch implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019)

Benedek Rozemberczki 393 Dec 13, 2022
A python library for highly configurable transformers - easing model architecture search and experimentation.

A python library for highly configurable transformers - easing model architecture search and experimentation.

Anthony Fuller 51 Nov 20, 2022
Demystifying How Self-Supervised Features Improve Training from Noisy Labels

Demystifying How Self-Supervised Features Improve Training from Noisy Labels This code is a PyTorch implementation of the paper "[Demystifying How Sel

<a href=[email protected]"> 4 Oct 14, 2022
A repo with study material, exercises, examples, etc for Devnet SPAUTO

MPLS in the SDN Era -- DevNet SPAUTO Get right to the study material: Checkout the Wiki! A lab topology based on MPLS in the SDN era book used for 30

Hugo Tinoco 67 Nov 16, 2022
LSTC: Boosting Atomic Action Detection with Long-Short-Term Context

LSTC: Boosting Atomic Action Detection with Long-Short-Term Context This Repository contains the code on AVA of our ACM MM 2021 paper: LSTC: Boosting

Tencent YouTu Research 9 Oct 11, 2022
The Official Implementation of the ICCV-2021 Paper: Semantically Coherent Out-of-Distribution Detection.

SCOOD-UDG (ICCV 2021) This repository is the official implementation of the paper: Semantically Coherent Out-of-Distribution Detection Jingkang Yang,

Jake YANG 62 Nov 21, 2022