-
The need for speed: `jit`
JAX's `jit` function will compile your Python code, which can lead to significant speedups. This is because JAX can fuse operations together, removing the overhead of Python's interpreter. In this...
-
Updating model parameters
In this exercise, you will implement a full training step for the regression problem that you have been working on. 1. Instantiate your model parameters, `W` and `b`, and your data `x` and...
-
A simple MLP
Now that you have all of the basic building blocks, it's time to build a simple multi-layer perceptron (MLP). In this exercise, you are to build a 2-layer MLP for a regression problem. 1....
-
A simple CNN
In this exercise, you will implement a simple convolutional neural network (CNN) for a regression problem. You can use `jax.lax.conv_general_dilated` to implement the convolution. 1. Implement a...
-
Understanding `jit` and tracing
When you `jit` a function, JAX traces it to determine its computational graph. This graph is then compiled by XLA for efficient execution. However, this tracing mechanism has some implications....
-
Conditionals with `jit`
Standard Python control flow, like `if` statements, can cause issues with `jit` when the condition depends on a traced value. This is because JAX needs to know the entire computation graph at...
-
Loops with `jit`
Similar to conditionals, standard Python `for` or `while` loops can cause problems with `jit` if the loop's duration depends on a traced value. JAX provides `jax.lax.fori_loop` and...
-
Implement a Simple Linear Regression in JAX
Your task is to implement a simple linear regression model from scratch using JAX. You'll need to define the model, a loss function (like Mean Squared Error), and a gradient descent update rule....
-
Debugging JAX code with `jax.debug.print`
In JAX, standard Python `print` statements don't always work as expected within `jit`-compiled functions because they execute at trace time. [11] The solution is to use `jax.debug.print`. [11, 23]...