# STAT 39000: Project 12 — Spring 2022

Motivation: Machine learning and AI are huge buzzwords in industry, and two of the most popular tools surrounding said topics are the pytorch and tensorflow libraries — JAX is another tool by Google growing in popularity. These tools are libraries used to build and use complex models. If available, they can take advantage of GPUs to speed up parallelizable code by a hundred or even thousand fold.

Context: This is the last of a series of 4 projects focused on using pytorch and JAX to solve numeric problems.

Scope: Python, JAX

Learning Objectives
• Compare and contrast pytorch and JAX.

• Differentiate functions using JAX.

• Understand what "JIT" is and why it is useful.

• Understand when a value or operation should be static vs. traced.

• Vectorize functions using the vmap function from JAX.

• How do random number generators work in JAX?

Make sure to read about, and use the template found here, and the important information about projects submissions here.

## Dataset(s)

The following questions will use the following dataset(s):

• `/depot/datamine/data/`

## Questions

### Question 1

Last weeks project was a bit fast paced, so we will slow things down considerablyto try and compensate, and give you a chance to digest and explore more. We will:

• Learn how `JAX` handles generating random numbers differently than most other packages.

• Write a function in `numpy` to calculate the Hamming distance between a given image hash and the remaining (around 123k) image hashes.

• Play around with the hash data and do some sanity checks.

Let’s start by taking a look at the documentation for random number generation. Carefully read the page — it’s not that long.

The documentation gives the following example.

``````import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())``````

It then goes on to say that `JAX` may try to parallelize the `bar` and `baz` functions. As a result, we would not know which would run first, `bar` or `baz`. This would change the results of `foo`. Below, we’ve modified the code to emulate this.

``````import numpy as np
import random

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo1(): return bar() + 2 * baz()

def foo2(): return 2*baz() + bar()

def foo(*funcs):
functions = list(funcs)
random.shuffle(functions)
return functions()``````
``````np.random.seed(0)
foo(foo1, foo2)``````
output
```# sometimes this
1.9791922366721637

# sometimes this
1.812816374227069```

`JAX` has a much different way of dealing with this. While the solution is clean and effective, and allows such code to be parallelized, it can be a bit more cumbersome managing and passing around keys. Create a modified version of this code using `JAX`, and passing around keys. Fill in the `?` parts.

``````import numpy as np

key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=?)

def bar(key):
return ?

def baz(key):
return ?

def foo1(key1, key2):
return bar(key1) + 2 * baz(key2)

def foo2(key1, key2):
return 2*baz(key2) + bar(key1)

def foo(funcs, keys):
functions = list(funcs)
random.shuffle(functions)
return ?``````
``````key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=3)
print(foo((foo1, foo2), (subkeys, subkeys)))``````
output
```# always
2.3250647```
Items to submit
• Code used to solve this problem.

• Output from running the code.

### Question 2

Write a function called `get_distances_np` that accepts a filename (as a string) (`fm_hash`), and a path (as a string) (`path`).

`get_distances_np` should return a numpy array of the distances between the hash for `fm_hash` and every other image hash in `path`.

For this question, use the dataset of hashed images found in `/depot/datamine/data/coco/hashed02/`. An example of a call to `get_distances_np` would look like the following.

``````from pathlib import Path
import imagehash
import numpy as np``````
``````%%time

hshs = get_distances_np("000000000008.jpg", "/depot/datamine/data/coco/hashed02/")
hshs.shape # (123387, 1)``````

How long does it take to run this function?

Make plots and/or summary statistics to check out the distribution of the distances. How does it look?

Items to submit
• Code used to solve this problem.

• Output from running the code.

### Question 3

What do you think about the design of the `get_distances_np` function, considering that we are interested in pairwise Hamming distances?

At its core, we essentially have a vector of 123k values. If we were to get the pairwise distances, the resulting distances would fill the upper triangle of a 123k by 123k matrix. This would be a very large amount of data, considering it is just numeric data — more than can easily fit in memory.

In addition, the part of the function from question 2 that takes the majority of the run time is not the numeric computations, but rather the opening and reading of the 123k hashes. Approximately 55 of the 65-70 seconds. With this in mind, let’s back up, and break this problem down further.

Write a code cell containing code that will read in all of the hashes into a `numpy` array of size (123387, 64).

This array contains the hashes for each of the 123k images. Each row is the hash of an image. Let’s call the resulting (123387, 64) array `hashes`.

Given what we know, the following is a very fast function that will find the Hamming distances between a single image and all of the other images.

``````def hamming_distance(hash1, hash2):
return np.sum(~(hash1 == hash2), axis=1)``````
``````%%time

hamming_distance(hashes, hashes)``````

This runs in approximately 46 ms. This would be about 94-95 minutes if we did this calculation for each pair.

Convert your `numpy` array into a `JAX` array, and create an equivalent function. How fast does this function run? What would the approximate runtime be for the total calculation?

 Remember to use `jax.jit` to speed up the function. Also recall that the first run of the compiled function will be slow since it needs to be compiled. After that, future uses of the function will be faster.

Make sure to take into consideration the slower first run. What would the approximate total runtime be using the `JAX` function?

Items to submit
• Code used to solve this problem.

• Output from running the code.

### Question 4

Don’t worry, I’m not going to make you run these calculations. Instead, answer one of the following two questions.

1. Pick 2 images / image hashes and get the closest 3 images by Hamming distance for each. Note the distances and display the images. At those distances, can you perceive any sort of "closeness" in image?

2. Randomly sample (using `JAX` methods) n (more than 4, please) images and calculate all of the pairwise distances. Create a set of plots showing the distributions of distances. Explore the distances, and the dataset, and write 1-2 sentences about any interesting observations you are able to make, or 1-2 sentences on how you could use the information to do something cool.

Items to submit
• Code used to solve this problem.

• Output from running the code.

 Please make sure to double check that your submission is complete, and contains all of your code and output before submitting. If you are on a spotty internet connect ion, it is recommended to download your submission after submitting it to make sure what you think you submitted, was what you actually submitted. In addition, please review our submission guidelines before submitting your project.