-
Batched Expert Forward Pass with Einops
A naive implementation of an MoE layer might involve a loop over the experts. This is inefficient. A much better approach is to perform a single, batched matrix multiplication for all expert...
-
Sliding Window Attention Preparation
### Description Full self-attention has a quadratic complexity with respect to sequence length, which is prohibitive for very long sequences. Models like Longformer introduce **sliding window...
-
MoE Gating: Top-K Selection
### Description In a Mixture of Experts (MoE) model, a gating network is responsible for routing each input token to a subset of 'expert' networks. [6, 14] A common strategy is Top-K gating, where...
-
Multi-Head Attention: Splitting Heads
### Description In multi-head attention, the query, key, and value tensors are split into multiple heads. Given a tensor of shape `(B, N, D)`, where `D` is the embedding dimension, you need to...
-
Multi-Head Attention: Merging Heads
### Description The inverse of splitting heads. After computing attention for each head, you need to merge them back. Given a tensor of shape `(B, H, N, D//H)`, you need to merge it back to `(B,...
-
Pixel Unshuffle (Pixel to Channel)
### Description This is another name for the space-to-depth operation, common in super-resolution models. It involves rearranging blocks of spatial data into the channel dimension. Given a tensor...
-
Pixel Shuffle (Channel to Pixel)
### Description The inverse of pixel unshuffle, also known as depth-to-space. It is used to upscale an image by rearranging elements from the channel dimension into spatial blocks. Given a tensor...
-
Grouped Matrix Multiplication
### Description Perform matrix multiplication on groups of matrices within a batch. For an input tensor of shape `(B, G, M, K)` and another of shape `(B, G, K, N)`, the output should be of shape...
-
Bilinear Attention Pooling
### Description In some attention mechanisms, you need to compute a bilinear interaction between two sets of features. Given two tensors of shapes `(B, N, D)` and `(B, M, D)`, compute a bilinear...
-
Build a Simple Neural Network with Flax
Using Flax, JAX's neural network library, build a simple Multi-Layer Perceptron (MLP). The MLP should have an input layer, one hidden layer with a ReLU activation function, and an output layer....
-
Implement a Basic Optimizer with Optax
Use Optax, JAX's optimization library, to create a simple Stochastic Gradient Descent (SGD) optimizer. You'll need to define a model, a loss function, and then use `optax.sgd` to update the...
-
Implement Layer Normalization
Implement Layer Normalization from scratch in JAX. Layer Normalization is a key component in many modern neural network architectures, like Transformers. It normalizes the inputs across the...
-
Implementing Dropout
Implement the dropout regularization technique in JAX. This involves randomly setting a fraction of input units to 0 at each update during training time. Remember that dropout should only be...
-
Differentiable Additive Synthesizer
### Description Differentiable Digital Signal Processing (DDSP) is a technique that combines classic signal processing with deep learning by making the parameters of synthesizers learnable via...
-
Implement a Knowledge Distillation Loss
### Description Knowledge Distillation is a model compression technique where a small "student" model is trained to mimic a larger, pre-trained "teacher" model. [1] This is achieved by training...
-
Masked Autoencoder (MAE) Input Preprocessing
### Description Masked Autoencoders (MAE) are a powerful self-supervised learning technique for vision transformers. The core idea is simple: randomly mask a large portion of the input image...
-
Deep Canonical Correlation Analysis (DCCA) Loss
### Description Canonical Correlation Analysis (CCA) is a statistical method for finding correlations between two sets of variables. Deep CCA (DCCA) uses neural networks to first project two...
-
Gradient Reversal Layer
### Description Implement a Gradient Reversal Layer (GRL), a key component in Domain-Adversarial Neural Networks (DANNs). [1] The GRL acts as an identity function during the forward pass but...
-
Tiny Neural Radiance Fields (NeRF)
### Description Implement a simplified version of a Neural Radiance Field (NeRF) to represent a 2D image. [1] A NeRF learns a continuous mapping from spatial coordinates to pixel values. Instead...
-
Implement Lottery Ticket Hypothesis Pruning
### Description The Lottery Ticket Hypothesis suggests that a randomly initialized, dense network contains a smaller subnetwork (a "winning ticket") that, when trained in isolation, can match the...
-
Simple Differentiable Renderer
### Description Modern 3D deep learning often relies on differentiable rendering, allowing gradients to flow from a 2D rendered image back to 3D scene parameters. [1] Your task is to implement a...
-
Custom `nn.Module` with a Non-standard Initialization
Create a **custom `nn.Module`** for a simple feed-forward layer. Instead of the default PyTorch initialization, you'll apply a specific, non-standard initialization scheme. For example, you could...
-
Model Compression with Pruning
Implement **model pruning** to reduce the size and computational cost of a trained model. Start with a simple, over-parameterized model (e.g., a fully-connected network on MNIST). Train it to a...
-
Implementing a Custom Loss Function with `torch.autograd`
Create a **custom loss function** that inherits from `torch.nn.Module` and performs a non-standard calculation. For example, a custom Huber loss. This loss is less sensitive to outliers than Mean...
-
Custom `DataLoader` for On-the-Fly Image Generation
Create a **custom `torch.utils.data.Dataset`** that doesn't load data from disk. Instead, the `__getitem__` method should **generate** an image on the fly (e.g., a simple geometric shape, a random...