-
Grouped Matrix Multiplication
### Description Perform matrix multiplication on groups of matrices within a batch. For an input tensor of shape `(B, G, M, K)` and another of shape `(B, G, K, N)`, the output should be of shape...
-
Bilinear Attention Pooling
### Description In some attention mechanisms, you need to compute a bilinear interaction between two sets of features. Given two tensors of shapes `(B, N, D)` and `(B, M, D)`, compute a bilinear...
-
Extract Diagonal from a Batch of Matrices
### Description Given a batch of square matrices of shape `(B, N, N)`, extract the diagonal of each matrix. The output should be a tensor of shape `(B, N)`. This can be achieved with...
-
Build a Custom ReLU Activation Function
Implement the Rectified Linear Unit (ReLU) activation function in JAX. Then, use `jax.grad` to find its derivative. The ReLU function is defined as: $ReLU(x) = max(0, x)$ **Verification:** - For...
-
Vectorized Operations with vmap
You have a function that processes a single data point. Your goal is to use `jax.vmap` to apply this function to a whole batch of data without writing an explicit loop. For example, consider a...
-
Debugging JAX code with `jax.debug.print`
In JAX, standard Python `print` statements don't always work as expected within `jit`-compiled functions because they execute at trace time. [11] The solution is to use `jax.debug.print`. [11, 23]...
-
Custom `nn.Module` with a Non-standard Initialization
Create a **custom `nn.Module`** for a simple feed-forward layer. Instead of the default PyTorch initialization, you'll apply a specific, non-standard initialization scheme. For example, you could...
-
Implementing Gradient Clipping
Implement **gradient clipping** in your training loop. This technique is used to prevent exploding gradients, which can be a problem in RNNs and other deep networks. After the backward pass...
-
Matrix Multiplication Basics
Implement a function in PyTorch that multiplies two matrices using `torch.mm`. ### Problem Write a function `matmul(A, B)` that takes two 2D tensors `A` and `B` and returns their matrix product. -...
-
ReLU Activation Function
Implement the ReLU (Rectified Linear Unit) function in PyTorch. ### Problem Write a function `relu(x)` that takes a 1D tensor and replaces all negative values with 0. - **Input:** A tensor `x` of...