-
Build a Custom ReLU Activation Function
Implement the Rectified Linear Unit (ReLU) activation function in JAX. Then, use `jax.grad` to find its derivative. The ReLU function is defined as: $ReLU(x) = max(0, x)$ **Verification:** - For...
-
Vectorized Operations with vmap
You have a function that processes a single data point. Your goal is to use `jax.vmap` to apply this function to a whole batch of data without writing an explicit loop. For example, consider a...
-
Debugging JAX code with `jax.debug.print`
In JAX, standard Python `print` statements don't always work as expected within `jit`-compiled functions because they execute at trace time. [11] The solution is to use `jax.debug.print`. [11, 23]...
-
Implementing Gradient Clipping
Implement **gradient clipping** in your training loop. This technique is used to prevent exploding gradients, which can be a problem in RNNs and other deep networks. After the backward pass...
-
Matrix Multiplication Basics
Implement a function in PyTorch that multiplies two matrices using `torch.mm`. ### Problem Write a function `matmul(A, B)` that takes two 2D tensors `A` and `B` and returns their matrix product. -...
-
ReLU Activation Function
Implement the ReLU (Rectified Linear Unit) function in PyTorch. ### Problem Write a function `relu(x)` that takes a 1D tensor and replaces all negative values with 0. - **Input:** A tensor `x` of...