- 
                
                    Implementing One-Hot Encoding with `scatter_`Your task is to create a function that performs one-hot encoding on a tensor of integer labels. This is a common preprocessing step for categorical data in machine learning. You will be given a 1D... 
- 
                
                    NumPy to JAX: The-BasicsThis first exercise is a straightforward warm-up to familiarize yourself with the JAX NumPy API. 1. Create a function that takes two `jnp.ndarray`'s, `W` and `x`, as well as a `jnp.ndarray` `b`,... 
- 
                
                    The JAX approach to PRNGJAX handles pseudo-random number generation (PRNG) differently than NumPy, which uses a global state. JAX, on the other hand, makes the state of the PRNG explicit. This is a design choice that... 
- 
                
                    The need for speed: `jit`JAX's `jit` function will compile your Python code, which can lead to significant speedups. This is because JAX can fuse operations together, removing the overhead of Python's interpreter. In this... 
- 
                
                    Taking gradients with `grad` IIBy default, `jax.grad` will take the gradient with respect to the first argument of the function. However, in many cases, we will want to take gradients with respect to many of the function's... 
- 
                
                    Understanding `jit` and tracingWhen you `jit` a function, JAX traces it to determine its computational graph. This graph is then compiled by XLA for efficient execution. However, this tracing mechanism has some implications.... 
- 
                
                    Numerical Stability: Log-Sum-ExpWhen dealing with probabilities, especially in log-space, sums of exponentials can lead to numerical underflow or overflow. For example, computing $\log \left( \sum_i \exp(x_i) \right)$ can be... 
- 
                
                    L2 Regularization GradientL2 regularization (also known as Ridge Regression or weight decay) is a common technique to prevent overfitting in machine learning models by adding a penalty proportional to the square of the... 
- 
                
                    Sparse MoE Top-K Gating### Description In a Mixture of Experts (MoE) model, the gating network is a crucial component that determines which 'expert' subnetworks process each token. [1] A common strategy is **top-k... 
- 
                
                    Hierarchical Patch Merging with Einops### Description In hierarchical vision transformers like the Swin Transformer, **patch merging** is used to downsample the feature map, effectively reducing the number of tokens while increasing... 
- 
                
                    Sliding Window Attention Preparation### Description Full self-attention has a quadratic complexity with respect to sequence length, which is prohibitive for very long sequences. Models like Longformer introduce **sliding window... 
- 
                
                    MoE Gating: Top-K Selection### Description In a Mixture of Experts (MoE) model, a gating network is responsible for routing each input token to a subset of 'expert' networks. [6, 14] A common strategy is Top-K gating, where... 
- 
                
                    Batch-wise Matrix Transposition### Description Given a batch of matrices, transpose each matrix in the batch. The input tensor has a shape of `(B, H, W)`, and the output should be `(B, W, H)`. ### Starter Code ```python import... 
- 
                
                    Global Average Pooling### Description Implement global average pooling, a common operation in convolutional neural networks. For a batch of feature maps of shape `(B, C, H, W)`, you need to compute the mean of each... 
- 
                
                    Multi-Head Attention: Splitting Heads### Description In multi-head attention, the query, key, and value tensors are split into multiple heads. Given a tensor of shape `(B, N, D)`, where `D` is the embedding dimension, you need to... 
- 
                
                    Multi-Head Attention: Merging Heads### Description The inverse of splitting heads. After computing attention for each head, you need to merge them back. Given a tensor of shape `(B, H, N, D//H)`, you need to merge it back to `(B,... 
- 
                
                    Tile a Tensor### Description Given a tensor, repeat its values along one or more dimensions. For example, given a tensor of shape `(H, W)`, you might want to create a batch of `B` identical copies, resulting... 
- 
                
                    Concatenate Tensors Along a New Axis### Description Given a list of tensors of the same shape, concatenate them along a new axis. For example, given 3 tensors of shape `(H, W)`, you want to create a single tensor of shape `(3, H,... 
- 
                
                    Channel-wise Max Pooling### Description Perform max pooling over the channel dimension. Given a tensor of shape `(B, C, H, W)`, find the maximum value across all channels for each spatial location. The output should have... 
- 
                
                    Swap Height and Width### Description For a batch of images or feature maps, swap the height and width dimensions. The input shape is `(B, C, H, W)` and the output shape should be `(B, C, W, H)`. ### Starter Code... 
- 
                
                    Permute Dimensions Cyclically### Description Perform a cyclic permutation of the dimensions of a tensor. For a tensor of shape `(D1, D2, D3, D4)`, a cyclic permutation would result in `(D4, D1, D2, D3)`. ### Starter Code... 
- 
                
                    Flatten Leading Dimensions### Description Given a tensor with multiple leading dimensions, flatten them into a single dimension. For example, transform a tensor of shape `(D1, D2, D3, D4)` into `(D1*D2, D3, D4)`. ###... 
- 
                
                    Unflatten a Dimension### Description This is the inverse of flattening. Given a tensor where the first dimension is a product of two other dimensions, unflatten it. For example, transform a tensor of shape `(D1*D2,... 
- 
                
                    Pixel Unshuffle (Pixel to Channel)### Description This is another name for the space-to-depth operation, common in super-resolution models. It involves rearranging blocks of spatial data into the channel dimension. Given a tensor... 
- 
                
                    Pixel Shuffle (Channel to Pixel)### Description The inverse of pixel unshuffle, also known as depth-to-space. It is used to upscale an image by rearranging elements from the channel dimension into spatial blocks. Given a tensor...