ML Katas

Build a Simple Neural Network with Flax

medium (<30 mins) mlp jax flax neural-network
this year by E

Using Flax, JAX's neural network library, build a simple Multi-Layer Perceptron (MLP). The MLP should have an input layer, one hidden layer with a ReLU activation function, and an output layer.

Task: Define a Flax nn.Module for the MLP. Initialize its parameters and apply it to a dummy input tensor.

Verification: - The output of the network should have the correct shape. - You should be able to access the initialized parameters of the model.