ML Katas

PyTrees

medium (<30 mins) pytree
this year by E

A PyTree is any nested structure of dictionaries, lists, and tuples. JAX is designed to work with PyTrees, which allows for a more organized way of handling model parameters.

In this exercise, you are tasked with re-implementing the simple MLP from the previous exercise, but, this time, you should use a PyTree to store your model parameters. The PyTree should be a dictionary of lists, where each list contains the weight matrix and bias vector for a layer.

Note that jax.grad will return a PyTree of the same structure as the one that you passed in.