Checkpointing
When training large models, it is important to save the model's parameters periodically. This is known as checkpointing and allows you to resume training from a saved state in case of an interruption.
Flax provides utilities for checkpointing. In this exercise, you will add checkpointing to your training loop.
- Use
flax.training.checkpointsto save theTrainStateevery 100 training steps. - Write a script that can resume training from a saved checkpoint.