# TDM 40200: Project 1 — 2023

Motivation: `JAX` is a Python library used for high-performance numerical computing and machine learning research. In the upcoming series of projects, this library will be used to build a small neural network from scratch.

Context: This is the first project of the semester. We are going to start slowly by reviewing the `JAX` library, which we will use in the following series of projects.

Scope: Python, `JAX`, `numpy`

Learning Objectives
• Differentiate functions using `JAX`.

• Understand what "JIT" is and why it is useful.

• Understand when a value or operation should be static vs. traced.

• How does random number generation work in `JAX`?

• How to use `JAX` for basic matrix manipulation.

Make sure to read about, and use the template found here, and the important information about projects submissions here.

## Dataset(s)

The following questions will use the following dataset(s):

• `/anvil/projects/tdm/data`

## Questions

The following sets of documentation will be useful for this project.

### Question 1

Use `JAX` to perform the following numeric operations.

1. Create and display a 2x3 array called `first` of the values 1 through 6 in row-major order.

2. Create and display a 3x2 array called `second` where each element is the value 2.

 Use the `JAX` equivalent of this function.
3. Multiply `first` and `second` together into a resulting `third` and display the result.

4. Display the shape of `third`.

5. Multiply all values of `third` by 2 and display the result.

6. Multiply `third`, element-wise, by the matrix formed from the values 1 through 4 in row-major order and display the result.

 Use the `JAX` equivalent of this function.
7. Use the `JAX` equivalent of this function to add a row to `second` containing the values 3 and 4. Save the result to `fourth`.

8. Use the `JAX` equivalent of this function to add a column to `fourth` containing the value 1 repeated 4 times. Save the result to `fifth`.

9. Given the `JAX` array created by the following code, change the 1 to 7.

``my_array = jnp.array([[2,2,2], [2,1,3]])``
10. Read this section of the `JAX` docs and explain (in your own words) why the following code that should solve the previous question does not work.

``my_array[1,1] = 7``
Items to submit
• Code used to solve this problem.

• Output from running the code.

• A markdown cell containing an explanation on why `my_array[1,1] = 7` does not work in `JAX`.

### Question 2

Read this section of the `JAX` documentation to recall the way you can generate random numbers using `JAX`.

Check out the following code.

``````import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())``````

If this were written using `JAX`, `JAX` may well attempt to parallelize the `bar` and `baz` functions to run at the same time. As a result, whether `bar` or `baz` ran first would be unclear. This, unfortunately, would change the results of `foo`, leaving the code difficult to replicate reliably.

To illustrate this, `foo` could be the result of either of the following functions.

``````def foo1(): return bar() + 2*baz()

# or

def foo2(): return 2*baz() + bar()``````

If `bar` executes first (like in `foo1`), you will end up with a different result than if `baz` executes first (like in `foo2`). The following code illustrates this.

``````import numpy as np
import random

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo1(): return bar() + 2*baz()

def foo2(): return 2*baz() + bar()

def foo(*funcs):
functions = list(funcs)
random.shuffle(functions)
return functions()``````

Running the following will sometimes give you the result `1.9791922366721637`, and sometimes `1.812816374227069`.

``````np.random.seed(0)
foo(foo1, foo2)``````

The way `JAX` generates random values is different, and prevents such issues. At the same time, the way `JAX` generates random values is not as straightforward as `NumPy`. Fill in the `?` parts of the following code. The resulting code should reliably output the same value regardless of whether `bar` or `baz` executes first — `2.3250647`.

``````import jax

key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=?)

def bar(key):
return ?

def baz(key):
return ?

def foo1(key1, key2):
return bar(key1) + 2*baz(key2)

def foo2(key1, key2):
return 2*baz(key2) + bar(key1)

def foo(funcs, keys):
functions = list(funcs)
random.shuffle(functions)
return ?``````
``````# the following code will always produce 2.3250647, regardless of whether bar or baz executes first
# this means this code is reproducible even in the scenario where the `bar` and `baz` functions are parallelized
key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=3)
print(foo((foo1, foo2), (subkeys, subkeys)))``````
Items to submit
• Code used to solve this problem.

• Output from running the code.

### Question 3

At the heart of `JAX` is the automatic differentiation system. This system allows `JAX` to compute gradients of functions, automatically. This is extremely powerful. Write a function called `my_function` that accepts the value `x` and returns the value `14x^3 + 13x`. Test it out given the value of `x=17`.

Next, use the powerful `JAX` function `grad` to create a new function called `my_gradient` that accepts the value `x` and returns the gradient of `my_function` at `x`. Test it out given the value of `x=17`. What was the result? Does the result match the value when you plug `x=17` into the derivative of `my_function`?

 `17` is an integer and `17.0` is a float. The `jax.grad` function requires real or complex valued inputs.
Items to submit
• Code used to solve this problem.

• Output from running the code.

### Question 4

Another key utility in `JAX` is the `jit` function. JIT stands for just in time. Just in time compilation is a trick that can be used in some situations to greatly increase the speed or execution time of some code, by compiling it. The compiled version of the code has a myriad of optimizations applied to it that speeds up your code.

Take the following, arbitrary code, and execute it in your notebook.

``````%%time

key = jax.random.PRNGKey(0)

def my_model(keys):
return 14*jax.random.normal(keys)**2 + 13*jax.random.normal(keys)

for i in range(100000):
key, *subkeys = jax.random.split(key, 3)
my_value = my_model(subkeys)``````

How long did it take? Now, use the `@jax.jit` decorator to apply the `jit` transformation to your `my_model` function to use just in time compilation to speed up the code. Did it work?

Well, actually, just slapping the `@jax.jit` decorator on the function is not good enough. Why? Because `JAX` has asynchronous dispatch by default. What this means is that, by default, `JAX` will return control to Python as soon as possible, even if it is before the function has been fully evaluated. So while it may appear as if all of the 100000 loops have been executed, in reality, they may not have been.

To properly test if the JIT trick has sped things up, we need to synchronously wait for our code to finish executing. This can be easily accomplished by using the built in `block_until_ready` method build into all JIT compiled functions.

For example, the following code will synchronously wait for the `my_func` function to finish executing.

``````@jax.jit
def my_func():
return 1

Repeat the experiment but make sure we are synchronously waiting for the code to finish executing. How long did it take? Did it work?

Items to submit
• Code used to solve this problem.

• Output from running the code.

### Question 5

At this point in time you may be thinking — let’s go! I can just slap `@jax.jit` on all my functions an make magically fast code! Well, not so fast. There are some caveats to using `JAX` and `jit` that you should be aware of.

By default, `JAX` will try and "trace" parameters to determine their effect on inputs of a specific shape and type. In `JAX`, control flow cannot depend on these "traced" values. For example, the following code will not work because `num_loops` is relied on in order to determine how many times to loop.

``````%%time

def my_function(x, num_loops):

for i in range(num_loops):
pass

fast_my_function = jax.jit(my_function)

fast_my_function(14, 1000000)``````

What is the solution to this problem? How do we fix it? Well, this is not always possible, however, we can choose to select certain arguments to be static or not "traced". If a parameter is marked as static, or not "traced", it can be JIT compiled. The catch is that any time a call to the function is made and the value of any of the static parameters is changed, the function will have to be recompiled with that new static value. So, this is only useful if you will only occasionally change the parameter.

It just so happens that the provided snippet of code is a good candidate for this, as the user will only occasionally decide to change the number of "num_loops" when running the code!

You can mark a parameter as static by specifying the argument position using the `static_argnums` argument to `jax.jit`, or by specifying the argumnet name using the `static_argnames` argument to `jax.jit`.

Force the `num_loops` argument to be static and use the `jax.jit` decorator to compile the function. Test out the function, in order, using the following code cells.

``````%%time

def my_function(x, num_loops):

for i in range(num_loops):
pass

fast_my_function = jax.jit(my_function, static_argnums=(1,))``````
``````%%time

fast_my_function(14, 1000000)``````
``````%%time

fast_my_function(14, 1000000)``````
``````%%time

fast_my_function(14, 999999)``````

Do your best to explain why the last code cell was once again slower.

In addition, the shapes or dimensions of all inputs and outputs must be able to be determined ahead of time. For example, the following will fail.

``````%%time

def my_function(x, arr_cols):

my_array = jnp.full((2, arr_cols), 5)

fast_my_function = jax.jit(my_function)

fast_my_function(5, 5)``````

You can, once again, fix this by specifying the static parameters.

``````%%time

def my_function(x, arr_cols):

my_array = jnp.full((2, arr_cols), 5)

fast_my_function = jax.jit(my_function, static_argnums=(1,))``````

And, once again, `JAX` will recompile every time that the static argument changes.

``````%%time

# slow, first time compiling
fast_my_function(5, 5)``````
``````%%time

# fast, already compiled with static argument of 5
fast_my_function(5, 5)``````
``````%%time

# slow, recompiling with static argument of 6
fast_my_function(5, 6)``````
Items to submit
• Code used to solve this problem.

• Output from running the code.

 Please make sure to double check that your submission is complete, and contains all of your code and output before submitting. If you are on a spotty internet connection, it is recommended to download your submission after submitting it to make sure what you think you submitted, was what you actually submitted. In addition, please review our submission guidelines before submitting your project.