-
MoE Gating and Dispatch
A core component of a Mixture of Experts model is the 'gating network' which determines which expert(s) each token should be sent to. This is often a `top-k` selection. Your task is to implement...
-
MoE Aggregator: Combining Expert Outputs
After tokens have been dispatched to and processed by their respective experts, the outputs need to be combined based on the weights from the gating network. This exercise focuses on this...
-
Building a Simple Mixture of Experts (MoE) Layer
Now, let's combine the concepts of dispatching and aggregating into a full, albeit simplified, `torch.nn.Module` for a Mixture of Experts layer. This layer will replace a standard feed-forward...
-
Batched Expert Forward Pass with Einops
A naive implementation of an MoE layer might involve a loop over the experts. This is inefficient. A much better approach is to perform a single, batched matrix multiplication for all expert...
-
Sparse MoE Top-K Gating
### Description In a Mixture of Experts (MoE) model, the gating network is a crucial component that determines which 'expert' subnetworks process each token. [1] A common strategy is **top-k...
-
Hierarchical Patch Merging with Einops
### Description In hierarchical vision transformers like the Swin Transformer, **patch merging** is used to downsample the feature map, effectively reducing the number of tokens while increasing...
-
Sliding Window Attention Preparation
### Description Full self-attention has a quadratic complexity with respect to sequence length, which is prohibitive for very long sequences. Models like Longformer introduce **sliding window...
-
MoE Gating: Top-K Selection
### Description In a Mixture of Experts (MoE) model, a gating network is responsible for routing each input token to a subset of 'expert' networks. [6, 14] A common strategy is Top-K gating, where...
-
Batch-wise Matrix Transposition
### Description Given a batch of matrices, transpose each matrix in the batch. The input tensor has a shape of `(B, H, W)`, and the output should be `(B, W, H)`. ### Starter Code ```python import...
-
Global Average Pooling
### Description Implement global average pooling, a common operation in convolutional neural networks. For a batch of feature maps of shape `(B, C, H, W)`, you need to compute the mean of each...
-
Multi-Head Attention: Splitting Heads
### Description In multi-head attention, the query, key, and value tensors are split into multiple heads. Given a tensor of shape `(B, N, D)`, where `D` is the embedding dimension, you need to...
-
Multi-Head Attention: Merging Heads
### Description The inverse of splitting heads. After computing attention for each head, you need to merge them back. Given a tensor of shape `(B, H, N, D//H)`, you need to merge it back to `(B,...
-
Tile a Tensor
### Description Given a tensor, repeat its values along one or more dimensions. For example, given a tensor of shape `(H, W)`, you might want to create a batch of `B` identical copies, resulting...
-
Concatenate Tensors Along a New Axis
### Description Given a list of tensors of the same shape, concatenate them along a new axis. For example, given 3 tensors of shape `(H, W)`, you want to create a single tensor of shape `(3, H,...
-
Channel-wise Max Pooling
### Description Perform max pooling over the channel dimension. Given a tensor of shape `(B, C, H, W)`, find the maximum value across all channels for each spatial location. The output should have...
-
Swap Height and Width
### Description For a batch of images or feature maps, swap the height and width dimensions. The input shape is `(B, C, H, W)` and the output shape should be `(B, C, W, H)`. ### Starter Code...
-
Permute Dimensions Cyclically
### Description Perform a cyclic permutation of the dimensions of a tensor. For a tensor of shape `(D1, D2, D3, D4)`, a cyclic permutation would result in `(D4, D1, D2, D3)`. ### Starter Code...
-
Flatten Leading Dimensions
### Description Given a tensor with multiple leading dimensions, flatten them into a single dimension. For example, transform a tensor of shape `(D1, D2, D3, D4)` into `(D1*D2, D3, D4)`. ###...
-
Unflatten a Dimension
### Description This is the inverse of flattening. Given a tensor where the first dimension is a product of two other dimensions, unflatten it. For example, transform a tensor of shape `(D1*D2,...
-
Pixel Unshuffle (Pixel to Channel)
### Description This is another name for the space-to-depth operation, common in super-resolution models. It involves rearranging blocks of spatial data into the channel dimension. Given a tensor...
-
Pixel Shuffle (Channel to Pixel)
### Description The inverse of pixel unshuffle, also known as depth-to-space. It is used to upscale an image by rearranging elements from the channel dimension into spatial blocks. Given a tensor...
-
Grouped Matrix Multiplication
### Description Perform matrix multiplication on groups of matrices within a batch. For an input tensor of shape `(B, G, M, K)` and another of shape `(B, G, K, N)`, the output should be of shape...
-
Bilinear Attention Pooling
### Description In some attention mechanisms, you need to compute a bilinear interaction between two sets of features. Given two tensors of shapes `(B, N, D)` and `(B, M, D)`, compute a bilinear...
-
Extract Diagonal from a Batch of Matrices
### Description Given a batch of square matrices of shape `(B, N, N)`, extract the diagonal of each matrix. The output should be a tensor of shape `(B, N)`. This can be achieved with...
-
Implement a Simple Linear Regression in JAX
Your task is to implement a simple linear regression model from scratch using JAX. You'll need to define the model, a loss function (like Mean Squared Error), and a gradient descent update rule....