JAX
Overview & References
JAX is a Google research project built upon native Python and NumPy functions to improve machine research learning. The official JAX page describes the core of the project as "an extensible system for composable function transformations," which means that JAX takes the dynamic form of Python functions and converts them to JAXbased functions that work with gradients, backpropogation, justintime compiling, and other JAX augmentations.
JAX deals with more complex ideas such as neural networks and XLA, which are based in linear algebra and compilers, topics that are more advanced than much of what we cover in projects. The following is a list of incredibly useful resources for learning the foundations of JAX.

Basic Deep Learning: JAX Edition
Google lists the following code at the top of their JAX page:
import jax.numpy as jnp
from jax import grad, jit, vmap
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
def loss(params, inputs, targets):
preds = predict(params, inputs)
return jnp.sum((preds  targets)**2)
grad_loss = jit(grad(loss)) # compiled gradient evaluation function
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast perexample grads
This short example provides the two main functions of a deep learning algorithm, predict
and loss
, adapted for JAX functionality. We’ll break down the code segment as an entry analysis of both JAX and deep learning:

jax.numpy
is JAX’s adapted version of the NumPy API, created to prevent standard NumPy functionality from breaking JAX functions when the two packages differ. Make sure to usejax.numpy
functions instead of regularnumpy
functions. 
jax
is the main library, from which important functions likegrad
,jit
,vmap
, andpmap
are used. 
predict
simulates the neural network’s predictions based on the dot product of the weights and activation values added to the biases, all of which are given in theparams
parameter. The next layer of neurons is then calculated using the current layer, eventually returning the last layer whenparams
is fully processed. 
loss
uses standard meansquared error loss calculation, using the currentpredictions
and comparing them withtargets
that the user defines.
This mirrors standard NumPy deep learning very closely, but JAX shortens the runtime in very important ways which we soon describe.
Runtime Optimization
jit
Autograd and XLA are the two fundamental components of JAX, with XLA (accelerated linear algebra) handling the runtime and compiling aspects of JAX. Take the following example, adapted from the JAX page:
def slow_f(x):
# Elementwise ops see a large benefit from fusion
return x * x * x + x * 2.0 * x + x
x = jnp.ones((2000, 2000))
fast_f = jit(slow_f)
%timeit n10 r3 fast_f(x)
%timeit n10 r3 slow_f(x)
3.97 ms ± 2.53 ms per loop (mean ± std. dev. of 3 runs, 10 loops each) 52.1 ms ± 1.83 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
JAX is designed to work with CPUs, GPUs, and TPUs, each a quicker processor than the last. THe example output comes from the most basic CPU setup, and JAX’s jit
function still ran significantly faster than the native Python function.
The discussion around compile times and runtimes seems like an arbitrary conversation when we’re dealing with small datasets — who cares if my code executes in 5 milliseconds instead of 15? This optimization, however, is vital for neural networks.
Consider a simple deep learning task of identifying a lowercase letter from an image with 36x36 pixel resolution. The input layer would have 36 * 36 = 1296 neurons and the output layer would have 26 neurons, one for every letter. Without any hidden layers, we’re already over 33,000 connections, and in reality, we’d need hidden layers for determining tiny parts to letters, patterns, or some other method for transitioning between image and output. A program that might take an hour on a standard system might now take 30 seconds using TPUs and jit
compiling — now the conversation is not arbitrary.
vmap
vmap
is a function that provides "autovectorization" for whatever batch you have. Batches are essentially variablysized samples of your population of training data used in one iteration, after which the model is updated. Imagine the simple solution of looping through every image in your batch, resulting in a vector with the activation values of the image. This vector is then multiplied by the model matrix, resulting in a different matrix. This process works, but it is incredibly slow, as a different intermediate matrix is created with each iteration.
By using vmap
, loops are pushed to the most primitive level possible. This speeds up compilation time as iterating over simple elements is quicker than the same with complex elements. For our purposes, this means that the activation vectors are compiled as an activation matrix — as Google puts it, "at every layer, we’re doing matrixmatrix multiplication rather than matrixvector multiplication."
The code for this has a unique format. Pay close attention to the following implementation:
from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
vmap
wraps the predict
function in parentheses, then takes the parameters and/or input batch wrapped in another set of parentheses.
Autodifferentiation
If you recall the XLAAutograd duo that composed JAX, autodifferentiation comes from Autograd and shares its API. JAX uses grad
for calculating gradients, which allows for differentiation to any order.
We’ll recontextualize why this matters for machine learning. The goal of any good model is to reduce the error present — we obviously want the model to be good at predicting things, otherwise there’s no point. The gradient of a function, in this case the error, will indicate the direction to move to minimize the function. In other words, in anydimensional space, the gradient will tell us which weights in the model need adjusting.
Once you understand the importance of gradients, the function implementation becomes trivial — it just takes a number as a parameter to evaluate the gradient at that point. Google gives the example of the hyperbolic tangent function, and we get the following results after using grad
:
def tanh(x): # Define a function
y = jnp.exp(2.0 * x)
return (1.0  y) / (1.0 + y)
grad_tanh = grad(tanh)
print(grad_tanh(2.0))
0.07065082
And that’s it! Combining all of the features we’ve shown will give you a great leap into your machine learning project, and it’s all streamlined to make the code easier to follow.