-
Updating model parameters
In this exercise, you will implement a full training step for the regression problem that you have been working on. 1. Instantiate your model parameters, `W` and `b`, and your data `x` and...
-
A simple MLP
Now that you have all of the basic building blocks, it's time to build a simple multi-layer perceptron (MLP). In this exercise, you are to build a 2-layer MLP for a regression problem. 1....
-
A simple CNN
In this exercise, you will implement a simple convolutional neural network (CNN) for a regression problem. You can use `jax.lax.conv_general_dilated` to implement the convolution. 1. Implement a...
-
Conditionals with `jit`
Standard Python control flow, like `if` statements, can cause issues with `jit` when the condition depends on a traced value. This is because JAX needs to know the entire computation graph at...
-
Loops with `jit`
Similar to conditionals, standard Python `for` or `while` loops can cause problems with `jit` if the loop's duration depends on a traced value. JAX provides `jax.lax.fori_loop` and...