-
Implementing One-Hot Encoding with `scatter_`
Your task is to create a function that performs one-hot encoding on a tensor of integer labels. This is a common preprocessing step for categorical data in machine learning. You will be given a 1D...
-
Image Patch Extraction with `select` and `narrow`
In computer vision, a common operation is to extract patches from an image. Your task is to write a function that extracts a patch of a given size from a specific starting location in an image...
-
Creating a Label Mapping with `scatter` for Domain Adaptation
In domain adaptation, you might need to map labels from a source domain to a target domain. Imagine you have a set of labels and a mapping that specifies how each old label corresponds to a new...
-
NumPy to JAX: The-Basics
This first exercise is a straightforward warm-up to familiarize yourself with the JAX NumPy API. 1. Create a function that takes two `jnp.ndarray`'s, `W` and `x`, as well as a `jnp.ndarray` `b`,...
-
The JAX approach to PRNG
JAX handles pseudo-random number generation (PRNG) differently than NumPy, which uses a global state. JAX, on the other hand, makes the state of the PRNG explicit. This is a design choice that...
-
The need for speed: `jit`
JAX's `jit` function will compile your Python code, which can lead to significant speedups. This is because JAX can fuse operations together, removing the overhead of Python's interpreter. In this...
-
Taking gradients with `grad`
One of JAX's most powerful features is its ability to automatically differentiate your code. In this exercise, you will implement a regression problem and use `jax.grad` to compute gradients. 1....
-
Taking gradients with `grad` II
By default, `jax.grad` will take the gradient with respect to the first argument of the function. However, in many cases, we will want to take gradients with respect to many of the function's...
-
Vectorizing with `vmap`
Another of JAX's powerful features is its ability to automatically vectorize functions with `vmap`. For example, `vmap` can be used to automatically batch a function that was written for a single...
-
Understanding `jit` and tracing
When you `jit` a function, JAX traces it to determine its computational graph. This graph is then compiled by XLA for efficient execution. However, this tracing mechanism has some implications....
-
Tracing Gradient Descent on a Parabola
Imagine a simple 1D function $f(x) = x^2 - 4x + 5$. Your goal is to find the minimum of this function using Gradient Descent. 1. **Derive the gradient**: What is $\frac{df}{dx}$? 2. **Perform a...
-
Numerical Stability: Log-Sum-Exp
When dealing with probabilities, especially in log-space, sums of exponentials can lead to numerical underflow or overflow. For example, computing $\log \left( \sum_i \exp(x_i) \right)$ can be...
-
L2 Regularization Gradient
L2 regularization (also known as Ridge Regression or weight decay) is a common technique to prevent overfitting in machine learning models by adding a penalty proportional to the square of the...
-
Einops Warm-up: Reshaping Tensors for Expert Batching
In Mixture of Experts (MoE) models, we often need to reshape tensors to efficiently process data across multiple 'experts'. Imagine you have a batch of sequences, and for each token in each...
-
Sparse MoE Top-K Gating
### Description In a Mixture of Experts (MoE) model, the gating network is a crucial component that determines which 'expert' subnetworks process each token. [1] A common strategy is **top-k...
-
Hierarchical Patch Merging with Einops
### Description In hierarchical vision transformers like the Swin Transformer, **patch merging** is used to downsample the feature map, effectively reducing the number of tokens while increasing...
-
Batch-wise Matrix Transposition
### Description Given a batch of matrices, transpose each matrix in the batch. The input tensor has a shape of `(B, H, W)`, and the output should be `(B, W, H)`. ### Starter Code ```python import...
-
Global Average Pooling
### Description Implement global average pooling, a common operation in convolutional neural networks. For a batch of feature maps of shape `(B, C, H, W)`, you need to compute the mean of each...
-
Tile a Tensor
### Description Given a tensor, repeat its values along one or more dimensions. For example, given a tensor of shape `(H, W)`, you might want to create a batch of `B` identical copies, resulting...
-
Concatenate Tensors Along a New Axis
### Description Given a list of tensors of the same shape, concatenate them along a new axis. For example, given 3 tensors of shape `(H, W)`, you want to create a single tensor of shape `(3, H,...
-
Channel-wise Max Pooling
### Description Perform max pooling over the channel dimension. Given a tensor of shape `(B, C, H, W)`, find the maximum value across all channels for each spatial location. The output should have...
-
Swap Height and Width
### Description For a batch of images or feature maps, swap the height and width dimensions. The input shape is `(B, C, H, W)` and the output shape should be `(B, C, W, H)`. ### Starter Code...
-
Permute Dimensions Cyclically
### Description Perform a cyclic permutation of the dimensions of a tensor. For a tensor of shape `(D1, D2, D3, D4)`, a cyclic permutation would result in `(D4, D1, D2, D3)`. ### Starter Code...
-
Flatten Leading Dimensions
### Description Given a tensor with multiple leading dimensions, flatten them into a single dimension. For example, transform a tensor of shape `(D1, D2, D3, D4)` into `(D1*D2, D3, D4)`. ###...
-
Unflatten a Dimension
### Description This is the inverse of flattening. Given a tensor where the first dimension is a product of two other dimensions, unflatten it. For example, transform a tensor of shape `(D1*D2,...