Model-Agnostic Meta-Learning (MAML) Update Step
Description
Model-Agnostic Meta-Learning (MAML) is a meta-learning algorithm that trains a model's initial parameters such that it can adapt to a new task with only a few gradient steps. [1] The core of MAML is its tricky second-order gradient update. Your task is to implement a single MAML update step for a given model and a single task.
Guidance
The process involves a gradient-through-gradient calculation:
1. Inner Loop: Create a temporary copy of your model. Update this copy using a few gradient steps on the task's support set.
2. Outer Loop: Evaluate the updated, temporary model on the task's query set to get a meta-loss.
3. Meta-Update: The crucial step is to calculate the gradient of this meta-loss with respect to the original model's parameters. This requires the gradients to flow back through the inner loop's update steps. This is done using torch.autograd.grad
with create_graph=True
.
Starter Code
import torch
import torch.nn as nn
def maml_inner_update(model, support_x, support_y, loss_fn, inner_lr):
# This function should take a model and a support set,
# and return a *new* model instance with updated weights after one step.
# Hint: you may need to manually implement the SGD update on the weights.
# The key is to NOT use an optimizer that would break the computation graph.
pass
def maml_outer_update(original_model, fast_model, query_x, query_y, loss_fn):
# This function calculates the meta-loss on the query set using the
# fast_model, and then computes gradients w.r.t the original_model's weights.
# Hint: use torch.autograd.grad with create_graph=True
pass
Verification
After performing the update, check that the original model's parameters have gradients (e.g., original_model.parameters()[0].grad
is not None). This confirms that the second-order gradient was successfully computed and propagated back.
References
[1] Finn, C., Abbeel, P., & Levine, S. (2017). Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. ICML 2017.