-
Implement a Convolutional Layer
Implement a 2D convolutional layer from scratch in JAX. This will involve using `jax.lax.conv_general_dilated`. You will need to manage the kernel initialization and the forward pass logic....
1