STAT 39000: Project 11 — Spring 2022

Motivation: Machine learning and AI are huge buzzwords in industry, and two of the most popular tools surrounding said topics are the pytorch and tensorflow libraries — JAX is another tool by Google growing in popularity. These tools are libraries used to build and use complex models. If available, they can take advantage of GPUs to speed up parallelizable code by a hundred or even thousand fold.

Context: This is the third of a series of 4 projects focused on using pytorch and JAX to solve numeric problems.

Scope: Python, JAX

Learning Objectives
  • Compare and contrast pytorch and JAX.

  • Differentiate functions using JAX.

  • Understand what "JIT" is and why it is useful.

  • Understand when a value or operation should be static vs. traced.

  • Vectorize functions using the vmap function from JAX.

  • How do random number generators work in JAX?

Make sure to read about, and use the template found here, and the important information about projects submissions here.

Dataset(s)

The following questions will use the following dataset(s):

  • /depot/datamine/data/sim/train.csv

Questions

Question 1

JAX is a library for high performance computing. It falls into the same category as other popular packages like: numpy, pytorch, and tensorflow. JAX is a product of Google / Deepmind that takes a completely different approach than their other product, tensorflow.

Like the the other popular libraries, JAX can utilize GPUs/TPUs to greatly speed up computation. Let’s take a look.

Here is a snippet of code from previous projects that uses pytorch and calculates predictions 10000 times.

Of course, this is the same calculation since our betas aren’t being updated yet, but just bear with me.

import pandas as pd
import torch
import jax
import jax.numpy as jnp

dat = pd.read_csv("/depot/datamine/data/sim/train.csv")
%%time

x_train = torch.tensor(dat['x'].to_numpy())
y_train = torch.tensor(dat['y'].to_numpy())

beta0 = torch.tensor(5, requires_grad=True, dtype=torch.float)
beta1 = torch.tensor(4, requires_grad=True, dtype=torch.float)
beta2 = torch.tensor(3, requires_grad=True, dtype=torch.float)

num_epochs = 10000

for idx in range(num_epochs):

    y_predictions = beta0 + beta1*x_train + beta2*x_train**2

Approximately how much time does it take to run this second chunk of code (after we have already read in our data)?

Here is the equivalent JAX code:

%%time

x_train = jnp.array(dat['x'].to_numpy())
y_train = jnp.array(dat['y'].to_numpy())

beta0 = 5
beta1 = 4
beta2 = 3

num_epochs = 10000

for idx in range(num_epochs):

    y_predictions = beta0 + beta1*x_train + beta2*x_train**2

How much time does this take?

At this point in time you may be questioning how JAX could possibly be worth it. At first glance, the new code does look a bit cleaner, but not clean enough to use code that is around 3 times slower.

This is where JAX first trick, or transformation comes in to play. When we refer to transformation, think of it as an operation on some function that produces another function as an output.

The first transformation we will talk about is jax.jit. "JIT" stands for "Just In Time" and refers to a "Just in time" compiler. Essentially, just in time compilation is a trick that can be used to greatly speed up the execution of some code by compiling the code. In a nutshell, the compiled version of the code has a wide variety of optimizations that speed your code up.

Lots of our computation time is spent inside our loop, specifically when we are calculating our y_predictions. Let’s see if we can use the jit transformation to speed up our JAX code with little to no extra effort.

Write a function called model that accepts two arguments. The first argument is a tuple containing our parameters: beta0, beta1, and beta2. The second is our input to our function (our x values) called x. model should then unpack our tuple of parameters into beta0, beta1, and beta2, and then return predictions (the same formula shown above, twice). Replace the code as follows.

# replace this line
y_predictions = beta0 + beta1*x_train + beta2*x_train**2

# with
y_predictions = model((beta0, beta1, beta2), x_train)

Run and time the code again. No difference? Well, we didn’t use our jit transformation yet! Using the transformation is easy. JAX provides two equivalent ways. You can either decorate your model function with the @jax.jit decorator, or simply apply the transformation to your function and save the new, jit compiled function and use it instead.

def my_func(x):
    return x**2

@jax.jit
def my_func_jit1(x):
    return x**2

my_func_jit2 = jax.jit(my_func)

Re-run your code using the JIT transformation. Is it faster now?

It is important to note that pytorch does have some jit functionality, and there is also a package called numba which can help with this as well, however, it is not as straightforward to perform the same operation using either as it is using JAX.

Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Question 2

At this point in time you may be considering slapping @jax.jit on all your functions — unfortunately it is not quite so simple! First of all, the previous comparison was actually not fair at all. Why? JAX has asynchronous dispatch by default. What this means is that, by default, JAX will return control to Python as soon as possible, even if it is before the function has been fully evaluated.

What does this mean? It means that our finished example from question 1 may be returning a not-yet-complete result, greatly throwing off our performance measurements. So how can we synchronously wait for execution to finish? This is easy, simply use the block_until_ready method built in to your jit compiled model function.

def my_func(x):
    return x**2

@jax.jit
def my_func_jit1(x):
    return x**2

my_func_jit2 = jax.jit(my_func)

my_func_jit1.block_until_ready()

# or

my_func_jit2.block_until_ready()

Re-run your code from before — you should find that the results are unchanged, it turns out that really was a serious speedup from before. Great. Let’s move on from this part of things. Back to our question. Why can’t we just slap @jax.jit on any function and expect a speedup?

Take the following function.

def train(params, x, y, epochs):
    def _model(params, x):
        beta0, beta1, beta2 = params
        return beta0 + beta1*x + beta2*x**2

    mses = []
    for _ in range(epochs):
        y_predictions = _model(params, x_train)
        mse = jnp.sum((y_predictions - y)**2)

fast_train = jax.jit(train)

fast_train((beta0, beta1, beta2), x_train, y_train, 10000)

If you try running it you will get an error saying something along the lines of "TracerIntegerConversionError". The problem with this function, and why it cannot be jit compiled, is the epochs argument. By default, JAX tries to "trace" the parameters to determine its effect on inputs of a specific shape and type. Control flow cannot depend on traced values — in this case, epochs is relied on in order to determine how many times to loop. In addition, the shapes of all input and output values of a function must be able to be determined ahead of time.

How do we fix this? Well, it is not always possible, however, we can choose to select parameters to be static or not traced. If a parameter is marked as static, or not traced, it can be JIT compiled. The catch is that any time a call to the function is made and the value of the static parameter is changed, the function will have to be recompiled with that new static value. So, this is only useful if you will only occasionally change the parameter. This sounds like our case! We only want to occasionally change the number of epochs, so perfect.

You can mark a parameter as static by specifying the argument position using the static_argnums argument to jax.jit, or by specifying the argument name using the static_argnames argument to jax.jit.

Force the epochs argument to be static, and use the jax.jit decorator to compile the function. Test out the function, in order using the following code cells.

%%time

fast_train((beta0, beta1, beta2), x_train, y_train, 10000)
%%time

fast_train((beta0, beta1, beta2), x_train, y_train, 10000)
%%time

fast_train((beta0, beta1, beta2), x_train, y_train, 9999)

Do your best to explain why the last code cell was once again slower.

If you aren’t sure why, reread the question text — we hint at the "catch" in the text.

Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Question 3

We learned that one of the coolest parts of the pytorch package was the automatic differentiation feature. It saves a lot of time doing some calculus and coding up resulting equations. Recall that in pytorch this differentiation was baked into the backward method of our MSE. This is quite different from the way we think about the equations when looking at the math, and is quite confusing.

JAX has the same functionality, but it is much cleaner and easier to use. We will provide you with a simple example, and explain the math as we go along.

Let’s say our function is $f(x) = 2x^2$. We can start by writing a function.

def squared(x):
    return x**2

Fantastic, so far pretty easy.

The derivative w.r.t. x is $4x$. Doing this in JAX is as easy as applying the jax.grad transformation to the function.

squared_deriv = jax.grad(squared)

Okay, test out both functions as follows.

my_array = jnp.array([1.0, 2.0, 3.0])

squared(4.0) # 16.0
squared(my_array) # [1.0, 4.0, 9.0]
squared_deriv(4.0) # 16.0
squared_deriv(my_array) # uh oh! Something went wrong!

A very perceptive student pointed out that we originally passed array values that were ints to jax.grad. This will fail. You can read more about why here.

On the last line, you probably received a message or error saying something along the lines of "Gradient only defined for scalar-ouput functions. What this means is that the resulting derivative function is not vectorized. As you may have guessed, this is easily fixed. Another key transformation that JAX provides is called vmap. vmap takes a function and creates a vectorized version of the function. See the following.

vectorized_deriv_squared = jax.vmap(squared_deriv)
vectorized_deriv_squared(my_array) # [4.0, 8.0, 12.0]

Heck yes! That is pretty cool, and very powerful. It is so much more understandable than the magic happening in the pytorch world too!

Dig back into your memory about any equation you may have had in the past where you needed to find a derivative. Create a Python function, find the derivative, and test it out on both a single value, like 4.0 as well as an array, like jnp.array([1.0,2.0,3.0]). Don’t hesitate to make it extra fun and include some functions like jnp.cos, jnp.sin, etc. Did everything work as expected?

Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Question 4

Okay, great, but that was a straight-forward example. What if we have multiple parameters we’d like to take partial derivatives with respect to? jax.grad can handle that too!

Read this excellent example in the official JAX documentation.

The JAX documentation is pretty excellent! If you are interested, I would recommend reading through it, it is very well written.

Given the following (should be familiar) model, create a function called get_partials that accepts an argument params (a tuple containing beta0, beta1, and beta2, in order) and an argument x, that can be either a single value (a scalar), or a jnp.array with multiple values. This function should return a single value for each of the 3 partial derivatives, where x is plugged into each of the 3 partial derivatives to calculate each value, OR, 3 arrays of results where there are 3 values for each value in the input array.

@jax.jit
def model(params, x):
    beta0, beta1, beta2 = params
    return beta0 + beta1*x + beta2*x**2
example using it
model((1.0, 2.0, 3.0), 4.0) # 57
model((1.0, 2.0, 3.0), jnp.array((4.0, 5.0, 6.0))) # [57, 86, 121]

Since we have 3 parameters, we will have 3 partial derivatives, and our new function should output a value for each of our 3 partial derivatives, for each value passed as x. To be explicit and allow you to check your work, the results should be the same as the following.

params = (5.0, 4.0, 3.0)
get_partials(params, x_train)
output
((DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],            dtype=float32, weak_type=True),
  DeviceArray([-15.94824   , -11.117526  , -10.4780855 ,  -8.867778  ,
                -8.799367  ,  -8.140428  ,  -7.8744955 ,  -7.72306   ,
                -6.9281745 ,  -6.2731333 ,  -6.2275624 ,  -5.7271757 ,
                -5.1857414 ,  -5.150156  ,  -4.8792663 ,  -4.663747  ,
                -4.58701   ,  -4.1310377 ,  -4.0215836 ,  -4.019455  ,
                -3.5578184 ,  -3.4748363 ,  -3.4004524 ,  -3.1221437 ,
                -3.0421085 ,  -2.941131  ,  -2.8603644 ,  -2.8294718 ,
                -2.7050996 ,  -1.9493109 ,  -1.7873074 ,  -1.2773769 ,
                -1.1804487 ,  -1.1161369 ,  -1.1154363 ,  -0.8590109 ,
                -0.81457555,  -0.7386795 ,  -0.57577926,  -0.5536533 ,
                -0.51964295,  -0.12334588,   0.11549235,   0.14650635,
                 0.24305418,   0.2876291 ,   0.3942046 ,   0.6342466 ,
                 0.8256681 ,   1.2047065 ,   1.9168468 ,   1.9493027 ,
                 1.9587051 ,   2.3490443 ,   2.7015095 ,   2.8161156 ,
                 2.8648841 ,   2.946292  ,   3.1312609 ,   3.1810293 ,
                 4.503682  ,   5.114829  ,   5.1591663 ,   5.205859  ,
                 5.622392  ,   5.852435  ,   6.21313   ,   6.4066596 ,
                 6.655888  ,   6.781989  ,   7.1651325 ,   7.957219  ,
                 8.349893  ,  11.266327  ,  13.733376  ],            dtype=float32, weak_type=True),
  DeviceArray([2.54346375e+02, 1.23599388e+02, 1.09790276e+02,
               7.86374817e+01, 7.74288559e+01, 6.62665634e+01,
               6.20076790e+01, 5.96456566e+01, 4.79996033e+01,
               3.93521996e+01, 3.87825356e+01, 3.28005409e+01,
               2.68919144e+01, 2.65241070e+01, 2.38072395e+01,
               2.17505341e+01, 2.10406590e+01, 1.70654716e+01,
               1.61731339e+01, 1.61560173e+01, 1.26580715e+01,
               1.20744877e+01, 1.15630760e+01, 9.74778175e+00,
               9.25442410e+00, 8.65025234e+00, 8.18168449e+00,
               8.00591087e+00, 7.31756353e+00, 3.79981303e+00,
               3.19446778e+00, 1.63169169e+00, 1.39345896e+00,
               1.24576163e+00, 1.24419820e+00, 7.37899661e-01,
               6.63533330e-01, 5.45647442e-01, 3.31521749e-01,
               3.06531966e-01, 2.70028800e-01, 1.52142067e-02,
               1.33384829e-02, 2.14641113e-02, 5.90753369e-02,
               8.27304944e-02, 1.55397251e-01, 4.02268738e-01,
               6.81727827e-01, 1.45131791e+00, 3.67430139e+00,
               3.79978085e+00, 3.83652544e+00, 5.51800919e+00,
               7.29815340e+00, 7.93050718e+00, 8.20756149e+00,
               8.68063641e+00, 9.80479431e+00, 1.01189480e+01,
               2.02831535e+01, 2.61614761e+01, 2.66169968e+01,
               2.71009693e+01, 3.16112938e+01, 3.42509956e+01,
               3.86029854e+01, 4.10452881e+01, 4.43008461e+01,
               4.59953766e+01, 5.13391228e+01, 6.33173370e+01,
               6.97207031e+01, 1.26930122e+02, 1.88605606e+02],            dtype=float32, weak_type=True)),)
get_partials((1.0,2.0,3.0), jnp.array((4.0,)))
output
((DeviceArray([1.], dtype=float32, weak_type=True),
  DeviceArray([4.], dtype=float32, weak_type=True),
  DeviceArray([16.], dtype=float32, weak_type=True)),)

To specify which arguments to take the partial derivative with respect to, use the argnums argument to jax.grad. In our case, our first argument is really 3 parameters all at once, so if you did argnums=(0,) it would take 3 partial derivatives. If you specified argnums=(0,1) it would take 4 — that last one being with respect to x.

To vectorize your resulting function, use jax.vmap. This time, since we have many possible arguments, we will need to specify the in_axes argument to jax.vmap. in_axes will accept a tuple of values — one value per parameter to our function. Since our function has 2 arguments: params and x, this tuple should have 2 values. We should put None for arguments that we don’t want to vectorize over (in this case, params stays the same for each call, so the associated in_axes value for params should be None). Our second argument, x, should be able to be a vector, so you should put 0 for the associated in_axes value for x.

This is confusing! However, considering how powerful and all that is baked into the get_partials function, it is probably acceptable to have to sit an think a bit to figure this out.

Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Please make sure to double check that your submission is complete, and contains all of your code and output before submitting. If you are on a spotty internet connect ion, it is recommended to download your submission after submitting it to make sure what you think you submitted, was what you actually submitted.

In addition, please review our submission guidelines before submitting your project.