ML Katas

Implement a Simple Linear Regression in JAX

easy (<30 mins) autograd optimization jax jit linear-regression
yesterday by E

Your task is to implement a simple linear regression model from scratch using JAX. You'll need to define the model, a loss function (like Mean Squared Error), and a gradient descent update rule. This exercise will help you get comfortable with JAX's functional style, its automatic differentiation capabilities with jax.grad, and the use of jax.jit for performance optimization.

Model: y=wx+b

Loss (MSE): L(w,b)=1Ni=1N(yi(wxi+b))2

Verification: - Your loss should decrease over training iterations. - The final learned parameters w and b should be close to the true values used to generate the data.