STAT 39000: Project 11 — Spring 2022

Motivation: Machine learning and AI are huge buzzwords in industry, and two of the most popular tools surrounding said topics are the `pytorch` and `tensorflow` libraries — `JAX` is another tool by Google growing in popularity. These tools are libraries used to build and use complex models. If available, they can take advantage of GPUs to speed up parallelizable code by a hundred or even thousand fold.

Context: This is the third of a series of 4 projects focused on using `pytorch` and `JAX` to solve numeric problems.

Scope: Python, JAX

Learning Objectives
• Compare and contrast `pytorch` and `JAX`.

• Differentiate functions using `JAX`.

• Understand what "JIT" is and why it is useful.

• Understand when a value or operation should be static vs. traced.

• Vectorize functions using the `vmap` function from `JAX`.

• How do random number generators work in `JAX`?

Make sure to read about, and use the template found here, and the important information about projects submissions here.

Dataset(s)

The following questions will use the following dataset(s):

• `/depot/datamine/data/sim/train.csv`

Questions

Question 1

`JAX` is a library for high performance computing. It falls into the same category as other popular packages like: `numpy`, `pytorch`, and `tensorflow`. `JAX` is a product of Google / Deepmind that takes a completely different approach than their other product, `tensorflow`.

Like the the other popular libraries, `JAX` can utilize GPUs/TPUs to greatly speed up computation. Let’s take a look.

Here is a snippet of code from previous projects that uses `pytorch` and calculates predictions 10000 times.

 Of course, this is the same calculation since our betas aren’t being updated yet, but just bear with me.
``````import pandas as pd
import torch
import jax
import jax.numpy as jnp

``````%%time

x_train = torch.tensor(dat['x'].to_numpy())
y_train = torch.tensor(dat['y'].to_numpy())

num_epochs = 10000

for idx in range(num_epochs):

y_predictions = beta0 + beta1*x_train + beta2*x_train**2``````

Approximately how much time does it take to run this second chunk of code (after we have already read in our data)?

Here is the equivalent `JAX` code:

``````%%time

x_train = jnp.array(dat['x'].to_numpy())
y_train = jnp.array(dat['y'].to_numpy())

beta0 = 5
beta1 = 4
beta2 = 3

num_epochs = 10000

for idx in range(num_epochs):

y_predictions = beta0 + beta1*x_train + beta2*x_train**2``````

How much time does this take?

At this point in time you may be questioning how `JAX` could possibly be worth it. At first glance, the new code does look a bit cleaner, but not clean enough to use code that is around 3 times slower.

This is where `JAX` first trick, or transformation comes in to play. When we refer to transformation, think of it as an operation on some function that produces another function as an output.

The first transformation we will talk about is `jax.jit`. "JIT" stands for "Just In Time" and refers to a "Just in time" compiler. Essentially, just in time compilation is a trick that can be used to greatly speed up the execution of some code by compiling the code. In a nutshell, the compiled version of the code has a wide variety of optimizations that speed your code up.

Lots of our computation time is spent inside our loop, specifically when we are calculating our `y_predictions`. Let’s see if we can use the jit transformation to speed up our `JAX` code with little to no extra effort.

Write a function called `model` that accepts two arguments. The first argument is a tuple containing our parameters: `beta0`, `beta1`, and `beta2`. The second is our input to our function (our x values) called `x`. `model` should then unpack our tuple of parameters into `beta0`, `beta1`, and `beta2`, and then return predictions (the same formula shown above, twice). Replace the code as follows.

``````# replace this line
y_predictions = beta0 + beta1*x_train + beta2*x_train**2

# with
y_predictions = model((beta0, beta1, beta2), x_train)``````

Run and time the code again. No difference? Well, we didn’t use our jit transformation yet! Using the transformation is easy. `JAX` provides two equivalent ways. You can either decorate your `model` function with the `@jax.jit` decorator, or simply apply the transformation to your function and save the new, jit compiled function and use it instead.

``````def my_func(x):
return x**2

@jax.jit
def my_func_jit1(x):
return x**2

my_func_jit2 = jax.jit(my_func)``````

Re-run your code using the JIT transformation. Is it faster now?

 It is important to note that `pytorch` does have some `jit` functionality, and there is also a package called `numba` which can help with this as well, however, it is not as straightforward to perform the same operation using either as it is using `JAX`.
Items to submit
• Code used to solve this problem.

• Output from running the code.

Question 2

At this point in time you may be considering slapping `@jax.jit` on all your functions — unfortunately it is not quite so simple! First of all, the previous comparison was actually not fair at all. Why? `JAX` has asynchronous dispatch by default. What this means is that, by default, `JAX` will return control to Python as soon as possible, even if it is before the function has been fully evaluated.

What does this mean? It means that our finished example from question 1 may be returning a not-yet-complete result, greatly throwing off our performance measurements. So how can we synchronously wait for execution to finish? This is easy, simply use the `block_until_ready` method built in to your jit compiled `model` function.

``````def my_func(x):
return x**2

@jax.jit
def my_func_jit1(x):
return x**2

my_func_jit2 = jax.jit(my_func)

# or

Re-run your code from before — you should find that the results are unchanged, it turns out that really was a serious speedup from before. Great. Let’s move on from this part of things. Back to our question. Why can’t we just slap `@jax.jit` on any function and expect a speedup?

Take the following function.

``````def train(params, x, y, epochs):
def _model(params, x):
beta0, beta1, beta2 = params
return beta0 + beta1*x + beta2*x**2

mses = []
for _ in range(epochs):
y_predictions = _model(params, x_train)
mse = jnp.sum((y_predictions - y)**2)

fast_train = jax.jit(train)

fast_train((beta0, beta1, beta2), x_train, y_train, 10000)``````

If you try running it you will get an error saying something along the lines of "TracerIntegerConversionError". The problem with this function, and why it cannot be jit compiled, is the `epochs` argument. By default, `JAX` tries to "trace" the parameters to determine its effect on inputs of a specific shape and type. Control flow cannot depend on traced values — in this case, `epochs` is relied on in order to determine how many times to loop. In addition, the shapes of all input and output values of a function must be able to be determined ahead of time.

How do we fix this? Well, it is not always possible, however, we can choose to select parameters to be static or not traced. If a parameter is marked as static, or not traced, it can be JIT compiled. The catch is that any time a call to the function is made and the value of the static parameter is changed, the function will have to be recompiled with that new static value. So, this is only useful if you will only occasionally change the parameter. This sounds like our case! We only want to occasionally change the number of epochs, so perfect.

You can mark a parameter as static by specifying the argument position using the `static_argnums` argument to `jax.jit`, or by specifying the argument name using the `static_argnames` argument to `jax.jit`.

Force the `epochs` argument to be static, and use the `jax.jit` decorator to compile the function. Test out the function, in order using the following code cells.

``````%%time

fast_train((beta0, beta1, beta2), x_train, y_train, 10000)``````
``````%%time

fast_train((beta0, beta1, beta2), x_train, y_train, 10000)``````
``````%%time

fast_train((beta0, beta1, beta2), x_train, y_train, 9999)``````

Do your best to explain why the last code cell was once again slower.

 If you aren’t sure why, reread the question text — we hint at the "catch" in the text.
Items to submit
• Code used to solve this problem.

• Output from running the code.

Question 3

We learned that one of the coolest parts of the `pytorch` package was the automatic differentiation feature. It saves a lot of time doing some calculus and coding up resulting equations. Recall that in `pytorch` this differentiation was baked into the `backward` method of our MSE. This is quite different from the way we think about the equations when looking at the math, and is quite confusing.

`JAX` has the same functionality, but it is much cleaner and easier to use. We will provide you with a simple example, and explain the math as we go along.

Let’s say our function is \$f(x) = 2x^2\$. We can start by writing a function.

``````def squared(x):
return x**2``````

Fantastic, so far pretty easy.

The derivative w.r.t. `x` is \$4x\$. Doing this in `JAX` is as easy as applying the `jax.grad` transformation to the function.

``squared_deriv = jax.grad(squared)``

Okay, test out both functions as follows.

``````my_array = jnp.array([1.0, 2.0, 3.0])

squared(4.0) # 16.0
squared(my_array) # [1.0, 4.0, 9.0]
squared_deriv(4.0) # 16.0
squared_deriv(my_array) # uh oh! Something went wrong!``````
 A very perceptive student pointed out that we originally passed array values that were ints to `jax.grad`. This will fail. You can read more about why here.

On the last line, you probably received a message or error saying something along the lines of "Gradient only defined for scalar-ouput functions. What this means is that the resulting derivative function is not vectorized. As you may have guessed, this is easily fixed. Another key transformation that `JAX` provides is called `vmap`. `vmap` takes a function and creates a vectorized version of the function. See the following.

``````vectorized_deriv_squared = jax.vmap(squared_deriv)
vectorized_deriv_squared(my_array) # [4.0, 8.0, 12.0]``````

Heck yes! That is pretty cool, and very powerful. It is so much more understandable than the magic happening in the `pytorch` world too!

Dig back into your memory about any equation you may have had in the past where you needed to find a derivative. Create a Python function, find the derivative, and test it out on both a single value, like `4.0` as well as an array, like `jnp.array([1.0,2.0,3.0])`. Don’t hesitate to make it extra fun and include some functions like `jnp.cos`, `jnp.sin`, etc. Did everything work as expected?

Items to submit
• Code used to solve this problem.

• Output from running the code.

Question 4

Okay, great, but that was a straight-forward example. What if we have multiple parameters we’d like to take partial derivatives with respect to? `jax.grad` can handle that too!

Read this excellent example in the official JAX documentation.

 The JAX documentation is pretty excellent! If you are interested, I would recommend reading through it, it is very well written.

Given the following (should be familiar) model, create a function called `get_partials` that accepts an argument `params` (a tuple containing beta0, beta1, and beta2, in order) and an argument `x`, that can be either a single value (a scalar), or a `jnp.array` with multiple values. This function should return a single value for each of the 3 partial derivatives, where `x` is plugged into each of the 3 partial derivatives to calculate each value, OR, 3 arrays of results where there are 3 values for each value in the input array.

``````@jax.jit
def model(params, x):
beta0, beta1, beta2 = params
return beta0 + beta1*x + beta2*x**2``````
example using it
``````model((1.0, 2.0, 3.0), 4.0) # 57
model((1.0, 2.0, 3.0), jnp.array((4.0, 5.0, 6.0))) # [57, 86, 121]``````

Since we have 3 parameters, we will have 3 partial derivatives, and our new function should output a value for each of our 3 partial derivatives, for each value passed as `x`. To be explicit and allow you to check your work, the results should be the same as the following.

``````params = (5.0, 4.0, 3.0)
get_partials(params, x_train)``````
output
```((DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],            dtype=float32, weak_type=True),
DeviceArray([-15.94824   , -11.117526  , -10.4780855 ,  -8.867778  ,
-8.799367  ,  -8.140428  ,  -7.8744955 ,  -7.72306   ,
-6.9281745 ,  -6.2731333 ,  -6.2275624 ,  -5.7271757 ,
-5.1857414 ,  -5.150156  ,  -4.8792663 ,  -4.663747  ,
-4.58701   ,  -4.1310377 ,  -4.0215836 ,  -4.019455  ,
-3.5578184 ,  -3.4748363 ,  -3.4004524 ,  -3.1221437 ,
-3.0421085 ,  -2.941131  ,  -2.8603644 ,  -2.8294718 ,
-2.7050996 ,  -1.9493109 ,  -1.7873074 ,  -1.2773769 ,
-1.1804487 ,  -1.1161369 ,  -1.1154363 ,  -0.8590109 ,
-0.81457555,  -0.7386795 ,  -0.57577926,  -0.5536533 ,
-0.51964295,  -0.12334588,   0.11549235,   0.14650635,
0.24305418,   0.2876291 ,   0.3942046 ,   0.6342466 ,
0.8256681 ,   1.2047065 ,   1.9168468 ,   1.9493027 ,
1.9587051 ,   2.3490443 ,   2.7015095 ,   2.8161156 ,
2.8648841 ,   2.946292  ,   3.1312609 ,   3.1810293 ,
4.503682  ,   5.114829  ,   5.1591663 ,   5.205859  ,
5.622392  ,   5.852435  ,   6.21313   ,   6.4066596 ,
6.655888  ,   6.781989  ,   7.1651325 ,   7.957219  ,
8.349893  ,  11.266327  ,  13.733376  ],            dtype=float32, weak_type=True),
DeviceArray([2.54346375e+02, 1.23599388e+02, 1.09790276e+02,
7.86374817e+01, 7.74288559e+01, 6.62665634e+01,
6.20076790e+01, 5.96456566e+01, 4.79996033e+01,
3.93521996e+01, 3.87825356e+01, 3.28005409e+01,
2.68919144e+01, 2.65241070e+01, 2.38072395e+01,
2.17505341e+01, 2.10406590e+01, 1.70654716e+01,
1.61731339e+01, 1.61560173e+01, 1.26580715e+01,
1.20744877e+01, 1.15630760e+01, 9.74778175e+00,
9.25442410e+00, 8.65025234e+00, 8.18168449e+00,
8.00591087e+00, 7.31756353e+00, 3.79981303e+00,
3.19446778e+00, 1.63169169e+00, 1.39345896e+00,
1.24576163e+00, 1.24419820e+00, 7.37899661e-01,
6.63533330e-01, 5.45647442e-01, 3.31521749e-01,
3.06531966e-01, 2.70028800e-01, 1.52142067e-02,
1.33384829e-02, 2.14641113e-02, 5.90753369e-02,
8.27304944e-02, 1.55397251e-01, 4.02268738e-01,
6.81727827e-01, 1.45131791e+00, 3.67430139e+00,
3.79978085e+00, 3.83652544e+00, 5.51800919e+00,
7.29815340e+00, 7.93050718e+00, 8.20756149e+00,
8.68063641e+00, 9.80479431e+00, 1.01189480e+01,
2.02831535e+01, 2.61614761e+01, 2.66169968e+01,
2.71009693e+01, 3.16112938e+01, 3.42509956e+01,
3.86029854e+01, 4.10452881e+01, 4.43008461e+01,
4.59953766e+01, 5.13391228e+01, 6.33173370e+01,
6.97207031e+01, 1.26930122e+02, 1.88605606e+02],            dtype=float32, weak_type=True)),)```
``get_partials((1.0,2.0,3.0), jnp.array((4.0,)))``
output
```((DeviceArray([1.], dtype=float32, weak_type=True),
DeviceArray([4.], dtype=float32, weak_type=True),
DeviceArray([16.], dtype=float32, weak_type=True)),)```
 To specify which arguments to take the partial derivative with respect to, use the `argnums` argument to `jax.grad`. In our case, our first argument is really 3 parameters all at once, so if you did `argnums=(0,)` it would take 3 partial derivatives. If you specified `argnums=(0,1)` it would take 4 — that last one being with respect to x.
 To vectorize your resulting function, use `jax.vmap`. This time, since we have many possible arguments, we will need to specify the `in_axes` argument to `jax.vmap`. `in_axes` will accept a tuple of values — one value per parameter to our function. Since our function has 2 arguments: `params` and `x`, this tuple should have 2 values. We should put `None` for arguments that we don’t want to vectorize over (in this case, `params` stays the same for each call, so the associated `in_axes` value for `params` should be `None`). Our second argument, `x`, should be able to be a vector, so you should put `0` for the associated `in_axes` value for `x`. This is confusing! However, considering how powerful and all that is baked into the `get_partials` function, it is probably acceptable to have to sit an think a bit to figure this out.
Items to submit
• Code used to solve this problem.

• Output from running the code.

 Please make sure to double check that your submission is complete, and contains all of your code and output before submitting. If you are on a spotty internet connect ion, it is recommended to download your submission after submitting it to make sure what you think you submitted, was what you actually submitted. In addition, please review our submission guidelines before submitting your project.