Implement a Simple Linear Regression in JAX
Your task is to implement a simple linear regression model from scratch using JAX. You'll need to define the model, a loss function (like Mean Squared Error), and a gradient descent update rule. This exercise will help you get comfortable with JAX's functional style, its automatic differentiation capabilities with jax.grad
, and the use of jax.jit
for performance optimization.
Model:
Loss (MSE):
Verification:
- Your loss should decrease over training iterations.
- The final learned parameters w
and b
should be close to the true values used to generate the data.