-
Implement a Simple Linear Regression in JAX
Your task is to implement a simple linear regression model from scratch using JAX. You'll need to define the model, a loss function (like Mean Squared Error), and a gradient descent update rule....
-
Build a Simple Neural Network with Flax
Using Flax, JAX's neural network library, build a simple Multi-Layer Perceptron (MLP). The MLP should have an input layer, one hidden layer with a ReLU activation function, and an output layer....
-
Implement a Basic Optimizer with Optax
Use Optax, JAX's optimization library, to create a simple Stochastic Gradient Descent (SGD) optimizer. You'll need to define a model, a loss function, and then use `optax.sgd` to update the...
-
Implement Layer Normalization
Implement Layer Normalization from scratch in JAX. Layer Normalization is a key component in many modern neural network architectures, like Transformers. It normalizes the inputs across the...
-
Implementing Dropout
Implement the dropout regularization technique in JAX. This involves randomly setting a fraction of input units to 0 at each update during training time. Remember that dropout should only be...
-
Build a Custom ReLU Activation Function
Implement the Rectified Linear Unit (ReLU) activation function in JAX. Then, use `jax.grad` to find its derivative. The ReLU function is defined as: $ReLU(x) = max(0, x)$ **Verification:** - For...
-
Vectorized Operations with vmap
You have a function that processes a single data point. Your goal is to use `jax.vmap` to apply this function to a whole batch of data without writing an explicit loop. For example, consider a...
-
Custom Gradient with `jax.custom_vjp`
Implement a function with a custom gradient using `jax.custom_vjp`. This is useful for numerical stability or for defining gradients for non-differentiable operations. A good example is a function...
-
Implement a Convolutional Layer
Implement a 2D convolutional layer from scratch in JAX. This will involve using `jax.lax.conv_general_dilated`. You will need to manage the kernel initialization and the forward pass logic....
-
Data Parallelism with `pmap`
Parallelize a training step across multiple devices (e.g., multiple CPU cores if you don't have GPUs/TPUs) using `jax.pmap`. This is a fundamental technique for scaling up training. **Task:** Take...
-
Debugging JAX code with `jax.debug.print`
In JAX, standard Python `print` statements don't always work as expected within `jit`-compiled functions because they execute at trace time. [11] The solution is to use `jax.debug.print`. [11, 23]...
-
Combine `vmap` and `pmap`
For more complex parallelism patterns, you can combine `vmap` and `pmap`. For instance, you can use `pmap` for data parallelism across devices, and `vmap` for model ensembling on each device....
-
Implement a Neural Ordinary Differential Equation
### Description Instead of modeling a function directly, a Neural ODE models its derivative with a neural network. The output is then found by integrating this derivative over time. [1] Your task...
-
Model-Agnostic Meta-Learning (MAML) Update Step
### Description Model-Agnostic Meta-Learning (MAML) is a meta-learning algorithm that trains a model's initial parameters such that it can adapt to a new task with only a few gradient steps. [1]...
-
Build a Transformer Encoder Block from Scratch
### Description The Transformer architecture is built upon a fundamental component: the Encoder block. [1] Each block is responsible for processing a sequence of embeddings and refining them. Your...
-
Differentiable Additive Synthesizer
### Description Differentiable Digital Signal Processing (DDSP) is a technique that combines classic signal processing with deep learning by making the parameters of synthesizers learnable via...
-
Soft Actor-Critic (SAC) Critic Loss
### Description Soft Actor-Critic (SAC) is a state-of-the-art reinforcement learning algorithm known for its stability and sample efficiency. [1] A key component is its critic (or Q-network)...
-
Implement a Knowledge Distillation Loss
### Description Knowledge Distillation is a model compression technique where a small "student" model is trained to mimic a larger, pre-trained "teacher" model. [1] This is achieved by training...
-
Masked Autoencoder (MAE) Input Preprocessing
### Description Masked Autoencoders (MAE) are a powerful self-supervised learning technique for vision transformers. The core idea is simple: randomly mask a large portion of the input image...
-
Neural Cellular Automata (NCA) Update Step
### Description Neural Cellular Automata (NCA) are a fascinating generative model where complex global patterns emerge from simple, local rules learned by a neural network. [1] A grid of "cells,"...
-
Bayesian Neural Network Layer
### Description In a standard neural network, weights are single point estimates. In a Bayesian Neural Network (BNN), we learn a probability distribution over each weight. [1] This allows for...
-
Deep Canonical Correlation Analysis (DCCA) Loss
### Description Canonical Correlation Analysis (CCA) is a statistical method for finding correlations between two sets of variables. Deep CCA (DCCA) uses neural networks to first project two...
-
Siamese Network for One-Shot Image Verification
### Description Your task is to implement a Siamese network that can determine if two images are of the same class, given only one or a few examples of that class at test time. You'll train a...
-
Physics-Informed Neural Network (PINN) for an ODE
### Description Solve a simple Ordinary Differential Equation (ODE) using a Physics-Informed Neural Network. A PINN is a neural network that is trained to satisfy both the data and the underlying...
-
Graph Convolutional Network for Node Classification
### Description Implement a simple Graph Convolutional Network (GCN) to perform node classification on a graph dataset like Cora. [1] A GCN layer aggregates information from a node's neighbors to...