-
Implementing One-Hot Encoding with `scatter_`
Your task is to create a function that performs one-hot encoding on a tensor of integer labels. This is a common preprocessing step for categorical data in machine learning. You will be given a 1D...
-
NumPy to JAX: The-Basics
This first exercise is a straightforward warm-up to familiarize yourself with the JAX NumPy API. 1. Create a function that takes two `jnp.ndarray`'s, `W` and `x`, as well as a `jnp.ndarray` `b`,...
-
The JAX approach to PRNG
JAX handles pseudo-random number generation (PRNG) differently than NumPy, which uses a global state. JAX, on the other hand, makes the state of the PRNG explicit. This is a design choice that...
-
The need for speed: `jit`
JAX's `jit` function will compile your Python code, which can lead to significant speedups. This is because JAX can fuse operations together, removing the overhead of Python's interpreter. In this...
-
Taking gradients with `grad` II
By default, `jax.grad` will take the gradient with respect to the first argument of the function. However, in many cases, we will want to take gradients with respect to many of the function's...
-
Understanding `jit` and tracing
When you `jit` a function, JAX traces it to determine its computational graph. This graph is then compiled by XLA for efficient execution. However, this tracing mechanism has some implications....
-
Numerical Stability: Log-Sum-Exp
When dealing with probabilities, especially in log-space, sums of exponentials can lead to numerical underflow or overflow. For example, computing $\log \left( \sum_i \exp(x_i) \right)$ can be...
-
L2 Regularization Gradient
L2 regularization (also known as Ridge Regression or weight decay) is a common technique to prevent overfitting in machine learning models by adding a penalty proportional to the square of the...
-
Sparse MoE Top-K Gating
### Description In a Mixture of Experts (MoE) model, the gating network is a crucial component that determines which 'expert' subnetworks process each token. [1] A common strategy is **top-k...
-
Hierarchical Patch Merging with Einops
### Description In hierarchical vision transformers like the Swin Transformer, **patch merging** is used to downsample the feature map, effectively reducing the number of tokens while increasing...
-
Sliding Window Attention Preparation
### Description Full self-attention has a quadratic complexity with respect to sequence length, which is prohibitive for very long sequences. Models like Longformer introduce **sliding window...
-
MoE Gating: Top-K Selection
### Description In a Mixture of Experts (MoE) model, a gating network is responsible for routing each input token to a subset of 'expert' networks. [6, 14] A common strategy is Top-K gating, where...
-
Batch-wise Matrix Transposition
### Description Given a batch of matrices, transpose each matrix in the batch. The input tensor has a shape of `(B, H, W)`, and the output should be `(B, W, H)`. ### Starter Code ```python import...
-
Global Average Pooling
### Description Implement global average pooling, a common operation in convolutional neural networks. For a batch of feature maps of shape `(B, C, H, W)`, you need to compute the mean of each...
-
Multi-Head Attention: Splitting Heads
### Description In multi-head attention, the query, key, and value tensors are split into multiple heads. Given a tensor of shape `(B, N, D)`, where `D` is the embedding dimension, you need to...
-
Multi-Head Attention: Merging Heads
### Description The inverse of splitting heads. After computing attention for each head, you need to merge them back. Given a tensor of shape `(B, H, N, D//H)`, you need to merge it back to `(B,...
-
Tile a Tensor
### Description Given a tensor, repeat its values along one or more dimensions. For example, given a tensor of shape `(H, W)`, you might want to create a batch of `B` identical copies, resulting...
-
Concatenate Tensors Along a New Axis
### Description Given a list of tensors of the same shape, concatenate them along a new axis. For example, given 3 tensors of shape `(H, W)`, you want to create a single tensor of shape `(3, H,...
-
Channel-wise Max Pooling
### Description Perform max pooling over the channel dimension. Given a tensor of shape `(B, C, H, W)`, find the maximum value across all channels for each spatial location. The output should have...
-
Swap Height and Width
### Description For a batch of images or feature maps, swap the height and width dimensions. The input shape is `(B, C, H, W)` and the output shape should be `(B, C, W, H)`. ### Starter Code...
-
Permute Dimensions Cyclically
### Description Perform a cyclic permutation of the dimensions of a tensor. For a tensor of shape `(D1, D2, D3, D4)`, a cyclic permutation would result in `(D4, D1, D2, D3)`. ### Starter Code...
-
Flatten Leading Dimensions
### Description Given a tensor with multiple leading dimensions, flatten them into a single dimension. For example, transform a tensor of shape `(D1, D2, D3, D4)` into `(D1*D2, D3, D4)`. ###...
-
Unflatten a Dimension
### Description This is the inverse of flattening. Given a tensor where the first dimension is a product of two other dimensions, unflatten it. For example, transform a tensor of shape `(D1*D2,...
-
Pixel Unshuffle (Pixel to Channel)
### Description This is another name for the space-to-depth operation, common in super-resolution models. It involves rearranging blocks of spatial data into the channel dimension. Given a tensor...
-
Pixel Shuffle (Channel to Pixel)
### Description The inverse of pixel unshuffle, also known as depth-to-space. It is used to upscale an image by rearranging elements from the channel dimension into spatial blocks. Given a tensor...