STAT 39000: Project 11 — Spring 2022
Motivation: Machine learning and AI are huge buzzwords in industry, and two of the most popular tools surrounding said topics are the pytorch
and tensorflow
libraries — JAX
is another tool by Google growing in popularity. These tools are libraries used to build and use complex models. If available, they can take advantage of GPUs to speed up parallelizable code by a hundred or even thousand fold.
Context: This is the third of a series of 4 projects focused on using pytorch
and JAX
to solve numeric problems.
Scope: Python, JAX
Dataset(s)
The following questions will use the following dataset(s):

/depot/datamine/data/sim/train.csv
Questions
Question 1
JAX
is a library for high performance computing. It falls into the same category as other popular packages like: numpy
, pytorch
, and tensorflow
. JAX
is a product of Google / Deepmind that takes a completely different approach than their other product, tensorflow
.
Like the the other popular libraries, JAX
can utilize GPUs/TPUs to greatly speed up computation. Let’s take a look.
Here is a snippet of code from previous projects that uses pytorch
and calculates predictions 10000 times.
Of course, this is the same calculation since our betas aren’t being updated yet, but just bear with me. 
import pandas as pd
import torch
import jax
import jax.numpy as jnp
dat = pd.read_csv("/depot/datamine/data/sim/train.csv")
%%time
x_train = torch.tensor(dat['x'].to_numpy())
y_train = torch.tensor(dat['y'].to_numpy())
beta0 = torch.tensor(5, requires_grad=True, dtype=torch.float)
beta1 = torch.tensor(4, requires_grad=True, dtype=torch.float)
beta2 = torch.tensor(3, requires_grad=True, dtype=torch.float)
num_epochs = 10000
for idx in range(num_epochs):
y_predictions = beta0 + beta1*x_train + beta2*x_train**2
Approximately how much time does it take to run this second chunk of code (after we have already read in our data)?
Here is the equivalent JAX
code:
%%time
x_train = jnp.array(dat['x'].to_numpy())
y_train = jnp.array(dat['y'].to_numpy())
beta0 = 5
beta1 = 4
beta2 = 3
num_epochs = 10000
for idx in range(num_epochs):
y_predictions = beta0 + beta1*x_train + beta2*x_train**2
How much time does this take?
At this point in time you may be questioning how JAX
could possibly be worth it. At first glance, the new code does look a bit cleaner, but not clean enough to use code that is around 3 times slower.
This is where JAX
first trick, or transformation comes in to play. When we refer to transformation, think of it as an operation on some function that produces another function as an output.
The first transformation we will talk about is jax.jit
. "JIT" stands for "Just In Time" and refers to a "Just in time" compiler. Essentially, just in time compilation is a trick that can be used to greatly speed up the execution of some code by compiling the code. In a nutshell, the compiled version of the code has a wide variety of optimizations that speed your code up.
Lots of our computation time is spent inside our loop, specifically when we are calculating our y_predictions
. Let’s see if we can use the jit transformation to speed up our JAX
code with little to no extra effort.
Write a function called model
that accepts two arguments. The first argument is a tuple containing our parameters: beta0
, beta1
, and beta2
. The second is our input to our function (our x values) called x
. model
should then unpack our tuple of parameters into beta0
, beta1
, and beta2
, and then return predictions (the same formula shown above, twice). Replace the code as follows.
# replace this line
y_predictions = beta0 + beta1*x_train + beta2*x_train**2
# with
y_predictions = model((beta0, beta1, beta2), x_train)
Run and time the code again. No difference? Well, we didn’t use our jit transformation yet! Using the transformation is easy. JAX
provides two equivalent ways. You can either decorate your model
function with the @jax.jit
decorator, or simply apply the transformation to your function and save the new, jit compiled function and use it instead.
def my_func(x):
return x**2
@jax.jit
def my_func_jit1(x):
return x**2
my_func_jit2 = jax.jit(my_func)
Rerun your code using the JIT transformation. Is it faster now?
It is important to note that 

Code used to solve this problem.

Output from running the code.
Question 2
At this point in time you may be considering slapping @jax.jit
on all your functions — unfortunately it is not quite so simple! First of all, the previous comparison was actually not fair at all. Why? JAX
has asynchronous dispatch by default. What this means is that, by default, JAX
will return control to Python as soon as possible, even if it is before the function has been fully evaluated.
What does this mean? It means that our finished example from question 1 may be returning a notyetcomplete result, greatly throwing off our performance measurements. So how can we synchronously wait for execution to finish? This is easy, simply use the block_until_ready
method built in to your jit compiled model
function.
def my_func(x):
return x**2
@jax.jit
def my_func_jit1(x):
return x**2
my_func_jit2 = jax.jit(my_func)
my_func_jit1.block_until_ready()
# or
my_func_jit2.block_until_ready()
Rerun your code from before — you should find that the results are unchanged, it turns out that really was a serious speedup from before. Great. Let’s move on from this part of things. Back to our question. Why can’t we just slap @jax.jit
on any function and expect a speedup?
Take the following function.
def train(params, x, y, epochs):
def _model(params, x):
beta0, beta1, beta2 = params
return beta0 + beta1*x + beta2*x**2
mses = []
for _ in range(epochs):
y_predictions = _model(params, x_train)
mse = jnp.sum((y_predictions  y)**2)
fast_train = jax.jit(train)
fast_train((beta0, beta1, beta2), x_train, y_train, 10000)
If you try running it you will get an error saying something along the lines of "TracerIntegerConversionError". The problem with this function, and why it cannot be jit compiled, is the epochs
argument. By default, JAX
tries to "trace" the parameters to determine its effect on inputs of a specific shape and type. Control flow cannot depend on traced values — in this case, epochs
is relied on in order to determine how many times to loop. In addition, the shapes of all input and output values of a function must be able to be determined ahead of time.
How do we fix this? Well, it is not always possible, however, we can choose to select parameters to be static or not traced. If a parameter is marked as static, or not traced, it can be JIT compiled. The catch is that any time a call to the function is made and the value of the static parameter is changed, the function will have to be recompiled with that new static value. So, this is only useful if you will only occasionally change the parameter. This sounds like our case! We only want to occasionally change the number of epochs, so perfect.
You can mark a parameter as static by specifying the argument position using the static_argnums
argument to jax.jit
, or by specifying the argument name using the static_argnames
argument to jax.jit
.
Force the epochs
argument to be static, and use the jax.jit
decorator to compile the function. Test out the function, in order using the following code cells.
%%time
fast_train((beta0, beta1, beta2), x_train, y_train, 10000)
%%time
fast_train((beta0, beta1, beta2), x_train, y_train, 10000)
%%time
fast_train((beta0, beta1, beta2), x_train, y_train, 9999)
Do your best to explain why the last code cell was once again slower.
If you aren’t sure why, reread the question text — we hint at the "catch" in the text. 

Code used to solve this problem.

Output from running the code.
Question 3
We learned that one of the coolest parts of the pytorch
package was the automatic differentiation feature. It saves a lot of time doing some calculus and coding up resulting equations. Recall that in pytorch
this differentiation was baked into the backward
method of our MSE. This is quite different from the way we think about the equations when looking at the math, and is quite confusing.
JAX
has the same functionality, but it is much cleaner and easier to use. We will provide you with a simple example, and explain the math as we go along.
Let’s say our function is $f(x) = 2x^2$. We can start by writing a function.
def squared(x):
return x**2
Fantastic, so far pretty easy.
The derivative w.r.t. x
is $4x$. Doing this in JAX
is as easy as applying the jax.grad
transformation to the function.
squared_deriv = jax.grad(squared)
Okay, test out both functions as follows.
my_array = jnp.array([1.0, 2.0, 3.0])
squared(4.0) # 16.0
squared(my_array) # [1.0, 4.0, 9.0]
squared_deriv(4.0) # 16.0
squared_deriv(my_array) # uh oh! Something went wrong!
A very perceptive student pointed out that we originally passed array values that were ints to 
On the last line, you probably received a message or error saying something along the lines of "Gradient only defined for scalarouput functions. What this means is that the resulting derivative function is not vectorized. As you may have guessed, this is easily fixed. Another key transformation that JAX
provides is called vmap
. vmap
takes a function and creates a vectorized version of the function. See the following.
vectorized_deriv_squared = jax.vmap(squared_deriv)
vectorized_deriv_squared(my_array) # [4.0, 8.0, 12.0]
Heck yes! That is pretty cool, and very powerful. It is so much more understandable than the magic happening in the pytorch
world too!
Dig back into your memory about any equation you may have had in the past where you needed to find a derivative. Create a Python function, find the derivative, and test it out on both a single value, like 4.0
as well as an array, like jnp.array([1.0,2.0,3.0])
. Don’t hesitate to make it extra fun and include some functions like jnp.cos
, jnp.sin
, etc. Did everything work as expected?

Code used to solve this problem.

Output from running the code.
Question 4
Okay, great, but that was a straightforward example. What if we have multiple parameters we’d like to take partial derivatives with respect to? jax.grad
can handle that too!
Read this excellent example in the official JAX documentation.
The JAX documentation is pretty excellent! If you are interested, I would recommend reading through it, it is very well written. 
Given the following (should be familiar) model, create a function called get_partials
that accepts an argument params
(a tuple containing beta0, beta1, and beta2, in order) and an argument x
, that can be either a single value (a scalar), or a jnp.array
with multiple values. This function should return a single value for each of the 3 partial derivatives, where x
is plugged into each of the 3 partial derivatives to calculate each value, OR, 3 arrays of results where there are 3 values for each value in the input array.
@jax.jit
def model(params, x):
beta0, beta1, beta2 = params
return beta0 + beta1*x + beta2*x**2
model((1.0, 2.0, 3.0), 4.0) # 57
model((1.0, 2.0, 3.0), jnp.array((4.0, 5.0, 6.0))) # [57, 86, 121]
Since we have 3 parameters, we will have 3 partial derivatives, and our new function should output a value for each of our 3 partial derivatives, for each value passed as x
. To be explicit and allow you to check your work, the results should be the same as the following.
params = (5.0, 4.0, 3.0)
get_partials(params, x_train)
((DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32, weak_type=True), DeviceArray([15.94824 , 11.117526 , 10.4780855 , 8.867778 , 8.799367 , 8.140428 , 7.8744955 , 7.72306 , 6.9281745 , 6.2731333 , 6.2275624 , 5.7271757 , 5.1857414 , 5.150156 , 4.8792663 , 4.663747 , 4.58701 , 4.1310377 , 4.0215836 , 4.019455 , 3.5578184 , 3.4748363 , 3.4004524 , 3.1221437 , 3.0421085 , 2.941131 , 2.8603644 , 2.8294718 , 2.7050996 , 1.9493109 , 1.7873074 , 1.2773769 , 1.1804487 , 1.1161369 , 1.1154363 , 0.8590109 , 0.81457555, 0.7386795 , 0.57577926, 0.5536533 , 0.51964295, 0.12334588, 0.11549235, 0.14650635, 0.24305418, 0.2876291 , 0.3942046 , 0.6342466 , 0.8256681 , 1.2047065 , 1.9168468 , 1.9493027 , 1.9587051 , 2.3490443 , 2.7015095 , 2.8161156 , 2.8648841 , 2.946292 , 3.1312609 , 3.1810293 , 4.503682 , 5.114829 , 5.1591663 , 5.205859 , 5.622392 , 5.852435 , 6.21313 , 6.4066596 , 6.655888 , 6.781989 , 7.1651325 , 7.957219 , 8.349893 , 11.266327 , 13.733376 ], dtype=float32, weak_type=True), DeviceArray([2.54346375e+02, 1.23599388e+02, 1.09790276e+02, 7.86374817e+01, 7.74288559e+01, 6.62665634e+01, 6.20076790e+01, 5.96456566e+01, 4.79996033e+01, 3.93521996e+01, 3.87825356e+01, 3.28005409e+01, 2.68919144e+01, 2.65241070e+01, 2.38072395e+01, 2.17505341e+01, 2.10406590e+01, 1.70654716e+01, 1.61731339e+01, 1.61560173e+01, 1.26580715e+01, 1.20744877e+01, 1.15630760e+01, 9.74778175e+00, 9.25442410e+00, 8.65025234e+00, 8.18168449e+00, 8.00591087e+00, 7.31756353e+00, 3.79981303e+00, 3.19446778e+00, 1.63169169e+00, 1.39345896e+00, 1.24576163e+00, 1.24419820e+00, 7.37899661e01, 6.63533330e01, 5.45647442e01, 3.31521749e01, 3.06531966e01, 2.70028800e01, 1.52142067e02, 1.33384829e02, 2.14641113e02, 5.90753369e02, 8.27304944e02, 1.55397251e01, 4.02268738e01, 6.81727827e01, 1.45131791e+00, 3.67430139e+00, 3.79978085e+00, 3.83652544e+00, 5.51800919e+00, 7.29815340e+00, 7.93050718e+00, 8.20756149e+00, 8.68063641e+00, 9.80479431e+00, 1.01189480e+01, 2.02831535e+01, 2.61614761e+01, 2.66169968e+01, 2.71009693e+01, 3.16112938e+01, 3.42509956e+01, 3.86029854e+01, 4.10452881e+01, 4.43008461e+01, 4.59953766e+01, 5.13391228e+01, 6.33173370e+01, 6.97207031e+01, 1.26930122e+02, 1.88605606e+02], dtype=float32, weak_type=True)),)
get_partials((1.0,2.0,3.0), jnp.array((4.0,)))
((DeviceArray([1.], dtype=float32, weak_type=True), DeviceArray([4.], dtype=float32, weak_type=True), DeviceArray([16.], dtype=float32, weak_type=True)),)
To specify which arguments to take the partial derivative with respect to, use the 
To vectorize your resulting function, use This is confusing! However, considering how powerful and all that is baked into the 

Code used to solve this problem.

Output from running the code.
Please make sure to double check that your submission is complete, and contains all of your code and output before submitting. If you are on a spotty internet connect ion, it is recommended to download your submission after submitting it to make sure what you think you submitted, was what you actually submitted. In addition, please review our submission guidelines before submitting your project. 