-
Custom Gradient with `jax.custom_vjp`
Implement a function with a custom gradient using `jax.custom_vjp`. This is useful for numerical stability or for defining gradients for non-differentiable operations. A good example is a function...
-
Implement a Convolutional Layer
Implement a 2D convolutional layer from scratch in JAX. This will involve using `jax.lax.conv_general_dilated`. You will need to manage the kernel initialization and the forward pass logic....
-
Data Parallelism with `pmap`
Parallelize a training step across multiple devices (e.g., multiple CPU cores if you don't have GPUs/TPUs) using `jax.pmap`. This is a fundamental technique for scaling up training. **Task:** Take...
-
Combine `vmap` and `pmap`
For more complex parallelism patterns, you can combine `vmap` and `pmap`. For instance, you can use `pmap` for data parallelism across devices, and `vmap` for model ensembling on each device....
-
Implement a Neural Ordinary Differential Equation
### Description Instead of modeling a function directly, a Neural ODE models its derivative with a neural network. The output is then found by integrating this derivative over time. [1] Your task...
-
Model-Agnostic Meta-Learning (MAML) Update Step
### Description Model-Agnostic Meta-Learning (MAML) is a meta-learning algorithm that trains a model's initial parameters such that it can adapt to a new task with only a few gradient steps. [1]...
-
Build a Transformer Encoder Block from Scratch
### Description The Transformer architecture is built upon a fundamental component: the Encoder block. [1] Each block is responsible for processing a sequence of embeddings and refining them. Your...
-
Soft Actor-Critic (SAC) Critic Loss
### Description Soft Actor-Critic (SAC) is a state-of-the-art reinforcement learning algorithm known for its stability and sample efficiency. [1] A key component is its critic (or Q-network)...
-
Neural Cellular Automata (NCA) Update Step
### Description Neural Cellular Automata (NCA) are a fascinating generative model where complex global patterns emerge from simple, local rules learned by a neural network. [1] A grid of "cells,"...
-
Bayesian Neural Network Layer
### Description In a standard neural network, weights are single point estimates. In a Bayesian Neural Network (BNN), we learn a probability distribution over each weight. [1] This allows for...
-
Siamese Network for One-Shot Image Verification
### Description Your task is to implement a Siamese network that can determine if two images are of the same class, given only one or a few examples of that class at test time. You'll train a...
-
Physics-Informed Neural Network (PINN) for an ODE
### Description Solve a simple Ordinary Differential Equation (ODE) using a Physics-Informed Neural Network. A PINN is a neural network that is trained to satisfy both the data and the underlying...
-
Graph Convolutional Network for Node Classification
### Description Implement a simple Graph Convolutional Network (GCN) to perform node classification on a graph dataset like Cora. [1] A GCN layer aggregates information from a node's neighbors to...
-
HyperNetwork for Weight Generation
### Description Implement a simple HyperNetwork. A HyperNetwork is a neural network that generates the weights for another, larger network (the "target network"). [1] This allows for dynamic...
-
Normalizing Flow for Density Estimation
### Description Implement a simple 2D Normalizing Flow model. Normalizing Flows transform a simple base distribution (like a Gaussian) into a more complex distribution by applying a sequence of...
-
Spiking Neuron with Leaky Integrate-and-Fire
### Description Implement a single Leaky Integrate-and-Fire (LIF) neuron, the fundamental building block of many Spiking Neural Networks (SNNs). Unlike traditional neurons, LIF neurons operate on...
-
Implementing a Siamese Network with Triplet Loss
Building on the previous exercise, let's switch to **Triplet Loss**. This loss function is more powerful as it enforces a margin between an anchor-positive pair and an anchor-negative pair. The...
-
Implementing Self-Supervised Learning with BYOL
Implement the core logic of **Bootstrap Your Own Latent (BYOL)**. BYOL is a self-supervised learning method that learns image representations without using negative pairs. It consists of two...
-
Implementing a Multi-Headed Attention Mechanism
Expand on the previous attention exercise by implementing a **Multi-Headed Attention mechanism** from scratch. A single attention head is a dot-product attention as you've implemented. Multi-head...
-
Implementing a Masked Language Model
Implement a **Masked Language Model (MLM)**, a technique at the heart of models like BERT. Given a sentence, you'll randomly mask some of the words and then train a model to predict those masked...
-
Building a Graph Autoencoder
Implement a **Graph Autoencoder (GAE)** for graph representation learning. The encoder will use a GNN to produce node embeddings, and the decoder will reconstruct the graph's adjacency matrix from...
-
Adversarial Training for Robustness
Implement **adversarial training** on a simple classification model like a small CNN on MNIST. The goal is to make the model robust to adversarial attacks. You'll need to generate adversarial...
-
Neural Style Transfer
Implement **Neural Style Transfer**. Given a content image and a style image, generate a new image that combines the content of the former with the style of the latter. Use a pre-trained VGG...
-
Training a Variational Autoencoder (VAE)
Implement and train a **Variational Autoencoder (VAE)** on a dataset like MNIST. The encoder should map the input to a latent space distribution (mean and variance), and the decoder should...
-
Implementing the Adam Optimizer from Scratch
Implement the **Adam optimizer from scratch** as a subclass of `torch.optim.Optimizer`. You'll need to manage the first-moment vector (moving average of gradients) and the second-moment vector...