-
Image Patch Extraction with `select` and `narrow`
In computer vision, a common operation is to extract patches from an image. Your task is to write a function that extracts a patch of a given size from a specific starting location in an image...
-
Creating a Label Mapping with `scatter` for Domain Adaptation
In domain adaptation, you might need to map labels from a source domain to a target domain. Imagine you have a set of labels and a mapping that specifies how each old label corresponds to a new...
-
Taking gradients with `grad`
One of JAX's most powerful features is its ability to automatically differentiate your code. In this exercise, you will implement a regression problem and use `jax.grad` to compute gradients. 1....
-
Vectorizing with `vmap`
Another of JAX's powerful features is its ability to automatically vectorize functions with `vmap`. For example, `vmap` can be used to automatically batch a function that was written for a single...
-
Tracing Gradient Descent on a Parabola
Imagine a simple 1D function $f(x) = x^2 - 4x + 5$. Your goal is to find the minimum of this function using Gradient Descent. 1. **Derive the gradient**: What is $\frac{df}{dx}$? 2. **Perform a...
-
Einops Warm-up: Reshaping Tensors for Expert Batching
In Mixture of Experts (MoE) models, we often need to reshape tensors to efficiently process data across multiple 'experts'. Imagine you have a batch of sequences, and for each token in each...
-
Implement a Simple Linear Regression in JAX
Your task is to implement a simple linear regression model from scratch using JAX. You'll need to define the model, a loss function (like Mean Squared Error), and a gradient descent update rule....
-
Implement a Linear Regression Model
Build a simple linear regression model using `nn.Module`. Requirements: - One input feature, one output. - Train it on synthetic data $$y = 3x + 2 + \epsilon$$. - Use `MSELoss` and `SGD`. Check...
-
Checkpointing with torch.save
Train a simple feedforward model for 1 epoch. Save: 1. Model state dict. 2. Optimizer state dict. 3. Epoch number. Then load the checkpoint and resume training seamlessly.