A Python library for differentiable optimal control on accelerators.

Related tags

Deep Learningtrajax
Overview

trajax

A Python library for differentiable optimal control on accelerators.

Trajax builds on JAX and hence code written with Trajax supports JAX's transformations. In particular, Trajax's solvers:

  1. Are automatically efficiently differentiable, via jax.grad.
  2. Scale up to parallel instances via jax.vmap and jax.pmap.
  3. Can run on CPUs, GPUs, and TPUs without code changes, and support end-to-end compilation with jax.jit.
  4. Are made available from Python, written with NumPy.

In Trajax, differentiation through the solution of a trajectory optimization problem is done more efficiently than by differentiating the solver implementation directly. Specifically, Trajax defines custom differentiation routines for its solvers. It registers these with JAX so that they are picked up whenever using JAX's autodiff features (e.g. jax.grad) to differentiate functions that call a Trajax solver.

This is a research project, not an official Google product.

Trajax is currently a work in progress, maintained by a few individuals at Google Research. While we are actively using Trajax in our own research projects, expect there to be bugs and rough edges compared to commercially available solvers.

Trajectory optimization and optimal control

We consider classical optimal control tasks concerning optimizing trajectories of a given discrete time dynamical system by solving the following problem. Given a cost function c, dynamics function f, and initial state x0, the goal is to compute:

argmin(lambda X, U: sum(c(X[t], U[t], t) for t in range(T)) + c_final(X[T]))

subject to the constraint that X[0] == x0 and that:

all(X[t + 1] == f(X[t], U[t], t) for t in range(T))

There are many resources for more on trajectory optimization, including Dynamic Programming and Optimal Control by Dimitri Bertsekas and Underactuated Robotics by Russ Tedrake.

API

In describing the API, it will be useful to abbreviate a JAX/NumPy floating point ndarray of shape (a, b, …) as a type denoted F[a, b, …]. Assume n is the state dimension, d is the control dimension, and T is the time horizon.

Problem setup convention/signature

Setting up a problem requires writing two functions, cost and dynamics, with type signatures:

cost(state: F[n], action: F[d], time_step: int) : float
dynamics(state: F[n], action: F[d], time_step: int) : F[n]

Note that even if a dimension n or d is 1, the corresponding state or action representation is still a rank-1 ndarray (i.e. a vector, of length 1).

Because Trajax uses JAX, the cost and dynamics functions must be written in a functional programming style as required by JAX. See the JAX readme for details on writing JAX-friendly functional code. By and large, functions that have no side effects and that use jax.numpy in place of numpy are likely to work.

Solvers

If we abbreviate the type of the above two functions as CostFn and DynamicsFn, then our solvers have the following type signature prefix in common:

solver(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], initial_actions: F[T, d], *solver_args, **solver_kwargs): SolverOutput

SolverOutput is a tuple of (F[T + 1, n], F[T, d], float, *solver_outputs). The first three tuple components represent the optimal state trajectory, optimal control sequence, and the optimal objective value achieved, respectively. The remaining *solver_outputs are specific to the particular solver (such as number of iterations, norm of the final gradient, etc.).

There are currently four solvers provided: ilqr, scipy_minimize, cem, and random_shooting. Each extends the signatures above with solver-specific arguments and output values. Details are provided in each solver function's docstring.

Underlying the ilqr implementation is a time-varying LQR routine, which solves a special case of the above problem, where costs are convex quadratic and dynamics are affine. To capture this, both are represented as matrices. This routine is also made available as tvlqr.

Objectives

One might want to write a custom solver, or work with an objective function for any other reason. To that end, Trajax offers the optimal control objective in the form of an API function:

objective(cost: CostFn, dynamics: DynamicsFn, initial_state: F[n], actions: F[T, d]): float

Combining this function with JAX's autodiff capabilities offers, for example, a starting point for writing a first-order custom solver. For example:

def improve_controls(cost, dynamics, U, x0, eta, num_iters):
  grad_fn = jax.grad(trajax.objective, argnums=(2,))
  for i in range(num_iters):
    U = U - eta * grad_fn(cost, dynamics, U, x0)
  return U

The solvers provided by Trajax are actually built around this objective function. For instance, the scipy_minimize solver simply calls scipy.minimize.minimize with the gradient and Hessian-vector product functions derived from objective using jax.grad and jax.hessian.

Limitations

​​Just as Trajax inherits the autodiff, compilation, and parallelism features of JAX, it also inherits its corresponding limitations. Functions such as the cost and dynamics given to a solver must be written using jax.numpy in place of standard numpy, and must conform to a functional style; see the JAX readme. Due to the complexity of trajectory optimizer implementations, initial compilation times can be long.

Owner
Google
Google ❤️ Open Source
Google
A clean and scalable template to kickstart your deep learning project 🚀 ⚡ 🔥

Lightning-Hydra-Template A clean and scalable template to kickstart your deep learning project 🚀 ⚡ 🔥 Click on Use this template to initialize new re

Hyunsoo Cho 1 Dec 20, 2021
Pytorch implementation for Semantic Segmentation/Scene Parsing on MIT ADE20K dataset

Semantic Segmentation on MIT ADE20K dataset in PyTorch This is a PyTorch implementation of semantic segmentation models on MIT ADE20K scene parsing da

MIT CSAIL Computer Vision 4.5k Jan 08, 2023
LegoDNN: a block-grained scaling tool for mobile vision systems

Table of contents 1 Introduction 1.1 Major features 1.2 Architecture 2 Code and Installation 2.1 Code 2.2 Installation 3 Repository of DNNs in vision

41 Dec 24, 2022
Code for "Primitive Representation Learning for Scene Text Recognition" (CVPR 2021)

Primitive Representation Learning Network (PREN) This repository contains the code for our paper accepted by CVPR 2021 Primitive Representation Learni

Ruijie Yan 76 Jan 02, 2023
A Survey on Deep Learning Technique for Video Segmentation

A Survey on Deep Learning Technique for Video Segmentation A Survey on Deep Learning Technique for Video Segmentation Wenguan Wang, Tianfei Zhou, Fati

Tianfei Zhou 112 Dec 12, 2022
Code for "SRHEN: Stepwise-Refining Homography Estimation Network via Parsing Geometric Correspondences in Deep Latent Space"

SRHEN This is a better and simpler implementation for "SRHEN: Stepwise-Refining Homography Estimation Network via Parsing Geometric Correspondences in

1 Oct 28, 2022
A general-purpose encoder-decoder framework for Tensorflow

READ THE DOCUMENTATION CONTRIBUTING A general-purpose encoder-decoder framework for Tensorflow that can be used for Machine Translation, Text Summariz

Google 5.5k Jan 07, 2023
Arbitrary Distribution Modeling with Censorship in Real Time 59 2 60 3 Bidding Advertising for KDD'21

Arbitrary_Distribution_Modeling This repo implements the Neighborhood Likelihood Loss (NLL) and Arbitrary Distribution Modeling (ADM, with Interacting

7 Jan 03, 2023
DropNAS: Grouped Operation Dropout for Differentiable Architecture Search

DropNAS: Grouped Operation Dropout for Differentiable Architecture Search DropNAS, a grouped operation dropout method for one-level DARTS, with better

weijunhong 4 Aug 15, 2022
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Thomas Neumann 117 Nov 27, 2022
A 10000+ hours dataset for Chinese speech recognition

WenetSpeech Official website | Paper A 10000+ Hours Multi-domain Chinese Corpus for Speech Recognition Download Please visit the official website, rea

310 Jan 03, 2023
Repo for the ACMMM20 submission: "Personalized breath based biometric authentication with wearable multimodality".

personalized-breath Repo for the ACMMM20 submission: "Personalized breath based biometric authentication with wearable multimodality". Guideline To ex

Manh-Ha Bui 2 Nov 15, 2021
[NeurIPS 2020] Semi-Supervision (Unlabeled Data) & Self-Supervision Improve Class-Imbalanced / Long-Tailed Learning

Rethinking the Value of Labels for Improving Class-Imbalanced Learning This repository contains the implementation code for paper: Rethinking the Valu

Yuzhe Yang 656 Dec 28, 2022
MetaShift: A Dataset of Datasets for Evaluating Contextual Distribution Shifts and Training Conflicts (ICLR 2022)

MetaShift: A Dataset of Datasets for Evaluating Distribution Shifts and Training Conflicts This repo provides the PyTorch source code of our paper: Me

88 Jan 04, 2023
Code release for the ICML 2021 paper "PixelTransformer: Sample Conditioned Signal Generation".

PixelTransformer Code release for the ICML 2021 paper "PixelTransformer: Sample Conditioned Signal Generation". Project Page Installation Please insta

Shubham Tulsiani 24 Dec 17, 2022
Read and write layered TIFF ImageSourceData and ImageResources tags

Read and write layered TIFF ImageSourceData and ImageResources tags Psdtags is a Python library to read and write the Adobe Photoshop(r) specific Imag

Christoph Gohlke 4 Feb 05, 2022
Using deep learning model to detect breast cancer.

Breast-Cancer-Detection Breast cancer is the most frequent cancer among women, with around one in every 19 women at risk. The number of cases of breas

1 Feb 13, 2022
Finding an Unsupervised Image Segmenter in each of your Deep Generative Models

Finding an Unsupervised Image Segmenter in each of your Deep Generative Models Description Recent research has shown that numerous human-interpretable

Luke Melas-Kyriazi 61 Oct 17, 2022
This is the replication package for paper submission: Towards Training Reproducible Deep Learning Models.

This is the replication package for paper submission: Towards Training Reproducible Deep Learning Models.

0 Feb 02, 2022
Predicting Semantic Map Representations from Images with Pyramid Occupancy Networks

This is the code associated with the paper Predicting Semantic Map Representations from Images with Pyramid Occupancy Networks, published at CVPR 2020.

Thomas Roddick 219 Dec 20, 2022