- 
                
                    Grouped Matrix Multiplication### Description Perform matrix multiplication on groups of matrices within a batch. For an input tensor of shape `(B, G, M, K)` and another of shape `(B, G, K, N)`, the output should be of shape... 
- 
                
                    Bilinear Attention Pooling### Description In some attention mechanisms, you need to compute a bilinear interaction between two sets of features. Given two tensors of shapes `(B, N, D)` and `(B, M, D)`, compute a bilinear... 
- 
                
                    Extract Diagonal from a Batch of Matrices### Description Given a batch of square matrices of shape `(B, N, N)`, extract the diagonal of each matrix. The output should be a tensor of shape `(B, N)`. This can be achieved with... 
- 
                
                    Build a Custom ReLU Activation FunctionImplement the Rectified Linear Unit (ReLU) activation function in JAX. Then, use `jax.grad` to find its derivative. The ReLU function is defined as: $ReLU(x) = max(0, x)$ **Verification:** - For... 
- 
                
                    Vectorized Operations with vmapYou have a function that processes a single data point. Your goal is to use `jax.vmap` to apply this function to a whole batch of data without writing an explicit loop. For example, consider a... 
- 
                
                    Debugging JAX code with `jax.debug.print`In JAX, standard Python `print` statements don't always work as expected within `jit`-compiled functions because they execute at trace time. [11] The solution is to use `jax.debug.print`. [11, 23]... 
- 
                
                    Custom `nn.Module` with a Non-standard InitializationCreate a **custom `nn.Module`** for a simple feed-forward layer. Instead of the default PyTorch initialization, you'll apply a specific, non-standard initialization scheme. For example, you could... 
- 
                
                    Implementing Gradient ClippingImplement **gradient clipping** in your training loop. This technique is used to prevent exploding gradients, which can be a problem in RNNs and other deep networks. After the backward pass... 
- 
                
                    Matrix Multiplication BasicsImplement a function in PyTorch that multiplies two matrices using `torch.mm`. ### Problem Write a function `matmul(A, B)` that takes two 2D tensors `A` and `B` and returns their matrix product. -... 
- 
                
                    ReLU Activation FunctionImplement the ReLU (Rectified Linear Unit) function in PyTorch. ### Problem Write a function `relu(x)` that takes a 1D tensor and replaces all negative values with 0. - **Input:** A tensor `x` of...