ML Katas

Implement a Convolutional Layer

hard (<30 mins) cnn jax convolution deep-learning
yesterday by E

Implement a 2D convolutional layer from scratch in JAX. This will involve using jax.lax.conv_general_dilated. You will need to manage the kernel initialization and the forward pass logic.

Task: Write a function that takes an input image, a kernel, and convolution parameters (stride, padding) and returns the convolved output.

Verification: - The output shape of your convolution should be correct based on the input shape, kernel size, stride, and padding. - For a known input and kernel, the output should match the expected result.