Overview & References

JAX is a Google research project built upon native Python and NumPy functions to improve machine research learning. The official JAX page describes the core of the project as "an extensible system for composable function transformations," which means that JAX takes the dynamic form of Python functions and converts them to JAX-based functions that work with gradients, backpropogation, just-in-time compiling, and other JAX augmentations.

JAX deals with more complex ideas such as neural networks and XLA, which are based in linear algebra and compilers, topics that are more advanced than much of what we cover in projects. The following is a list of incredibly useful resources for learning the foundations of JAX.

  • The GitHub JAX page. We linked this earlier, but it’s your best starting point. Everything you need to understand the project is here and you can branch into Autograd, XLA, neural networks, TPUs, and anything else you might want to understand. Here’s a Google Cloud Tech YouTube video that talks through the content of the GitHub page.

  • This TensorFlow video provides a slower, in-depth look at how the important features of JAX operate.

  • There’s an excellent video series on neural networks and deep learning from 3Blue1Brown that explains how linear algebra creates the foundation for neural networks which, for our purposes, explains why grad is so important when using JAX.

  • As always, library documentation is integral to understanding the inner workings and specifics of your code.

Basic Deep Learning: JAX Edition

Google lists the following code at the top of their JAX page:

import jax.numpy as jnp
from jax import grad, jit, vmap

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)  # inputs to the next layer
  return outputs                # no activation on last layer

def loss(params, inputs, targets):
  preds = predict(params, inputs)
  return jnp.sum((preds - targets)**2)

grad_loss = jit(grad(loss))  # compiled gradient evaluation function
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0)))  # fast per-example grads

This short example provides the two main functions of a deep learning algorithm, predict and loss, adapted for JAX functionality. We’ll break down the code segment as an entry analysis of both JAX and deep learning:

  • jax.numpy is JAX’s adapted version of the NumPy API, created to prevent standard NumPy functionality from breaking JAX functions when the two packages differ. Make sure to use jax.numpy functions instead of regular numpy functions.

  • jax is the main library, from which important functions like grad, jit, vmap, and pmap are used.

  • predict simulates the neural network’s predictions based on the dot product of the weights and activation values added to the biases, all of which are given in the params parameter. The next layer of neurons is then calculated using the current layer, eventually returning the last layer when params is fully processed.

  • loss uses standard mean-squared error loss calculation, using the current predictions and comparing them with targets that the user defines.

This mirrors standard NumPy deep learning very closely, but JAX shortens the runtime in very important ways which we soon describe.

Runtime Optimization


Autograd and XLA are the two fundamental components of JAX, with XLA (accelerated linear algebra) handling the runtime and compiling aspects of JAX. Take the following example, adapted from the JAX page:

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x * x + x * 2.0 * x + x

x = jnp.ones((2000, 2000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)
%timeit -n10 -r3 slow_f(x)
3.97 ms ± 2.53 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
52.1 ms ± 1.83 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

JAX is designed to work with CPUs, GPUs, and TPUs, each a quicker processor than the last. THe example output comes from the most basic CPU setup, and JAX’s jit function still ran significantly faster than the native Python function.

The discussion around compile times and runtimes seems like an arbitrary conversation when we’re dealing with small datasets — who cares if my code executes in 5 milliseconds instead of 15? This optimization, however, is vital for neural networks.

Consider a simple deep learning task of identifying a lowercase letter from an image with 36x36 pixel resolution. The input layer would have 36 * 36 = 1296 neurons and the output layer would have 26 neurons, one for every letter. Without any hidden layers, we’re already over 33,000 connections, and in reality, we’d need hidden layers for determining tiny parts to letters, patterns, or some other method for transitioning between image and output. A program that might take an hour on a standard system might now take 30 seconds using TPUs and jit compiling — now the conversation is not arbitrary.


vmap is a function that provides "auto-vectorization" for whatever batch you have. Batches are essentially variably-sized samples of your population of training data used in one iteration, after which the model is updated. Imagine the simple solution of looping through every image in your batch, resulting in a vector with the activation values of the image. This vector is then multiplied by the model matrix, resulting in a different matrix. This process works, but it is incredibly slow, as a different intermediate matrix is created with each iteration.

By using vmap, loops are pushed to the most primitive level possible. This speeds up compilation time as iterating over simple elements is quicker than the same with complex elements. For our purposes, this means that the activation vectors are compiled as an activation matrix — as Google puts it, "at every layer, we’re doing matrix-matrix multiplication rather than matrix-vector multiplication."

The code for this has a unique format. Pay close attention to the following implementation:

from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)

vmap wraps the predict function in parentheses, then takes the parameters and/or input batch wrapped in another set of parentheses.


If you recall the XLA-Autograd duo that composed JAX, autodifferentiation comes from Autograd and shares its API. JAX uses grad for calculating gradients, which allows for differentiation to any order.

We’ll recontextualize why this matters for machine learning. The goal of any good model is to reduce the error present — we obviously want the model to be good at predicting things, otherwise there’s no point. The gradient of a function, in this case the error, will indicate the direction to move to minimize the function. In other words, in any-dimensional space, the gradient will tell us which weights in the model need adjusting.

Once you understand the importance of gradients, the function implementation becomes trivial — it just takes a number as a parameter to evaluate the gradient at that point. Google gives the example of the hyperbolic tangent function, and we get the following results after using grad:

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)

And that’s it! Combining all of the features we’ve shown will give you a great leap into your machine learning project, and it’s all streamlined to make the code easier to follow.