Geometric Algebra package for JAX

Overview

JAXGA - JAX Geometric Algebra

Build status PyPI

GitHub | Docs

JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing only the non-zero basis blade coefficients. It makes use of JAX's just-in-time (JIT) compilation by first precomputing blade indices and signs and then JITting the function doing the actual calculations.

Installation

Install using pip: pip install jaxga

Requirements:

Usage

Unlike most other Geometric Algebra packages, it is not necessary to pre-specify an algebra. JAXGA can either be used with the MultiVector class or by using lower-level functions which is useful for example when using JAX's jit or automatic differentaition.

The MultiVector class provides operator overloading and is constructed with an array of values and their corresponding basis blades. The basis blades are encoded as tuples, for example the multivector 2 e_1 + 4 e_23 would have the values [2, 4] and the basis blade tuple ((1,), (2, 3)).

MultiVector example

import jax.numpy as jnp
from jaxga.mv import MultiVector

a = MultiVector(
    values=2 * jnp.ones([1], dtype=jnp.float32),
    indices=((1,),)
)
# Alternative: 2 * MultiVector.e(1)

b = MultiVector(
    values=4 * jnp.ones([2], dtype=jnp.float32),
    indices=((2, 3),)
)
# Alternative: 4 * MultiVector.e(2, 3)

c = a * b
print(c)

Output: Multivector(8.0 e_{1, 2, 3})

The lower-level functions also deal with values and blades. Functions are provided that take the blades and return a function that does the actual calculation. The returned function is JITted and can also be automatically differentiated with JAX. Furthermore, some operations like the geometric product take a signature function that takes a basis vector index and returns their square.

Lower-level function example

import jax.numpy as jnp
from jaxga.signatures import positive_signature
from jaxga.ops.multiply import get_mv_multiply

a_values = 2 * jnp.ones([1], dtype=jnp.float32)
a_indices = ((1,),)

b_values = 4 * jnp.ones([1], dtype=jnp.float32)
b_indices = ((2, 3),)

mv_multiply, c_indices = get_mv_multiply(a_indices, b_indices, positive_signature)
c_values = mv_multiply(a_values, b_values)
print("C indices:", c_indices, "C values:", c_values)

Output: C indices: ((1, 2, 3),) C values: [8.]

Some notes

  • Both the MultiVector and lower-level function approaches support batches: the axes after the first one (which indexes the basis blades) are treated as batch indices.
  • The MultiVector class can also take a signature in its constructor (default is square to 1 for all basis vectors). Doing operations with MultiVectors with different signatures is undefined.
  • The jaxga.signatures submodule contains a few predefined signature functions.
  • get_mv_multiply and similar functions cache their result by their inputs.
  • The flaxmodules submodule provides flax (a popular neural network library for jax) modules with Geometric Algebra operations.
  • Because we don't deal with a specific algebra, the dual needs an input that specifies the dimensionality of the space in which we want to find the dual element.

Benchmarks

N-d vector * N-d vector, batch size 100, N=range(1, 10), CPU

JaxGA stores only the non-zero basis blade coefficients. TFGA and Clifford on the other hand store all GA elements as full multivectors including all zeros. As a result, JaxGA does better than these for high dimensional algebras.

Below is a benchmark of the geometric product of two vectors with increasing dimensionality from 1 to 9. 100 vectors are multiplied at a time.

JAXGA (CPU) tfga (CPU) clifford
benchmark-results benchmark-results benchmark-results

N-d vector * N-d vector, batch size 100, N=range(1, 50, 5), CPU

Below is a benchmark for higher dimensions that TFGA and Clifford could not handle. Note that the X axis isn't sorted naturally.

benchmark-results

Owner
Robin Kahlow
Software / Machine Learning Engineer
Robin Kahlow
A PyTorch implementation of "From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network" (ICCV2021)

From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network The official code of VisionLAN (ICCV2021). VisionLAN successfully a

81 Dec 12, 2022
Optimus: the first large-scale pre-trained VAE language model

Optimus: the first pre-trained Big VAE language model This repository contains source code necessary to reproduce the results presented in the EMNLP 2

314 Dec 19, 2022
Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021

Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021 [WIP] The code for CVPR 2021 paper 'Disentangled Cycle Consistency for H

ChongjianGE 94 Dec 11, 2022
Pytorch implementation of Supporting Clustering with Contrastive Learning, NAACL 2021

Supporting Clustering with Contrastive Learning SCCL (NAACL 2021) Dejiao Zhang, Feng Nan, Xiaokai Wei, Shangwen Li, Henghui Zhu, Kathleen McKeown, Ram

231 Jan 05, 2023
Code for the paper: "On the Bottleneck of Graph Neural Networks and Its Practical Implications"

On the Bottleneck of Graph Neural Networks and its Practical Implications This is the official implementation of the paper: On the Bottleneck of Graph

75 Dec 22, 2022
Official Keras Implementation for UNet++ in IEEE Transactions on Medical Imaging and DLMIA 2018

UNet++: A Nested U-Net Architecture for Medical Image Segmentation UNet++ is a new general purpose image segmentation architecture for more accurate i

Zongwei Zhou 1.8k Jan 07, 2023
AWS provides a Python SDK, "Boto3" ,which can be used to access the AWS-account from the local.

Boto3 - The AWS SDK for Python Boto3 is the Amazon Web Services (AWS) Software Development Kit (SDK) for Python, which allows Python developers to wri

Shreyas Srivastava 1 Oct 25, 2021
Mesh TensorFlow: Model Parallelism Made Easier

Mesh TensorFlow - Model Parallelism Made Easier Introduction Mesh TensorFlow (mtf) is a language for distributed deep learning, capable of specifying

1.3k Dec 26, 2022
MultiTaskLearning - Multi Task Learning for 3D segmentation

Multi Task Learning for 3D segmentation Perception stack of an Autonomous Drivin

2 Sep 22, 2022
Code and models for ICCV2021 paper "Robust Object Detection via Instance-Level Temporal Cycle Confusion".

Robust Object Detection via Instance-Level Temporal Cycle Confusion This repo contains the implementation of the ICCV 2021 paper, Robust Object Detect

Xin Wang 69 Oct 13, 2022
Recreate CenternetV2 based on MMDET.

Introduction This project is trying to Recreate CenternetV2 based on MMDET, which is proposed in paper Probabilistic two-stage detection. This project

25 Dec 09, 2022
Nicely is a real-time Feedback and Intervention Program Depression is a prevalent issue across all age groups, socioeconomic classes, and cultural identities.

Nicely is a real-time Feedback and Intervention Program Depression is a prevalent issue across all age groups, socioeconomic classes, and cultural identities.

1 Jan 16, 2022
GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond

GCNet for Object Detection By Yue Cao, Jiarui Xu, Stephen Lin, Fangyun Wei, Han Hu. This repo is a official implementation of "GCNet: Non-local Networ

Jerry Jiarui XU 1.1k Dec 29, 2022
Augmentation for Single-Image-Super-Resolution

SRAugmentation Augmentation for Single-Image-Super-Resolution Implimentation CutBlur Cutout CutMix Cutup CutMixup Blend RGBPermutation Identity OneOf

Yubo 6 Jun 27, 2022
Implementation for Panoptic-PolarNet (CVPR 2021)

Panoptic-PolarNet This is the official implementation of Panoptic-PolarNet. [ArXiv paper] Introduction Panoptic-PolarNet is a fast and robust LiDAR po

Zixiang Zhou 126 Jan 01, 2023
An abstraction layer for mathematical optimization solvers.

MathOptInterface Documentation Build Status Social An abstraction layer for mathematical optimization solvers. Replaces MathProgBase. Citing MathOptIn

JuMP-dev 284 Jan 04, 2023
SPCL: A New Framework for Domain Adaptive Semantic Segmentation via Semantic Prototype-based Contrastive Learning

SPCL SPCL: A New Framework for Domain Adaptive Semantic Segmentation via Semantic Prototype-based Contrastive Learning Update on 2021/11/25: ArXiv Ver

Binhui Xie (谢斌辉) 11 Oct 29, 2022
Code release for NeurIPS 2020 paper "Co-Tuning for Transfer Learning"

CoTuning Official implementation for NeurIPS 2020 paper Co-Tuning for Transfer Learning. [News] 2021/01/13 The COCO 70 dataset used in the paper is av

THUML @ Tsinghua University 35 Sep 23, 2022
Neural network chess engine trained on Gary Kasparov's games.

Neural Chess It's not the best chess engine, but it is a chess engine. Proof of concept neural network chess engine (feed-forward multi-layer perceptr

3 Jun 22, 2022
Reference implementation for Structured Prediction with Deep Value Networks

Deep Value Network (DVN) This code is a python reference implementation of DVNs introduced in Deep Value Networks Learn to Evaluate and Iteratively Re

Michael Gygli 55 Feb 02, 2022