TDM 40200: Project 1 — 2023

Motivation: JAX is a Python library used for high-performance numerical computing and machine learning research. In the upcoming series of projects, this library will be used to build a small neural network from scratch.

Context: This is the first project of the semester. We are going to start slowly by reviewing the JAX library, which we will use in the following series of projects.

Scope: Python, JAX, numpy

Learning Objectives
  • Differentiate functions using JAX.

  • Understand what "JIT" is and why it is useful.

  • Understand when a value or operation should be static vs. traced.

  • How does random number generation work in JAX?

  • How to use JAX for basic matrix manipulation.

Make sure to read about, and use the template found here, and the important information about projects submissions here.

Dataset(s)

The following questions will use the following dataset(s):

  • /anvil/projects/tdm/data

Questions

The following sets of documentation will be useful for this project.

Question 1

Use JAX to perform the following numeric operations.

  1. Create and display a 2x3 array called first of the values 1 through 6 in row-major order.

  2. Create and display a 3x2 array called second where each element is the value 2.

    Use the JAX equivalent of this function.

  3. Multiply first and second together into a resulting third and display the result.

  4. Display the shape of third.

  5. Multiply all values of third by 2 and display the result.

  6. Multiply third, element-wise, by the matrix formed from the values 1 through 4 in row-major order and display the result.

    Use the JAX equivalent of this function.

  7. Use the JAX equivalent of this function to add a row to second containing the values 3 and 4. Save the result to fourth.

  8. Use the JAX equivalent of this function to add a column to fourth containing the value 1 repeated 4 times. Save the result to fifth.

  9. Given the JAX array created by the following code, change the 1 to 7.

    my_array = jnp.array([[2,2,2], [2,1,3]])
  10. Read this section of the JAX docs and explain (in your own words) why the following code that should solve the previous question does not work.

    my_array[1,1] = 7
Items to submit
  • Code used to solve this problem.

  • Output from running the code.

  • A markdown cell containing an explanation on why my_array[1,1] = 7 does not work in JAX.

Question 2

Read this section of the JAX documentation to recall the way you can generate random numbers using JAX.

Check out the following code.

import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())

If this were written using JAX, JAX may well attempt to parallelize the bar and baz functions to run at the same time. As a result, whether bar or baz ran first would be unclear. This, unfortunately, would change the results of foo, leaving the code difficult to replicate reliably.

To illustrate this, foo could be the result of either of the following functions.

def foo1(): return bar() + 2*baz()

# or

def foo2(): return 2*baz() + bar()

If bar executes first (like in foo1), you will end up with a different result than if baz executes first (like in foo2). The following code illustrates this.

import numpy as np
import random

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo1(): return bar() + 2*baz()

def foo2(): return 2*baz() + bar()

def foo(*funcs):
    functions = list(funcs)
    random.shuffle(functions)
    return functions[0]()

Running the following will sometimes give you the result 1.9791922366721637, and sometimes 1.812816374227069.

np.random.seed(0)
foo(foo1, foo2)

The way JAX generates random values is different, and prevents such issues. At the same time, the way JAX generates random values is not as straightforward as NumPy. Fill in the ? parts of the following code. The resulting code should reliably output the same value regardless of whether bar or baz executes first — 2.3250647.

import jax

key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=?)

def bar(key):
    return ?

def baz(key):
    return ?

def foo1(key1, key2):
    return bar(key1) + 2*baz(key2)

def foo2(key1, key2):
    return 2*baz(key2) + bar(key1)

def foo(funcs, keys):
    functions = list(funcs)
    random.shuffle(functions)
    return ?
# the following code will always produce 2.3250647, regardless of whether bar or baz executes first
# this means this code is reproducible even in the scenario where the `bar` and `baz` functions are parallelized
key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=3)
print(foo((foo1, foo2), (subkeys[0], subkeys[1])))
Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Question 3

At the heart of JAX is the automatic differentiation system. This system allows JAX to compute gradients of functions, automatically. This is extremely powerful. Write a function called my_function that accepts the value x and returns the value 14x^3 + 13x. Test it out given the value of x=17.

Next, use the powerful JAX function grad to create a new function called my_gradient that accepts the value x and returns the gradient of my_function at x. Test it out given the value of x=17. What was the result? Does the result match the value when you plug x=17 into the derivative of my_function?

17 is an integer and 17.0 is a float. The jax.grad function requires real or complex valued inputs.

Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Question 4

Another key utility in JAX is the jit function. JIT stands for just in time. Just in time compilation is a trick that can be used in some situations to greatly increase the speed or execution time of some code, by compiling it. The compiled version of the code has a myriad of optimizations applied to it that speeds up your code.

Take the following, arbitrary code, and execute it in your notebook.

%%time

key = jax.random.PRNGKey(0)

def my_model(keys):
    return 14*jax.random.normal(keys[0])**2 + 13*jax.random.normal(keys[1])

for i in range(100000):
    key, *subkeys = jax.random.split(key, 3)
    my_value = my_model(subkeys)

How long did it take? Now, use the @jax.jit decorator to apply the jit transformation to your my_model function to use just in time compilation to speed up the code. Did it work?

Well, actually, just slapping the @jax.jit decorator on the function is not good enough. Why? Because JAX has asynchronous dispatch by default. What this means is that, by default, JAX will return control to Python as soon as possible, even if it is before the function has been fully evaluated. So while it may appear as if all of the 100000 loops have been executed, in reality, they may not have been.

To properly test if the JIT trick has sped things up, we need to synchronously wait for our code to finish executing. This can be easily accomplished by using the built in block_until_ready method build into all JIT compiled functions.

For example, the following code will synchronously wait for the my_func function to finish executing.

@jax.jit
def my_func():
    return 1

my_func().block_until_ready()

Repeat the experiment but make sure we are synchronously waiting for the code to finish executing. How long did it take? Did it work?

Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Question 5

At this point in time you may be thinking — let’s go! I can just slap @jax.jit on all my functions an make magically fast code! Well, not so fast. There are some caveats to using JAX and jit that you should be aware of.

By default, JAX will try and "trace" parameters to determine their effect on inputs of a specific shape and type. In JAX, control flow cannot depend on these "traced" values. For example, the following code will not work because num_loops is relied on in order to determine how many times to loop.

%%time

def my_function(x, num_loops):

    for i in range(num_loops):
        pass


fast_my_function = jax.jit(my_function)

fast_my_function(14, 1000000)

What is the solution to this problem? How do we fix it? Well, this is not always possible, however, we can choose to select certain arguments to be static or not "traced". If a parameter is marked as static, or not "traced", it can be JIT compiled. The catch is that any time a call to the function is made and the value of any of the static parameters is changed, the function will have to be recompiled with that new static value. So, this is only useful if you will only occasionally change the parameter.

It just so happens that the provided snippet of code is a good candidate for this, as the user will only occasionally decide to change the number of "num_loops" when running the code!

You can mark a parameter as static by specifying the argument position using the static_argnums argument to jax.jit, or by specifying the argumnet name using the static_argnames argument to jax.jit.

Force the num_loops argument to be static and use the jax.jit decorator to compile the function. Test out the function, in order, using the following code cells.

%%time

def my_function(x, num_loops):

    for i in range(num_loops):
        pass

fast_my_function = jax.jit(my_function, static_argnums=(1,))
%%time

fast_my_function(14, 1000000)
%%time

fast_my_function(14, 1000000)
%%time

fast_my_function(14, 999999)

Do your best to explain why the last code cell was once again slower.

In addition, the shapes or dimensions of all inputs and outputs must be able to be determined ahead of time. For example, the following will fail.

%%time

def my_function(x, arr_cols):

    my_array = jnp.full((2, arr_cols), 5)

fast_my_function = jax.jit(my_function)

fast_my_function(5, 5)

You can, once again, fix this by specifying the static parameters.

%%time

def my_function(x, arr_cols):

    my_array = jnp.full((2, arr_cols), 5)

fast_my_function = jax.jit(my_function, static_argnums=(1,))

And, once again, JAX will recompile every time that the static argument changes.

%%time

# slow, first time compiling
fast_my_function(5, 5)
%%time

# fast, already compiled with static argument of 5
fast_my_function(5, 5)
%%time

# slow, recompiling with static argument of 6
fast_my_function(5, 6)
Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Please make sure to double check that your submission is complete, and contains all of your code and output before submitting. If you are on a spotty internet connection, it is recommended to download your submission after submitting it to make sure what you think you submitted, was what you actually submitted.

In addition, please review our submission guidelines before submitting your project.