-
Implement a Simple Linear Regression in JAX
Your task is to implement a simple linear regression model from scratch using JAX. You'll need to define the model, a loss function (like Mean Squared Error), and a gradient descent update rule....
-
Implement a Linear Regression Model
Build a simple linear regression model using `nn.Module`. Requirements: - One input feature, one output. - Train it on synthetic data $$y = 3x + 2 + \epsilon$$. - Use `MSELoss` and `SGD`. Check...
-
Checkpointing with torch.save
Train a simple feedforward model for 1 epoch. Save: 1. Model state dict. 2. Optimizer state dict. 3. Epoch number. Then load the checkpoint and resume training seamlessly.
1