Build a Custom ReLU Activation Function
Implement the Rectified Linear Unit (ReLU) activation function in JAX. Then, use jax.grad
to find its derivative. The ReLU function is defined as:
Verification:
- For x > 0
, the gradient should be 1.
- For x < 0
, the gradient should be 0.
- At x = 0
, the gradient is technically undefined, but JAX will likely return either 0 or 1. Verify this behavior.