-
Implementing One-Hot Encoding with `scatter_`
Your task is to create a function that performs one-hot encoding on a tensor of integer labels. This is a common preprocessing step for categorical data in machine learning. You will be given a 1D...
-
Advanced Indexing with `gather` for NLP
In Natural Language Processing, it's common to work with sequences of varying lengths. A frequent task is to extract the activations of the last token in each sequence from a tensor of shape...
-
Image Patch Extraction with `select` and `narrow`
In computer vision, a common operation is to extract patches from an image. Your task is to write a function that extracts a patch of a given size from a specific starting location in an image...
-
Sparse Updates with `scatter_add_`
In graph neural networks and other sparse data applications, you often need to update a tensor based on sparse indices. Your exercise is to implement a function that takes a tensor of `values`, a...
-
Implementing a simplified Beam Search Decoder with `gather` and `scatter`
Beam search is a popular decoding algorithm used in machine translation and text generation. A key step in beam search is to select the top-k most likely next tokens and update the corresponding...
-
Selecting RoIs (Regions of Interest) with `index_select`
In object detection tasks, after a region proposal network (RPN) suggests potential object locations, these regions of interest (RoIs) need to be extracted from the feature map for further...
-
Creating a Label Mapping with `scatter` for Domain Adaptation
In domain adaptation, you might need to map labels from a source domain to a target domain. Imagine you have a set of labels and a mapping that specifies how each old label corresponds to a new...
-
Replicating `torch.nn.Embedding` with `gather`
The `torch.nn.Embedding` layer is fundamental in many deep learning models, especially in NLP. Your task is to replicate its forward pass functionality using `torch.gather`. You'll create a...
-
Efficiently Updating a Sub-Tensor with `scatter_`
Sometimes you need to update a portion of a larger tensor with values from a smaller tensor, where the update locations are specified by an index tensor. This is a common pattern in various...
-
NumPy to JAX: The-Basics
This first exercise is a straightforward warm-up to familiarize yourself with the JAX NumPy API. 1. Create a function that takes two `jnp.ndarray`'s, `W` and `x`, as well as a `jnp.ndarray` `b`,...
-
The JAX approach to PRNG
JAX handles pseudo-random number generation (PRNG) differently than NumPy, which uses a global state. JAX, on the other hand, makes the state of the PRNG explicit. This is a design choice that...
-
The need for speed: `jit`
JAX's `jit` function will compile your Python code, which can lead to significant speedups. This is because JAX can fuse operations together, removing the overhead of Python's interpreter. In this...
-
Taking gradients with `grad`
One of JAX's most powerful features is its ability to automatically differentiate your code. In this exercise, you will implement a regression problem and use `jax.grad` to compute gradients. 1....
-
Taking gradients with `grad` II
By default, `jax.grad` will take the gradient with respect to the first argument of the function. However, in many cases, we will want to take gradients with respect to many of the function's...
-
Vectorizing with `vmap`
Another of JAX's powerful features is its ability to automatically vectorize functions with `vmap`. For example, `vmap` can be used to automatically batch a function that was written for a single...
-
Updating model parameters
In this exercise, you will implement a full training step for the regression problem that you have been working on. 1. Instantiate your model parameters, `W` and `b`, and your data `x` and...
-
A simple MLP
Now that you have all of the basic building blocks, it's time to build a simple multi-layer perceptron (MLP). In this exercise, you are to build a 2-layer MLP for a regression problem. 1....
-
PyTrees
A PyTree is any nested structure of dictionaries, lists, and tuples. JAX is designed to work with PyTrees, which allows for a more organized way of handling model parameters. In this exercise, you...
-
A simple CNN
In this exercise, you will implement a simple convolutional neural network (CNN) for a regression problem. You can use `jax.lax.conv_general_dilated` to implement the convolution. 1. Implement a...
-
Understanding `jit` and tracing
When you `jit` a function, JAX traces it to determine its computational graph. This graph is then compiled by XLA for efficient execution. However, this tracing mechanism has some implications....
-
Conditionals with `jit`
Standard Python control flow, like `if` statements, can cause issues with `jit` when the condition depends on a traced value. This is because JAX needs to know the entire computation graph at...
-
Loops with `jit`
Similar to conditionals, standard Python `for` or `while` loops can cause problems with `jit` if the loop's duration depends on a traced value. JAX provides `jax.lax.fori_loop` and...
-
Forward-mode vs. Reverse-mode autodiff
JAX supports both forward-mode and reverse-mode automatic differentiation. While `grad` uses reverse-mode, you can use `jax.jvp` for forward-mode, which computes Jacobian-vector products....
-
Custom VJP
For some functions, you may want to define a custom vector-Jacobian product (VJP). This can be useful for numerical stability or for implementing algorithms that are not easily expressed in terms...
-
Working with a NN library: Flax
While it is possible to build neural networks from scratch in JAX, it is often more convenient to use a library like Flax or Haiku. These libraries provide common neural network layers and...