Implement Layer Normalization
Implement Layer Normalization from scratch in JAX. Layer Normalization is a key component in many modern neural network architectures, like Transformers. It normalizes the inputs across the features.
Formula:
Where and are learnable parameters (gain and bias).
Verification: - The output of your layer norm function should have a mean close to 0 and a standard deviation close to 1 before applying the gain and bias. - The shape of the output should be the same as the input.