JAX is a Google research project built upon native Python and NumPy functions to improve machine research learning. The official JAX page describes the core of the project as "an extensible system for composable function transformations," which means that JAX takes the dynamic form of Python functions and converts them to JAX-based functions that work with gradients, backpropogation, just-in-time compiling, and other JAX augmentations.
JAX deals with more complex ideas such as neural networks and XLA, which are based in linear algebra and compilers, topics that are more advanced than much of what we cover in projects. The following is a list of incredibly useful resources for learning the foundations of JAX.
Google lists the following code at the top of their JAX page:
import jax.numpy as jnp from jax import grad, jit, vmap def predict(params, inputs): for W, b in params: outputs = jnp.dot(inputs, W) + b inputs = jnp.tanh(outputs) # inputs to the next layer return outputs # no activation on last layer def loss(params, inputs, targets): preds = predict(params, inputs) return jnp.sum((preds - targets)**2) grad_loss = jit(grad(loss)) # compiled gradient evaluation function perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads
This short example provides the two main functions of a deep learning algorithm,
loss, adapted for JAX functionality. We’ll break down the code segment as an entry analysis of both JAX and deep learning:
jax.numpyis JAX’s adapted version of the NumPy API, created to prevent standard NumPy functionality from breaking JAX functions when the two packages differ. Make sure to use
jax.numpyfunctions instead of regular
jaxis the main library, from which important functions like
predictsimulates the neural network’s predictions based on the dot product of the weights and activation values added to the biases, all of which are given in the
paramsparameter. The next layer of neurons is then calculated using the current layer, eventually returning the last layer when
paramsis fully processed.
lossuses standard mean-squared error loss calculation, using the current
predictionsand comparing them with
targetsthat the user defines.
This mirrors standard NumPy deep learning very closely, but JAX shortens the runtime in very important ways which we soon describe.
Autograd and XLA are the two fundamental components of JAX, with XLA (accelerated linear algebra) handling the runtime and compiling aspects of JAX. Take the following example, adapted from the JAX page:
def slow_f(x): # Element-wise ops see a large benefit from fusion return x * x * x + x * 2.0 * x + x x = jnp.ones((2000, 2000)) fast_f = jit(slow_f) %timeit -n10 -r3 fast_f(x) %timeit -n10 -r3 slow_f(x)
3.97 ms ± 2.53 ms per loop (mean ± std. dev. of 3 runs, 10 loops each) 52.1 ms ± 1.83 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
JAX is designed to work with CPUs, GPUs, and TPUs, each a quicker processor than the last. THe example output comes from the most basic CPU setup, and JAX’s
jit function still ran significantly faster than the native Python function.
The discussion around compile times and runtimes seems like an arbitrary conversation when we’re dealing with small datasets — who cares if my code executes in 5 milliseconds instead of 15? This optimization, however, is vital for neural networks.
Consider a simple deep learning task of identifying a lowercase letter from an image with 36x36 pixel resolution. The input layer would have 36 * 36 = 1296 neurons and the output layer would have 26 neurons, one for every letter. Without any hidden layers, we’re already over 33,000 connections, and in reality, we’d need hidden layers for determining tiny parts to letters, patterns, or some other method for transitioning between image and output. A program that might take an hour on a standard system might now take 30 seconds using TPUs and
jit compiling — now the conversation is not arbitrary.
vmap is a function that provides "auto-vectorization" for whatever batch you have. Batches are essentially variably-sized samples of your population of training data used in one iteration, after which the model is updated. Imagine the simple solution of looping through every image in your batch, resulting in a vector with the activation values of the image. This vector is then multiplied by the model matrix, resulting in a different matrix. This process works, but it is incredibly slow, as a different intermediate matrix is created with each iteration.
vmap, loops are pushed to the most primitive level possible. This speeds up compilation time as iterating over simple elements is quicker than the same with complex elements. For our purposes, this means that the activation vectors are compiled as an activation matrix — as Google puts it, "at every layer, we’re doing matrix-matrix multiplication rather than matrix-vector multiplication."
The code for this has a unique format. Pay close attention to the following implementation:
from jax import vmap predictions = vmap(partial(predict, params))(input_batch) # or, alternatively predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
vmap wraps the
predict function in parentheses, then takes the parameters and/or input batch wrapped in another set of parentheses.
If you recall the XLA-Autograd duo that composed JAX, autodifferentiation comes from Autograd and shares its API. JAX uses
grad for calculating gradients, which allows for differentiation to any order.
We’ll recontextualize why this matters for machine learning. The goal of any good model is to reduce the error present — we obviously want the model to be good at predicting things, otherwise there’s no point. The gradient of a function, in this case the error, will indicate the direction to move to minimize the function. In other words, in any-dimensional space, the gradient will tell us which weights in the model need adjusting.
Once you understand the importance of gradients, the function implementation becomes trivial — it just takes a number as a parameter to evaluate the gradient at that point. Google gives the example of the hyperbolic tangent function, and we get the following results after using
def tanh(x): # Define a function y = jnp.exp(-2.0 * x) return (1.0 - y) / (1.0 + y) grad_tanh = grad(tanh) print(grad_tanh(2.0))
And that’s it! Combining all of the features we’ve shown will give you a great leap into your machine learning project, and it’s all streamlined to make the code easier to follow.