Einshape: DSL-based reshaping library for JAX and other frameworks.

Related tags

Deep Learningeinshape
Overview

Einshape: DSL-based reshaping library for JAX and other frameworks.

The jnp.einsum op provides a DSL-based unified interface to matmul and tensordot ops. This einshape library is designed to offer a similar DSL-based approach to unifying reshape, squeeze, expand_dims, and transpose operations.

Some examples:

  • einshape("n->n111", x) is equivalent to expand_dims(x, axis=1) three times
  • einshape("a1b11->ab", x) is equivalent to squeeze(x, axis=[1,3,4])
  • einshape("nhwc->nchw", x) is equivalent to transpose(x, perm=[0,3,1,2])
  • einshape("mnhwc->(mn)hwc", x) is equivalent to a reshape combining the two leading dimensions
  • einshape("(mn)hwc->mnhwc", x, n=batch_size) is equivalent to a reshape splitting the leading dimension into two, using kwargs (m or n or both) to supply the necessary additional shape information
  • einshape("mn...->(mn)...", x) combines the two leading dimensions without knowing the rank of x
  • einshape("n...->n(...)", x) performs a 'batch flatten'
  • einshape("ij->ijk", x, k=3) inserts a trailing dimension and tiles along it
  • einshape("ij->i(nj)", x, n=3) tiles along the second dimension

See jax_ops.py for the JAX implementation of the einshape function. Alternatively, the parser and engine are exposed in engine.py allowing analogous implementations in TensorFlow or other frameworks.

Installation

Einshape can be installed with the following command:

pip3 install git+https://github.com/deepmind/einshape

Einshape will work with either Jax or TensorFlow. To allow for that it does not list either as a requirement, so it is necessary to ensure that Jax or TensorFlow is installed separately.

Usage

Jax version:

(ij)", a) # b is [1, 2, 3, 4] ">
from einshape import jax_einshape as einshape
from jax import numpy as jnp

a = jnp.array([[1, 2], [3, 4]])
b = einshape("ij->(ij)", a)
# b is [1, 2, 3, 4]

TensorFlow version:

(ij)", a) # b is [1, 2, 3, 4] ">
from einshape import tf_einshape as einshape
import tensorflow as tf

a = tf.constant([[1, 2], [3, 4]])
b = einshape("ij->(ij)", a)
# b is [1, 2, 3, 4]

Understanding einshape equations

An einshape equation is always of the form {lhs}->{rhs}, where {lhs} and {rhs} both stand for expressions. An expression represents the axes of an array; the relationship between two expressions illustrate how an array should be transformed.

An expression is a non-empty sequence of the following elements:

Index name

A single letter a-z, representing one axis of an array.

For example, the expressions ab and jq both represent an array of rank 2.

Every index name that is present on the left-hand side of an equation must also be present on the right-hand side. So, ab->a is not a valid equation, but a->ba is valid (and will tile a vector b times).

Ellipsis

..., representing any axes of an array that are not otherwise represented in the expression. This is similar to the use of -1 as an axis in a reshape operation.

For example, a...b can represent any array of rank 2 or more: a will refer to the first axis and b to the last. The equation ...ab->...ba will swap the last two axes of an array.

An expression may not include more than one ellipsis (because that would be ambiguous). Like an index name, an ellipsis must be present in both halves of an equation or neither.

Group

({components}), where components is a sequence of index names and ellipsis elements. The entire group corresponds to a single axis of the array; the group's components represent factors of the axis size. This can be used to reshape an axis into many axes. All the factors except at most one must be specified using keyword arguments.

For example, einshape('(ab)->ab', x, a=10) reshapes an array of rank 1 (whose length must be a multiple of 10) into an array of rank 2 (whose first dimension is of length 10).

Groups may not be nested.

Unit

The digit 1, representing a single axis of length 1. This is useful for expanding and squeezing unit dimensions.

For example, the equation 1...->... squeezes a leading axis (which must have length one).

Disclaimer

This is not an official Google product.

Einshape Logo

Owner
DeepMind
DeepMind
The implementation of the CVPR2021 paper "Structure-Aware Face Clustering on a Large-Scale Graph with 10^7 Nodes"

STAR-FC This code is the implementation for the CVPR 2021 paper "Structure-Aware Face Clustering on a Large-Scale Graph with 10^7 Nodes" 🌟 🌟 . 🎓 Re

Shuai Shen 87 Dec 28, 2022
A GPU-optional modular synthesizer in pytorch, 16200x faster than realtime, for audio ML researchers.

torchsynth The fastest synth in the universe. Introduction torchsynth is based upon traditional modular synthesis written in pytorch. It is GPU-option

torchsynth 229 Jan 02, 2023
Mixed Transformer UNet for Medical Image Segmentation

MT-UNet Update 2021/11/19 Thank you for your interest in our work. We have uploaded the code of our MTUNet to help peers conduct further research on i

dotman 92 Dec 25, 2022
Tools for manipulating UVs in the Blender viewport.

UV Tool Suite for Blender A set of tools to make editing UVs easier in Blender. These tools can be accessed wither through the Kitfox - UV panel on th

35 Oct 29, 2022
Super Pix Adv - Offical implemention of Robust Superpixel-Guided Attentional Adversarial Attack (CVPR2020)

Super_Pix_Adv Offical implemention of Robust Superpixel-Guided Attentional Adver

DLight 8 Oct 26, 2022
The code for Expectation-Maximization Attention Networks for Semantic Segmentation (ICCV'2019 Oral)

EMANet News The bug in loading the pretrained model is now fixed. I have updated the .pth. To use it, download it again. EMANet-101 gets 80.99 on the

Xia Li 李夏 663 Nov 30, 2022
Code for Transformer Hawkes Process, ICML 2020.

Transformer Hawkes Process Source code for Transformer Hawkes Process (ICML 2020). Run the code Dependencies Python 3.7. Anaconda contains all the req

Simiao Zuo 111 Dec 26, 2022
PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal Convolutions for Action Recognition"

R2Plus1D-PyTorch PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal

Irhum Shafkat 342 Dec 16, 2022
A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution

DRSAN A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution Karam Park, Jae Woong Soh, and Nam Ik Cho Environments U

4 May 10, 2022
Group R-CNN for Point-based Weakly Semi-supervised Object Detection (CVPR2022)

Group R-CNN for Point-based Weakly Semi-supervised Object Detection (CVPR2022) By Shilong Zhang*, Zhuoran Yu*, Liyang Liu*, Xinjiang Wang, Aojun Zhou,

Shilong Zhang 129 Dec 24, 2022
PyTorch implementation of the end-to-end coreference resolution model with different higher-order inference methods.

End-to-End Coreference Resolution with Different Higher-Order Inference Methods This repository contains the implementation of the paper: Revealing th

Liyan 52 Jan 04, 2023
Open-source implementation of Google Vizier for hyper parameters tuning

Advisor Introduction Advisor is the hyper parameters tuning system for black box optimization. It is the open-source implementation of Google Vizier w

tobe 1.5k Jan 04, 2023
The repository for freeCodeCamp's YouTube course, Algorithmic Trading in Python

Algorithmic Trading in Python This repository Course Outline Section 1: Algorithmic Trading Fundamentals What is Algorithmic Trading? The Differences

Nick McCullum 1.8k Jan 02, 2023
FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation.

FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation [Project] [Paper] [arXiv] [Home] Official implementation of FastFCN:

Wu Huikai 815 Dec 29, 2022
for taichi voxel-challange event

Taichi Voxel Challenge Figure: result of python3 example6.py. Please replace the image above (demo.jpg) with yours, so that other people can immediate

Liming Xu 20 Nov 26, 2022
PyTorch implementation of Deformable Convolution

Deformable Convolutional Networks in PyTorch This repo is an implementation of Deformable Convolution. Ported from author's MXNet implementation. Buil

411 Dec 16, 2022
Code for Reciprocal Adversarial Learning for Brain Tumor Segmentation: A Solution to BraTS Challenge 2021 Segmentation Task

BRATS 2021 Solution For Segmentation Task This repo contains the supported pytorch code and configuration files to reproduce 3D medical image segmenta

Himashi Amanda Peiris 6 Sep 15, 2022
Distributed DataLoader For Pytorch Based On Ray

Dpex——用户无感知分布式数据预处理组件 一、前言 随着GPU与CPU的算力差距越来越大以及模型训练时的预处理Pipeline变得越来越复杂,CPU部分的数据预处理已经逐渐成为了模型训练的瓶颈所在,这导致单机的GPU配置的提升并不能带来期望的线性加速。预处理性能瓶颈的本质在于每个GPU能够使用的C

Dalong 23 Nov 02, 2022
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion". Paper link: https://arxiv.org/abs/2111.10

Ziyao Zeng 14 Feb 26, 2022
PASSL包含 SimCLR,MoCo,BYOL,CLIP等基于对比学习的图像自监督算法以及 Vision-Transformer,Swin-Transformer,BEiT,CVT,T2T,MLP_Mixer等视觉Transformer算法

PASSL Introduction PASSL is a Paddle based vision library for state-of-the-art Self-Supervised Learning research with PaddlePaddle. PASSL aims to acce

186 Dec 29, 2022