ML Katas

Build a Custom ReLU Activation Function

easy (<10 mins) autograd jax activation-function custom-function
yesterday by E

Implement the Rectified Linear Unit (ReLU) activation function in JAX. Then, use jax.grad to find its derivative. The ReLU function is defined as:

ReLU(x)=max(0,x)

Verification: - For x > 0, the gradient should be 1. - For x < 0, the gradient should be 0. - At x = 0, the gradient is technically undefined, but JAX will likely return either 0 or 1. Verify this behavior.