-
Advanced Indexing with `gather` for NLP
In Natural Language Processing, it's common to work with sequences of varying lengths. A frequent task is to extract the activations of the last token in each sequence from a tensor of shape...
-
Image Patch Extraction with `select` and `narrow`
In computer vision, a common operation is to extract patches from an image. Your task is to write a function that extracts a patch of a given size from a specific starting location in an image...
-
Sparse Updates with `scatter_add_`
In graph neural networks and other sparse data applications, you often need to update a tensor based on sparse indices. Your exercise is to implement a function that takes a tensor of `values`, a...
-
Selecting RoIs (Regions of Interest) with `index_select`
In object detection tasks, after a region proposal network (RPN) suggests potential object locations, these regions of interest (RoIs) need to be extracted from the feature map for further...
-
Creating a Label Mapping with `scatter` for Domain Adaptation
In domain adaptation, you might need to map labels from a source domain to a target domain. Imagine you have a set of labels and a mapping that specifies how each old label corresponds to a new...
-
Replicating `torch.nn.Embedding` with `gather`
The `torch.nn.Embedding` layer is fundamental in many deep learning models, especially in NLP. Your task is to replicate its forward pass functionality using `torch.gather`. You'll create a...
-
Taking gradients with `grad`
One of JAX's most powerful features is its ability to automatically differentiate your code. In this exercise, you will implement a regression problem and use `jax.grad` to compute gradients. 1....
-
Vectorizing with `vmap`
Another of JAX's powerful features is its ability to automatically vectorize functions with `vmap`. For example, `vmap` can be used to automatically batch a function that was written for a single...
-
Updating model parameters
In this exercise, you will implement a full training step for the regression problem that you have been working on. 1. Instantiate your model parameters, `W` and `b`, and your data `x` and...
-
PyTrees
A PyTree is any nested structure of dictionaries, lists, and tuples. JAX is designed to work with PyTrees, which allows for a more organized way of handling model parameters. In this exercise, you...
-
Conditionals with `jit`
Standard Python control flow, like `if` statements, can cause issues with `jit` when the condition depends on a traced value. This is because JAX needs to know the entire computation graph at...
-
Loops with `jit`
Similar to conditionals, standard Python `for` or `while` loops can cause problems with `jit` if the loop's duration depends on a traced value. JAX provides `jax.lax.fori_loop` and...
-
Working with an optimizer library: Optax
Optax is a popular library for optimization in JAX. It provides a wide range of optimizers and is designed to be highly modular. In this exercise, you will use Optax to train the Flax MLP from the...
-
Checkpointing
When training large models, it is important to save the model's parameters periodically. This is known as checkpointing and allows you to resume training from a saved state in case of an...
-
Tracing Gradient Descent on a Parabola
Imagine a simple 1D function $f(x) = x^2 - 4x + 5$. Your goal is to find the minimum of this function using Gradient Descent. 1. **Derive the gradient**: What is $\frac{df}{dx}$? 2. **Perform a...
-
Cross-Entropy: A Measure of Surprise
Cross-entropy loss is fundamental for classification tasks. Let's build some intuition for its formulation. 1. **Definition**: For a binary classification problem, the binary cross-entropy (BCE)...
-
The Implicit Higher Dimension of Kernels
Support Vector Machines (SVMs) are powerful, and the "kernel trick" allows them to find non-linear decision boundaries without explicitly mapping data to high-dimensional spaces. 1. **Linear...
-
KL Divergence Calculation and Interpretation
The Kullback-Leibler (KL) Divergence (also known as relative entropy) is a non-symmetric measure of how one probability distribution $P$ is different from a second, reference probability...
-
Matrix Multiplication and Efficiency
Matrix multiplication is a fundamental operation in linear algebra and a cornerstone of deep learning. Given two matrices $A$ (size $m \times k$) and $B$ (size $k \times n$), their product $C =...
-
Einops Warm-up: Reshaping Tensors for Expert Batching
In Mixture of Experts (MoE) models, we often need to reshape tensors to efficiently process data across multiple 'experts'. Imagine you have a batch of sequences, and for each token in each...
-
Implement a Simple Linear Regression in JAX
Your task is to implement a simple linear regression model from scratch using JAX. You'll need to define the model, a loss function (like Mean Squared Error), and a gradient descent update rule....
-
Build a Simple Neural Network with Flax
Using Flax, JAX's neural network library, build a simple Multi-Layer Perceptron (MLP). The MLP should have an input layer, one hidden layer with a ReLU activation function, and an output layer....
-
Implement a Basic Optimizer with Optax
Use Optax, JAX's optimization library, to create a simple Stochastic Gradient Descent (SGD) optimizer. You'll need to define a model, a loss function, and then use `optax.sgd` to update the...
-
Implement Layer Normalization
Implement Layer Normalization from scratch in JAX. Layer Normalization is a key component in many modern neural network architectures, like Transformers. It normalizes the inputs across the...
-
Implementing Dropout
Implement the dropout regularization technique in JAX. This involves randomly setting a fraction of input units to 0 at each update during training time. Remember that dropout should only be...