-
Implementing One-Hot Encoding with `scatter_`
Your task is to create a function that performs one-hot encoding on a tensor of integer labels. This is a common preprocessing step for categorical data in machine learning. You will be given a 1D...
-
NumPy to JAX: The-Basics
This first exercise is a straightforward warm-up to familiarize yourself with the JAX NumPy API. 1. Create a function that takes two `jnp.ndarray`'s, `W` and `x`, as well as a `jnp.ndarray` `b`,...
-
The JAX approach to PRNG
JAX handles pseudo-random number generation (PRNG) differently than NumPy, which uses a global state. JAX, on the other hand, makes the state of the PRNG explicit. This is a design choice that...
-
The need for speed: `jit`
JAX's `jit` function will compile your Python code, which can lead to significant speedups. This is because JAX can fuse operations together, removing the overhead of Python's interpreter. In this...
-
Taking gradients with `grad` II
By default, `jax.grad` will take the gradient with respect to the first argument of the function. However, in many cases, we will want to take gradients with respect to many of the function's...
-
Understanding `jit` and tracing
When you `jit` a function, JAX traces it to determine its computational graph. This graph is then compiled by XLA for efficient execution. However, this tracing mechanism has some implications....
-
Numerical Stability: Log-Sum-Exp
When dealing with probabilities, especially in log-space, sums of exponentials can lead to numerical underflow or overflow. For example, computing $\log \left( \sum_i \exp(x_i) \right)$ can be...
-
L2 Regularization Gradient
L2 regularization (also known as Ridge Regression or weight decay) is a common technique to prevent overfitting in machine learning models by adding a penalty proportional to the square of the...
-
Sparse MoE Top-K Gating
### Description In a Mixture of Experts (MoE) model, the gating network is a crucial component that determines which 'expert' subnetworks process each token. [1] A common strategy is **top-k...
-
Hierarchical Patch Merging with Einops
### Description In hierarchical vision transformers like the Swin Transformer, **patch merging** is used to downsample the feature map, effectively reducing the number of tokens while increasing...
-
Batch-wise Matrix Transposition
### Description Given a batch of matrices, transpose each matrix in the batch. The input tensor has a shape of `(B, H, W)`, and the output should be `(B, W, H)`. ### Starter Code ```python import...
-
Global Average Pooling
### Description Implement global average pooling, a common operation in convolutional neural networks. For a batch of feature maps of shape `(B, C, H, W)`, you need to compute the mean of each...
-
Tile a Tensor
### Description Given a tensor, repeat its values along one or more dimensions. For example, given a tensor of shape `(H, W)`, you might want to create a batch of `B` identical copies, resulting...
-
Concatenate Tensors Along a New Axis
### Description Given a list of tensors of the same shape, concatenate them along a new axis. For example, given 3 tensors of shape `(H, W)`, you want to create a single tensor of shape `(3, H,...
-
Channel-wise Max Pooling
### Description Perform max pooling over the channel dimension. Given a tensor of shape `(B, C, H, W)`, find the maximum value across all channels for each spatial location. The output should have...
-
Swap Height and Width
### Description For a batch of images or feature maps, swap the height and width dimensions. The input shape is `(B, C, H, W)` and the output shape should be `(B, C, W, H)`. ### Starter Code...
-
Permute Dimensions Cyclically
### Description Perform a cyclic permutation of the dimensions of a tensor. For a tensor of shape `(D1, D2, D3, D4)`, a cyclic permutation would result in `(D4, D1, D2, D3)`. ### Starter Code...
-
Flatten Leading Dimensions
### Description Given a tensor with multiple leading dimensions, flatten them into a single dimension. For example, transform a tensor of shape `(D1, D2, D3, D4)` into `(D1*D2, D3, D4)`. ###...
-
Unflatten a Dimension
### Description This is the inverse of flattening. Given a tensor where the first dimension is a product of two other dimensions, unflatten it. For example, transform a tensor of shape `(D1*D2,...
-
Build a Custom ReLU Activation Function
Implement the Rectified Linear Unit (ReLU) activation function in JAX. Then, use `jax.grad` to find its derivative. The ReLU function is defined as: $ReLU(x) = max(0, x)$ **Verification:** - For...
-
Vectorized Operations with vmap
You have a function that processes a single data point. Your goal is to use `jax.vmap` to apply this function to a whole batch of data without writing an explicit loop. For example, consider a...
-
Debugging JAX code with `jax.debug.print`
In JAX, standard Python `print` statements don't always work as expected within `jit`-compiled functions because they execute at trace time. [11] The solution is to use `jax.debug.print`. [11, 23]...
-
Implementing Gradient Clipping
Implement **gradient clipping** in your training loop. This technique is used to prevent exploding gradients, which can be a problem in RNNs and other deep networks. After the backward pass...
-
Matrix Multiplication Basics
Implement a function in PyTorch that multiplies two matrices using `torch.mm`. ### Problem Write a function `matmul(A, B)` that takes two 2D tensors `A` and `B` and returns their matrix product. -...
-
ReLU Activation Function
Implement the ReLU (Rectified Linear Unit) function in PyTorch. ### Problem Write a function `relu(x)` that takes a 1D tensor and replaces all negative values with 0. - **Input:** A tensor `x` of...