ML Katas

Model-Agnostic Meta-Learning (MAML) Update Step

hard (<30 mins) pytorch meta-learning maml few-shot
this year by E

Description

Model-Agnostic Meta-Learning (MAML) is a meta-learning algorithm that trains a model's initial parameters such that it can adapt to a new task with only a few gradient steps. [1] The core of MAML is its tricky second-order gradient update. Your task is to implement a single MAML update step for a given model and a single task.

Guidance

The process involves a gradient-through-gradient calculation: 1. Inner Loop: Create a temporary copy of your model. Update this copy using a few gradient steps on the task's support set. 2. Outer Loop: Evaluate the updated, temporary model on the task's query set to get a meta-loss. 3. Meta-Update: The crucial step is to calculate the gradient of this meta-loss with respect to the original model's parameters. This requires the gradients to flow back through the inner loop's update steps. This is done using torch.autograd.grad with create_graph=True.

Starter Code

import torch
import torch.nn as nn

def maml_inner_update(model, support_x, support_y, loss_fn, inner_lr):
    # This function should take a model and a support set,
    # and return a *new* model instance with updated weights after one step.
    # Hint: you may need to manually implement the SGD update on the weights.
    # The key is to NOT use an optimizer that would break the computation graph.
    pass

def maml_outer_update(original_model, fast_model, query_x, query_y, loss_fn):
    # This function calculates the meta-loss on the query set using the
    # fast_model, and then computes gradients w.r.t the original_model's weights.
    # Hint: use torch.autograd.grad with create_graph=True
    pass

Verification

After performing the update, check that the original model's parameters have gradients (e.g., original_model.parameters()[0].grad is not None). This confirms that the second-order gradient was successfully computed and propagated back.

References

[1] Finn, C., Abbeel, P., & Levine, S. (2017). Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. ICML 2017.