STAT 39000: Project 12 — Spring 2022
Motivation: Machine learning and AI are huge buzzwords in industry, and two of the most popular tools surrounding said topics are the pytorch and tensorflow libraries — JAX is another tool by Google growing in popularity. These tools are libraries used to build and use complex models. If available, they can take advantage of GPUs to speed up parallelizable code by a hundred or even thousand fold.
Context: This is the last of a series of 4 projects focused on using pytorch and JAX to solve numeric problems.
Scope: Python, JAX
Questions
Question 1
Last weeks project was a bit fast paced, so we will slow things down considerablyto try and compensate, and give you a chance to digest and explore more. We will:

Learn how
JAX
handles generating random numbers differently than most other packages. 
Write a function in
numpy
to calculate the Hamming distance between a given image hash and the remaining (around 123k) image hashes. 
Play around with the hash data and do some sanity checks.
Let’s start by taking a look at the documentation for random number generation. Carefully read the page — it’s not that long.
The documentation gives the following example.
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
It then goes on to say that JAX
may try to parallelize the bar
and baz
functions. As a result, we would not know which would run first, bar
or baz
. This would change the results of foo
. Below, we’ve modified the code to emulate this.
import numpy as np
import random
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo1(): return bar() + 2 * baz()
def foo2(): return 2*baz() + bar()
def foo(*funcs):
functions = list(funcs)
random.shuffle(functions)
return functions[0]()
np.random.seed(0)
foo(foo1, foo2)
# sometimes this 1.9791922366721637 # sometimes this 1.812816374227069
JAX
has a much different way of dealing with this. While the solution is clean and effective, and allows such code to be parallelized, it can be a bit more cumbersome managing and passing around keys. Create a modified version of this code using JAX
, and passing around keys. Fill in the ?
parts.
import numpy as np
key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=?)
def bar(key):
return ?
def baz(key):
return ?
def foo1(key1, key2):
return bar(key1) + 2 * baz(key2)
def foo2(key1, key2):
return 2*baz(key2) + bar(key1)
def foo(funcs, keys):
functions = list(funcs)
random.shuffle(functions)
return ?
key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=3)
print(foo((foo1, foo2), (subkeys[0], subkeys[1])))
# always 2.3250647

Code used to solve this problem.

Output from running the code.
Question 2
Write a function called get_distances_np
that accepts a filename (as a string) (fm_hash
), and a path (as a string) (path
).
get_distances_np
should return a numpy array of the distances between the hash for fm_hash
and every other image hash in path
.
For this question, use the dataset of hashed images found in /depot/datamine/data/coco/hashed02/
. An example of a call to get_distances_np
would look like the following.
from pathlib import Path
import imagehash
import numpy as np
%%time
hshs = get_distances_np("000000000008.jpg", "/depot/datamine/data/coco/hashed02/")
hshs.shape # (123387, 1)
How long does it take to run this function?
Make plots and/or summary statistics to check out the distribution of the distances. How does it look?

Code used to solve this problem.

Output from running the code.
Question 3
What do you think about the design of the get_distances_np
function, considering that we are interested in pairwise Hamming distances?
At its core, we essentially have a vector of 123k values. If we were to get the pairwise distances, the resulting distances would fill the upper triangle of a 123k by 123k matrix. This would be a very large amount of data, considering it is just numeric data — more than can easily fit in memory.
In addition, the part of the function from question 2 that takes the majority of the run time is not the numeric computations, but rather the opening and reading of the 123k hashes. Approximately 55 of the 6570 seconds. With this in mind, let’s back up, and break this problem down further.
Write a code cell containing code that will read in all of the hashes into a numpy
array of size (123387, 64).
This array contains the hashes for each of the 123k images. Each row is the hash of an image. Let’s call the resulting (123387, 64) array hashes
.
Given what we know, the following is a very fast function that will find the Hamming distances between a single image and all of the other images.
def hamming_distance(hash1, hash2):
return np.sum(~(hash1 == hash2), axis=1)
%%time
hamming_distance(hashes[0], hashes)
This runs in approximately 46 ms. This would be about 9495 minutes if we did this calculation for each pair.
Convert your numpy
array into a JAX
array, and create an equivalent function. How fast does this function run? What would the approximate runtime be for the total calculation?
Remember to use 
Make sure to take into consideration the slower first run. What would the approximate total runtime be using the JAX
function?

Code used to solve this problem.

Output from running the code.
Question 4
Don’t worry, I’m not going to make you run these calculations. Instead, answer one of the following two questions.

Pick 2 images / image hashes and get the closest 3 images by Hamming distance for each. Note the distances and display the images. At those distances, can you perceive any sort of "closeness" in image?

Randomly sample (using
JAX
methods) n (more than 4, please) images and calculate all of the pairwise distances. Create a set of plots showing the distributions of distances. Explore the distances, and the dataset, and write 12 sentences about any interesting observations you are able to make, or 12 sentences on how you could use the information to do something cool.

Code used to solve this problem.

Output from running the code.
Please make sure to double check that your submission is complete, and contains all of your code and output before submitting. If you are on a spotty internet connect ion, it is recommended to download your submission after submitting it to make sure what you think you submitted, was what you actually submitted. In addition, please review our submission guidelines before submitting your project. 