-
Advanced Indexing with `gather` for NLP
In Natural Language Processing, it's common to work with sequences of varying lengths. A frequent task is to extract the activations of the last token in each sequence from a tensor of shape...
-
Sparse Updates with `scatter_add_`
In graph neural networks and other sparse data applications, you often need to update a tensor based on sparse indices. Your exercise is to implement a function that takes a tensor of `values`, a...
-
Selecting RoIs (Regions of Interest) with `index_select`
In object detection tasks, after a region proposal network (RPN) suggests potential object locations, these regions of interest (RoIs) need to be extracted from the feature map for further...
-
Replicating `torch.nn.Embedding` with `gather`
The `torch.nn.Embedding` layer is fundamental in many deep learning models, especially in NLP. Your task is to replicate its forward pass functionality using `torch.gather`. You'll create a...
-
Updating model parameters
In this exercise, you will implement a full training step for the regression problem that you have been working on. 1. Instantiate your model parameters, `W` and `b`, and your data `x` and...
-
A simple MLP
Now that you have all of the basic building blocks, it's time to build a simple multi-layer perceptron (MLP). In this exercise, you are to build a 2-layer MLP for a regression problem. 1....
-
PyTrees
A PyTree is any nested structure of dictionaries, lists, and tuples. JAX is designed to work with PyTrees, which allows for a more organized way of handling model parameters. In this exercise, you...
-
A simple CNN
In this exercise, you will implement a simple convolutional neural network (CNN) for a regression problem. You can use `jax.lax.conv_general_dilated` to implement the convolution. 1. Implement a...
-
Conditionals with `jit`
Standard Python control flow, like `if` statements, can cause issues with `jit` when the condition depends on a traced value. This is because JAX needs to know the entire computation graph at...
-
Loops with `jit`
Similar to conditionals, standard Python `for` or `while` loops can cause problems with `jit` if the loop's duration depends on a traced value. JAX provides `jax.lax.fori_loop` and...
-
Working with a NN library: Flax
While it is possible to build neural networks from scratch in JAX, it is often more convenient to use a library like Flax or Haiku. These libraries provide common neural network layers and...
-
Working with an optimizer library: Optax
Optax is a popular library for optimization in JAX. It provides a wide range of optimizers and is designed to be highly modular. In this exercise, you will use Optax to train the Flax MLP from the...
-
Checkpointing
When training large models, it is important to save the model's parameters periodically. This is known as checkpointing and allows you to resume training from a saved state in case of an...
-
Softmax's Numerical Stability: The Max Trick
While the standard softmax formula $\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}}$ is mathematically correct, a direct implementation can lead to numerical instability due to potential...
-
The Gradients of Activation Functions
Activation functions introduce non-linearity into neural networks, but their derivatives are crucial for backpropagation. 1. **Sigmoid**: Given $\sigma(x) = \frac{1}{1 + e^{-x}}$, derive...
-
Cross-Entropy: A Measure of Surprise
Cross-entropy loss is fundamental for classification tasks. Let's build some intuition for its formulation. 1. **Definition**: For a binary classification problem, the binary cross-entropy (BCE)...
-
L2 Regularization's Gradient Impact
L2 regularization (also known as weight decay) is a common technique to prevent overfitting. 1. **Loss Function**: Consider a simple linear regression loss with L2 regularization: $J(\mathbf{w},...
-
The Stabilizing Power of Batch Normalization
Batch Normalization (BatchNorm) is a crucial technique for stabilizing and accelerating deep neural network training. 1. **Normalization Step**: Given a mini-batch of activations $X = \{x_1, x_2,...
-
The Implicit Higher Dimension of Kernels
Support Vector Machines (SVMs) are powerful, and the "kernel trick" allows them to find non-linear decision boundaries without explicitly mapping data to high-dimensional spaces. 1. **Linear...
-
Riding the Momentum Wave in Optimization
Stochastic Gradient Descent (SGD) with momentum is a popular optimization algorithm that often converges faster and more stably than plain SGD. 1. **Update Rule**: The update rule for SGD with...
-
KL Divergence Calculation and Interpretation
The Kullback-Leibler (KL) Divergence (also known as relative entropy) is a non-symmetric measure of how one probability distribution $P$ is different from a second, reference probability...
-
Numerical Gradient Verification
Understanding and correctly implementing backpropagation is crucial in deep learning. A common way to debug backpropagation is using numerical gradient checking. This involves approximating the...
-
Linear Regression via Gradient Descent
Linear regression is a foundational supervised learning algorithm. Given a dataset of input features $X$ and corresponding target values $y$, the goal is to find a linear relationship $y =...
-
Matrix Multiplication and Efficiency
Matrix multiplication is a fundamental operation in linear algebra and a cornerstone of deep learning. Given two matrices $A$ (size $m \times k$) and $B$ (size $k \times n$), their product $C =...
-
MoE Gating and Dispatch
A core component of a Mixture of Experts model is the 'gating network' which determines which expert(s) each token should be sent to. This is often a `top-k` selection. Your task is to implement...