TDM 30200: Project 12 — 2023

Motivation: Machine learning and AI are huge buzzwords in industry, and two of the most popular tools surrounding said topics are the pytorch and tensorflow libraries — JAX is another tool by Google growing in popularity. These tools are libraries used to build and use complex models. If available, they can take advantage of GPUs to speed up parallelizable code by a hundred or even thousand fold.

Context: This is the last of a series of 4 projects focused on using pytorch and JAX to solve numeric problems.

Scope: Python, JAX

Learning Objectives
  • Compare and contrast pytorch and JAX.

  • Differentiate functions using JAX.

  • Understand what "JIT" is and why it is useful.

  • Understand when a value or operation should be static vs. traced.

  • Vectorize functions using the vmap function from JAX.

  • How do random number generators work in JAX?

Make sure to read about, and use the template found here, and the important information about projects submissions here.

Questions

Question 1

Last weeks project was a bit fast paced, so we will slow things down considerablyto try and compensate, and give you a chance to digest and explore more. We will:

  • Learn how JAX handles generating random numbers differently than most other packages.

  • Write a function in numpy to calculate the Hamming distance between a given image hash and the remaining (around 123k) image hashes.

  • Play around with the hash data and do some sanity checks.

Let’s start by taking a look at the documentation for random number generation. Carefully read the page — it’s not that long.

The documentation gives the following example.

import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())

It then goes on to say that JAX may try to parallelize the bar and baz functions. As a result, we would not know which would run first, bar or baz. This would change the results of foo. Below, we’ve modified the code to emulate this.

import numpy as np
import random

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo1(): return bar() + 2 * baz()

def foo2(): return 2*baz() + bar()

def foo(*funcs):
    functions = list(funcs)
    random.shuffle(functions)
    return functions[0]()
np.random.seed(0)
foo(foo1, foo2)
output
# sometimes this
1.9791922366721637

# sometimes this
1.812816374227069

JAX has a much different way of dealing with this. While the solution is clean and effective, and allows such code to be parallelized, it can be a bit more cumbersome managing and passing around keys. Create a modified version of this code using JAX, and passing around keys. Fill in the ? parts.

import numpy as np

key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=?)

def bar(key):
    return ?

def baz(key):
    return ?

def foo1(key1, key2):
    return bar(key1) + 2 * baz(key2)

def foo2(key1, key2):
    return 2*baz(key2) + bar(key1)

def foo(funcs, keys):
    functions = list(funcs)
    random.shuffle(functions)
    return ?
key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=3)
print(foo((foo1, foo2), (subkeys[0], subkeys[1])))
output
# always
2.3250647
Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Question 2

Write a function called get_distances_np that accepts a filename (as a string) (fm_hash), and a path (as a string) (path).

get_distances_np should return a numpy array of the distances between the hash for fm_hash and every other image hash in path.

For this question, use the dataset of hashed images found in /anvil/projects/tdm/data/coco/hashed02/. An example of a call to get_distances_np would look like the following.

from pathlib import Path
import imagehash
import numpy as np
%%time

hshs = get_distances_np("000000000008.jpg", "/anvil/projects/tdm/data/coco/hashed02/")
hshs.shape # (123387, 1)

How long does it take to run this function?

Make plots and/or summary statistics to check out the distribution of the distances. How does it look?

The distance we want you to calculate is the Hamming distance. We’ve provided you with a function that accepts two numpy arrays and returns the Hamming distance between them.

def _hamming_distance(hash1, hash2):
    return sum(~(hash1 == hash2))
Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Question 3

What do you think about the design of the get_distances_np function, considering that we are interested in pairwise Hamming distances?

At its core, we essentially have a vector of 123k values. If we were to get the pairwise distances, the resulting distances would fill the upper triangle of a 123k by 123k matrix. This would be a very large amount of data, considering it is just numeric data — more than can easily fit in memory.

In addition, the part of the function from question 2 that takes the majority of the run time is not the numeric computations, but rather the opening and reading of the 123k hashes. Approximately 55 of the 65-70 seconds. With this in mind, let’s back up, and break this problem down further.

Write a code cell containing code that will read in all of the hashes into a numpy array of size (123387, 64).

This array contains the hashes for each of the 123k images. Each row is the hash of an image. Let’s call the resulting (123387, 64) array hashes.

Given what we know, the following is a very fast function that will find the Hamming distances between a single image and all of the other images.

def hamming_distance(hash1, hash2):
    return np.sum(~(hash1 == hash2), axis=1)
%%time

hamming_distance(hashes[0], hashes)

This runs in approximately 16 ms. This would be about 32 minutes if we calculated the distance for every pair.

Convert your numpy array into a JAX array, and create an equivalent function. How fast does this function run? What would the approximate runtime be for the total calculation?

Remember to use jax.jit to speed up the function. Also recall that the first run of the compiled function will be slow since it needs to be compiled. After that, future uses of the function will be faster.

Make sure to take into consideration the slower first run. What would the approximate total runtime be using the JAX function?

Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Question 4

Don’t worry, I’m not going to make you run these calculations. Instead, answer one of the following two questions.

  1. Pick 2 images / image hashes and get the closest 3 images by Hamming distance for each. Note the distances and display the images. At those distances, can you perceive any sort of "closeness" in image?

  2. Randomly sample (using JAX methods) n (more than 4, please) images and calculate all of the pairwise distances. Create a set of plots showing the distributions of distances. Explore the distances, and the dataset, and write 1-2 sentences about any interesting observations you are able to make, or 1-2 sentences on how you could use the information to do something cool.

Items to submit
  • Code used to solve this problem.

  • Output from running the code.

Please make sure to double check that your submission is complete, and contains all of your code and output before submitting. If you are on a spotty internet connection, it is recommended to download your submission after submitting it to make sure what you think you submitted, was what you actually submitted.

In addition, please review our submission guidelines before submitting your project.