Unofficial JAX implementations of Deep Learning models

Overview

JAX Models

license-shield release-shield python-shield code-style

Table of Contents
  1. About The Project
  2. Getting Started
  3. Contributing
  4. License
  5. Contact

About The Project

The JAX Models repository aims to provide open sourced JAX/Flax implementations for research papers originally without code or code written with frameworks other than JAX. The goal of this project is to make a collection of models, layers, activations and other utilities that are most commonly used for research. All papers and derived or translated code is cited in either the README or the docstrings. If you think that any citation is missed then please raise an issue.

All implementations provided here are available on Papers With Code.


Available model implementations for JAX are:
  1. MetaFormer is Actually What You Need for Vision (Weihao Yu et al., 2021)
  2. Augmenting Convolutional networks with attention-based aggregation (Hugo Touvron et al., 2021)
  3. MPViT : Multi-Path Vision Transformer for Dense Prediction (Youngwan Lee et al., 2021)
  4. MLP-Mixer: An all-MLP Architecture for Vision (Ilya Tolstikhin et al., 2021)
  5. Patches Are All You Need (Anonymous et al., 2021)
  6. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (Enze Xie et al., 2021)
  7. A ConvNet for the 2020s (Zhuang Liu et al., 2021)
  8. Masked Autoencoders Are Scalable Vision Learners (Kaiming He et al., 2021)

Available layers for out-of-the-box integration:
  1. DropPath (Stochastic Depth) (Gao Huang et al., 2021)
  2. Squeeze-and-Excitation Layer (Jie Hu et al. 2019)
  3. Depthwise Convolution (François Chollet, 2017)

Prerequisites

Prerequisites can be installed separately through the requirements.txt file in the main directory using:

pip install -r requirements.txt

The use of a virtual environment is highly recommended to avoid version incompatibilites.

Installation

This project is built with Python 3 for the latest JAX/Flax versions and can be directly installed via pip.

pip install jax-models

If you wish to use the latest version then you can directly clone the repository too.

git clone https://github.com/DarshanDeshpande/jax-models.git

Usage

To see all model architectures available:

from jax_models.models.model_registry import list_models
from pprint import pprint

pprint(list_models())

To load your desired model:

from jax_models.models.model_registry import load_model
load_model('mpvit-base', attach_head=True, num_classes=1000, dropout=0.1)

Contributing

Please raise an issue if any implementation gives incorrect results, crashes unexpectedly during training/inference or if any citation is missing.

You can contribute to jax_models by supporting me with compute resources or by contributing your own resources to provide pretrained weights.

If you wish to donate to this inititative then please drop me a mail here.

License

Distributed under the Apache 2.0 License. See LICENSE for more information.

Contact

Feel free to reach out for any issues or requests related to these implementations

Darshan Deshpande - Email | Twitter | LinkedIn

You might also like...
Very deep VAEs in JAX/Flax
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

PyTorch implementations of neural network models for keyword spotting
PyTorch implementations of neural network models for keyword spotting

Honk: CNNs for Keyword Spotting Honk is a PyTorch reimplementation of Google's TensorFlow convolutional neural networks for keyword spotting, which ac

Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning
Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning

Proxy Anchor Loss for Deep Metric Learning Unofficial pytorch, tensorflow and mxnet implementations of Proxy Anchor Loss for Deep Metric Learning. Not

Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.
Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.

Stock Price Prediction Using Deep Learning Univariate Time Series Predicting stock price using historical data of a company using Neural networks for

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax

Objax Tutorials | Install | Documentation | Philosophy This is not an officially supported Google product. Objax is an open source machine learning fr

Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX
Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.

JAX code for the paper
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Comments
  • Missing Axis Swap in ExtractPatches and MergePatches

    Missing Axis Swap in ExtractPatches and MergePatches

    In patch_utils.py, the modules ExtractPatches and MergePatches are missing an axis swap between the reshapes, resulting in the extracted patches becoming horizontal stripes. For example, if we follow the code in ExtractPatches:

    >>> inputs = jnp.arange(16).reshape(1, 4, 4, 1)
    >>> inputs[0, :, :, 0]
    
    DeviceArray([[ 0,  1,  2,  3],
                 [ 4,  5,  6,  7],
                 [ 8,  9, 10, 11],
                 [12, 13, 14, 15]], dtype=int32)
    
    >>> patch_size = 2
    >>> batch, height, width, channels = inputs.shape
    >>> height, width = height // patch_size, width // patch_size
    >>> x = jnp.reshape(inputs, (batch, height, patch_size, width, patch_size, channels))
    >>> x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
    >>> x[0, 0, :]
    
    DeviceArray([0, 1, 2, 3], dtype=int32)
    

    We see that the first patch extracted is not the patch containing [0, 1, 4, 5], but the horizontal stripe [0, 1, 2, 3]. To fix this problem, we should add an axis swap. For ExtractPatches, this should be:

    batch, height, width, channels = inputs.shape
    height, width = height // patch_size, width // patch_size
    x = jnp.reshape(
        inputs, (batch, height, patch_size, width, patch_size, channels)
    )
    x = jnp.swapaxes(x, 2, 3)
    x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
    

    For MergePatches, this should be:

    batch, length, _ = inputs.shape
    height = width = int(length**0.5)
    x = jnp.reshape(inputs, (batch, height, width, patch_size, patch_size, -1))
    x = jnp.swapaxes(x, 2, 3)
    x = jnp.reshape(x, (batch, height * patch_size, width * patch_size, -1))
    
    bug 
    opened by young-geng 4
  • fix convnext to make it work with jax.jit

    fix convnext to make it work with jax.jit

    Hey, first of all, thanks for the nice codebase. When doing inference using the convnext model, I noticed the following issue:

    Calling x.item() will call float(x), which breaks the jit tracer. We can remove the list comprehension in unnecessary conversion to make jax.jit work. Without jax.jit, the model is very slow for me, running with only ~30% GPU utilization (RTX 3090).

    This issue could apply to other models as well, maybe it is a good idea to include a test for applying jax.jit to each model?

    opened by maxidl 1
Releases(v0.5-van)
Owner
Helping Machines Learn Better 💻😃
Angora is a mutation-based fuzzer. The main goal of Angora is to increase branch coverage by solving path constraints without symbolic execution.

Angora Angora is a mutation-based coverage guided fuzzer. The main goal of Angora is to increase branch coverage by solving path constraints without s

833 Jan 07, 2023
🔮 A refreshing functional take on deep learning, compatible with your favorite libraries

Thinc: A refreshing functional take on deep learning, compatible with your favorite libraries From the makers of spaCy, Prodigy and FastAPI Thinc is a

Explosion 2.6k Dec 30, 2022
Instance Segmentation by Jointly Optimizing Spatial Embeddings and Clustering Bandwidth

Instance segmentation by jointly optimizing spatial embeddings and clustering bandwidth This codebase implements the loss function described in: Insta

209 Dec 07, 2022
Emblaze - Interactive Embedding Comparison

Emblaze - Interactive Embedding Comparison Emblaze is a Jupyter notebook widget for visually comparing embeddings using animated scatter plots. It bun

CMU Data Interaction Group 77 Nov 24, 2022
Simulations for Turring patterns on an apically expanding domain. T

Turing patterns on expanding domain Simulations for Turring patterns on an apically expanding domain. The details about the models and numerical imple

Yue Liu 0 Aug 03, 2021
Python wrappers to the C++ library SymEngine, a fast C++ symbolic manipulation library.

SymEngine Python Wrappers Python wrappers to the C++ library SymEngine, a fast C++ symbolic manipulation library. Installation Pip See License section

136 Dec 28, 2022
Code for the paper "TadGAN: Time Series Anomaly Detection Using Generative Adversarial Networks"

TadGAN: Time Series Anomaly Detection Using Generative Adversarial Networks This is a Python3 / Pytorch implementation of TadGAN paper. The associated

Arun 92 Dec 03, 2022
A Runtime method overload decorator which should behave like a compiled language

strongtyping-pyoverload A Runtime method overload decorator which should behave like a compiled language there is a override decorator from typing whi

20 Oct 31, 2022
Official pytorch implementation of Active Learning for deep object detection via probabilistic modeling (ICCV 2021)

Active Learning for Deep Object Detection via Probabilistic Modeling This repository is the official PyTorch implementation of Active Learning for Dee

NVIDIA Research Projects 130 Jan 06, 2023
Video Matting Refinement For Python

Video-matting refinement Library (use pip to install) scikit-image numpy av matplotlib Run Static background python path_to_video.mp4 Moving backgroun

3 Jan 11, 2022
Codes for our paper "SentiLARE: Sentiment-Aware Language Representation Learning with Linguistic Knowledge" (EMNLP 2020)

SentiLARE: Sentiment-Aware Language Representation Learning with Linguistic Knowledge Introduction SentiLARE is a sentiment-aware pre-trained language

74 Dec 30, 2022
The PyTorch improved version of TPAMI 2017 paper: Face Alignment in Full Pose Range: A 3D Total Solution.

Face Alignment in Full Pose Range: A 3D Total Solution By Jianzhu Guo. [Updates] 2020.8.30: The pre-trained model and code of ECCV-20 are made public

Jianzhu Guo 3.4k Jan 02, 2023
Identify the emotion of multiple speakers in an Audio Segment

MevonAI - Speech Emotion Recognition Identify the emotion of multiple speakers in a Audio Segment Report Bug · Request Feature Try the Demo Here Table

Suyash More 110 Dec 03, 2022
Open source simulator for autonomous vehicles built on Unreal Engine / Unity, from Microsoft AI & Research

Welcome to AirSim AirSim is a simulator for drones, cars and more, built on Unreal Engine (we now also have an experimental Unity release). It is open

Microsoft 13.8k Jan 03, 2023
Implementing yolov4 target detection and tracking based on nao robot

Implementing yolov4 target detection and tracking based on nao robot

6 Apr 19, 2022
An investigation project for SISR.

SISR-Survey An investigation project for SISR. This repository is an official project of the paper "From Beginner to Master: A Survey for Deep Learnin

Juncheng Li 79 Oct 20, 2022
Visual Question Answering in Pytorch

Visual Question Answering in pytorch /!\ New version of pytorch for VQA available here: https://github.com/Cadene/block.bootstrap.pytorch This repo wa

Remi 672 Jan 01, 2023
Single Image Deraining Using Bilateral Recurrent Network (TIP 2020)

Single Image Deraining Using Bilateral Recurrent Network Introduction Single image deraining has received considerable progress based on deep convolut

23 Aug 10, 2022
Code for "Sparse Steerable Convolutions: An Efficient Learning of SE(3)-Equivariant Features for Estimation and Tracking of Object Poses in 3D Space"

Sparse Steerable Convolution (SS-Conv) Code for "Sparse Steerable Convolutions: An Efficient Learning of SE(3)-Equivariant Features for Estimation and

25 Dec 21, 2022
Some simple programs built in Python: webcam with cv2 that detects eyes and face, with grayscale filter

Programas en Python Algunos programas simples creados en Python: 📹 Webcam con c

Madirex 1 Feb 15, 2022