- 
                
                    Advanced Indexing with `gather` for NLPIn Natural Language Processing, it's common to work with sequences of varying lengths. A frequent task is to extract the activations of the last token in each sequence from a tensor of shape... 
- 
                
                    Sparse Updates with `scatter_add_`In graph neural networks and other sparse data applications, you often need to update a tensor based on sparse indices. Your exercise is to implement a function that takes a tensor of `values`, a... 
- 
                
                    Selecting RoIs (Regions of Interest) with `index_select`In object detection tasks, after a region proposal network (RPN) suggests potential object locations, these regions of interest (RoIs) need to be extracted from the feature map for further... 
- 
                
                    Replicating `torch.nn.Embedding` with `gather`The `torch.nn.Embedding` layer is fundamental in many deep learning models, especially in NLP. Your task is to replicate its forward pass functionality using `torch.gather`. You'll create a... 
- 
                
                    Updating model parametersIn this exercise, you will implement a full training step for the regression problem that you have been working on. 1. Instantiate your model parameters, `W` and `b`, and your data `x` and... 
- 
                
                    PyTreesA PyTree is any nested structure of dictionaries, lists, and tuples. JAX is designed to work with PyTrees, which allows for a more organized way of handling model parameters. In this exercise, you... 
- 
                
                    Conditionals with `jit`Standard Python control flow, like `if` statements, can cause issues with `jit` when the condition depends on a traced value. This is because JAX needs to know the entire computation graph at... 
- 
                
                    Loops with `jit`Similar to conditionals, standard Python `for` or `while` loops can cause problems with `jit` if the loop's duration depends on a traced value. JAX provides `jax.lax.fori_loop` and... 
- 
                
                    Working with an optimizer library: OptaxOptax is a popular library for optimization in JAX. It provides a wide range of optimizers and is designed to be highly modular. In this exercise, you will use Optax to train the Flax MLP from the... 
- 
                
                    CheckpointingWhen training large models, it is important to save the model's parameters periodically. This is known as checkpointing and allows you to resume training from a saved state in case of an... 
- 
                
                    Cross-Entropy: A Measure of SurpriseCross-entropy loss is fundamental for classification tasks. Let's build some intuition for its formulation. 1. **Definition**: For a binary classification problem, the binary cross-entropy (BCE)... 
- 
                
                    The Implicit Higher Dimension of KernelsSupport Vector Machines (SVMs) are powerful, and the "kernel trick" allows them to find non-linear decision boundaries without explicitly mapping data to high-dimensional spaces. 1. **Linear... 
- 
                
                    KL Divergence Calculation and InterpretationThe Kullback-Leibler (KL) Divergence (also known as relative entropy) is a non-symmetric measure of how one probability distribution $P$ is different from a second, reference probability... 
- 
                
                    Matrix Multiplication and EfficiencyMatrix multiplication is a fundamental operation in linear algebra and a cornerstone of deep learning. Given two matrices $A$ (size $m \times k$) and $B$ (size $k \times n$), their product $C =... 
- 
                
                    Build a Simple Neural Network with FlaxUsing Flax, JAX's neural network library, build a simple Multi-Layer Perceptron (MLP). The MLP should have an input layer, one hidden layer with a ReLU activation function, and an output layer.... 
- 
                
                    Implement a Basic Optimizer with OptaxUse Optax, JAX's optimization library, to create a simple Stochastic Gradient Descent (SGD) optimizer. You'll need to define a model, a loss function, and then use `optax.sgd` to update the... 
- 
                
                    Implement Layer NormalizationImplement Layer Normalization from scratch in JAX. Layer Normalization is a key component in many modern neural network architectures, like Transformers. It normalizes the inputs across the... 
- 
                
                    Implementing DropoutImplement the dropout regularization technique in JAX. This involves randomly setting a fraction of input units to 0 at each update during training time. Remember that dropout should only be... 
- 
                
                    Differentiable Additive Synthesizer### Description Differentiable Digital Signal Processing (DDSP) is a technique that combines classic signal processing with deep learning by making the parameters of synthesizers learnable via... 
- 
                
                    Implement a Knowledge Distillation Loss### Description Knowledge Distillation is a model compression technique where a small "student" model is trained to mimic a larger, pre-trained "teacher" model. [1] This is achieved by training... 
- 
                
                    Masked Autoencoder (MAE) Input Preprocessing### Description Masked Autoencoders (MAE) are a powerful self-supervised learning technique for vision transformers. The core idea is simple: randomly mask a large portion of the input image... 
- 
                
                    Deep Canonical Correlation Analysis (DCCA) Loss### Description Canonical Correlation Analysis (CCA) is a statistical method for finding correlations between two sets of variables. Deep CCA (DCCA) uses neural networks to first project two... 
- 
                
                    Gradient Reversal Layer### Description Implement a Gradient Reversal Layer (GRL), a key component in Domain-Adversarial Neural Networks (DANNs). [1] The GRL acts as an identity function during the forward pass but... 
- 
                
                    Tiny Neural Radiance Fields (NeRF)### Description Implement a simplified version of a Neural Radiance Field (NeRF) to represent a 2D image. [1] A NeRF learns a continuous mapping from spatial coordinates to pixel values. Instead... 
- 
                
                    Implement Lottery Ticket Hypothesis Pruning### Description The Lottery Ticket Hypothesis suggests that a randomly initialized, dense network contains a smaller subnetwork (a "winning ticket") that, when trained in isolation, can match the...