Implement a Convolutional Layer
Implement a 2D convolutional layer from scratch in JAX. This will involve using jax.lax.conv_general_dilated
. You will need to manage the kernel initialization and the forward pass logic.
Task: Write a function that takes an input image, a kernel, and convolution parameters (stride, padding) and returns the convolved output.
Verification: - The output shape of your convolution should be correct based on the input shape, kernel size, stride, and padding. - For a known input and kernel, the output should match the expected result.