Build a Simple Neural Network with Flax
Using Flax, JAX's neural network library, build a simple Multi-Layer Perceptron (MLP). The MLP should have an input layer, one hidden layer with a ReLU activation function, and an output layer.
Task:
Define a Flax nn.Module
for the MLP. Initialize its parameters and apply it to a dummy input tensor.
Verification: - The output of the network should have the correct shape. - You should be able to access the initialized parameters of the model.