ML Katas

A simple CNN

medium (<1 hr) jit autodiff CNN conv
this year by E

In this exercise, you will implement a simple convolutional neural network (CNN) for a regression problem. You can use jax.lax.conv_general_dilated to implement the convolution.

  1. Implement a CNN with one convolutional layer followed by a ReLU activation and a dense layer. The convolutional layer should have 4 filters of size 3x3. The dense layer should map the flattened output of the convolutional layer to a single output value.
  2. As before, implement a loss function, a jit'ed training step, and a training loop.
  3. Plot the loss as a function of the training step. It should be decreasing.