ML Katas

Checkpointing

medium (<30 mins) flax checkpointing
this year by E

When training large models, it is important to save the model's parameters periodically. This is known as checkpointing and allows you to resume training from a saved state in case of an interruption.

Flax provides utilities for checkpointing. In this exercise, you will add checkpointing to your training loop.

  1. Use flax.training.checkpoints to save the TrainState every 100 training steps.
  2. Write a script that can resume training from a saved checkpoint.