-
Implement Layer Normalization
Implement Layer Normalization from scratch in JAX. Layer Normalization is a key component in many modern neural network architectures, like Transformers. It normalizes the inputs across the...
-
Implementing Dropout
Implement the dropout regularization technique in JAX. This involves randomly setting a fraction of input units to 0 at each update during training time. Remember that dropout should only be...
-
Implement a Convolutional Layer
Implement a 2D convolutional layer from scratch in JAX. This will involve using `jax.lax.conv_general_dilated`. You will need to manage the kernel initialization and the forward pass logic....
1