- 
                
                    Implementing One-Hot Encoding with `scatter_`Your task is to create a function that performs one-hot encoding on a tensor of integer labels. This is a common preprocessing step for categorical data in machine learning. You will be given a 1D... 
- 
                
                    Advanced Indexing with `gather` for NLPIn Natural Language Processing, it's common to work with sequences of varying lengths. A frequent task is to extract the activations of the last token in each sequence from a tensor of shape... 
- 
                
                    Image Patch Extraction with `select` and `narrow`In computer vision, a common operation is to extract patches from an image. Your task is to write a function that extracts a patch of a given size from a specific starting location in an image... 
- 
                
                    Sparse Updates with `scatter_add_`In graph neural networks and other sparse data applications, you often need to update a tensor based on sparse indices. Your exercise is to implement a function that takes a tensor of `values`, a... 
- 
                
                    Implementing a simplified Beam Search Decoder with `gather` and `scatter`Beam search is a popular decoding algorithm used in machine translation and text generation. A key step in beam search is to select the top-k most likely next tokens and update the corresponding... 
- 
                
                    Selecting RoIs (Regions of Interest) with `index_select`In object detection tasks, after a region proposal network (RPN) suggests potential object locations, these regions of interest (RoIs) need to be extracted from the feature map for further... 
- 
                
                    Creating a Label Mapping with `scatter` for Domain AdaptationIn domain adaptation, you might need to map labels from a source domain to a target domain. Imagine you have a set of labels and a mapping that specifies how each old label corresponds to a new... 
- 
                
                    Replicating `torch.nn.Embedding` with `gather`The `torch.nn.Embedding` layer is fundamental in many deep learning models, especially in NLP. Your task is to replicate its forward pass functionality using `torch.gather`. You'll create a... 
- 
                
                    Efficiently Updating a Sub-Tensor with `scatter_`Sometimes you need to update a portion of a larger tensor with values from a smaller tensor, where the update locations are specified by an index tensor. This is a common pattern in various... 
- 
                
                    NumPy to JAX: The-BasicsThis first exercise is a straightforward warm-up to familiarize yourself with the JAX NumPy API. 1. Create a function that takes two `jnp.ndarray`'s, `W` and `x`, as well as a `jnp.ndarray` `b`,... 
- 
                
                    The JAX approach to PRNGJAX handles pseudo-random number generation (PRNG) differently than NumPy, which uses a global state. JAX, on the other hand, makes the state of the PRNG explicit. This is a design choice that... 
- 
                
                    The need for speed: `jit`JAX's `jit` function will compile your Python code, which can lead to significant speedups. This is because JAX can fuse operations together, removing the overhead of Python's interpreter. In this... 
- 
                
                    Taking gradients with `grad`One of JAX's most powerful features is its ability to automatically differentiate your code. In this exercise, you will implement a regression problem and use `jax.grad` to compute gradients. 1.... 
- 
                
                    Taking gradients with `grad` IIBy default, `jax.grad` will take the gradient with respect to the first argument of the function. However, in many cases, we will want to take gradients with respect to many of the function's... 
- 
                
                    Vectorizing with `vmap`Another of JAX's powerful features is its ability to automatically vectorize functions with `vmap`. For example, `vmap` can be used to automatically batch a function that was written for a single... 
- 
                
                    Updating model parametersIn this exercise, you will implement a full training step for the regression problem that you have been working on. 1. Instantiate your model parameters, `W` and `b`, and your data `x` and... 
- 
                
                    A simple MLPNow that you have all of the basic building blocks, it's time to build a simple multi-layer perceptron (MLP). In this exercise, you are to build a 2-layer MLP for a regression problem. 1.... 
- 
                
                    PyTreesA PyTree is any nested structure of dictionaries, lists, and tuples. JAX is designed to work with PyTrees, which allows for a more organized way of handling model parameters. In this exercise, you... 
- 
                
                    A simple CNNIn this exercise, you will implement a simple convolutional neural network (CNN) for a regression problem. You can use `jax.lax.conv_general_dilated` to implement the convolution. 1. Implement a... 
- 
                
                    Understanding `jit` and tracingWhen you `jit` a function, JAX traces it to determine its computational graph. This graph is then compiled by XLA for efficient execution. However, this tracing mechanism has some implications.... 
- 
                
                    Conditionals with `jit`Standard Python control flow, like `if` statements, can cause issues with `jit` when the condition depends on a traced value. This is because JAX needs to know the entire computation graph at... 
- 
                
                    Loops with `jit`Similar to conditionals, standard Python `for` or `while` loops can cause problems with `jit` if the loop's duration depends on a traced value. JAX provides `jax.lax.fori_loop` and... 
- 
                
                    Forward-mode vs. Reverse-mode autodiffJAX supports both forward-mode and reverse-mode automatic differentiation. While `grad` uses reverse-mode, you can use `jax.jvp` for forward-mode, which computes Jacobian-vector products.... 
- 
                
                    Custom VJPFor some functions, you may want to define a custom vector-Jacobian product (VJP). This can be useful for numerical stability or for implementing algorithms that are not easily expressed in terms... 
- 
                
                    Working with a NN library: FlaxWhile it is possible to build neural networks from scratch in JAX, it is often more convenient to use a library like Flax or Haiku. These libraries provide common neural network layers and...