A simple CNN
In this exercise, you will implement a simple convolutional neural network (CNN) for a regression problem. You can use jax.lax.conv_general_dilated to implement the convolution.
- Implement a CNN with one convolutional layer followed by a ReLU activation and a dense layer. The convolutional layer should have 4 filters of size 3x3. The dense layer should map the flattened output of the convolutional layer to a single output value.
- As before, implement a loss function, a
jit'ed training step, and a training loop. - Plot the loss as a function of the training step. It should be decreasing.