ML Katas

Implement a Neural Ordinary Differential Equation

hard (<30 mins) pytorch generative neural ode dynamics
yesterday by E

Description

Instead of modeling a function directly, a Neural ODE models its derivative with a neural network. The output is then found by integrating this derivative over time. [1] Your task is to implement a simple ODE solver that uses an MLP to define the dynamics. You will approximate the solution to the differential equation dydt=f(t,y), where f is a neural network.

Guidance

The core of the exercise is to create two components: 1. An nn.Module that represents the derivative function f(t,y). It should take the current time t and state y and output the derivative dy/dt. 2. A solver loop that repeatedly calls your derivative network to take small steps forward in time. You can implement the simplest solver, Euler's method: yt+1=yt+Δt·f(t,yt).

Starter Code

import torch
import torch.nn as nn

# The network that learns the derivative function
class ODEF(nn.Module):
    def forward(self, t, y):
        # y is the current state
        # t is the current time (can be ignored for autonomous systems)
        # This should return the derivative dy/dt
        pass

# The solver
def euler_solver(model, y0, t_points):
    """Solves the ODE using Euler's method."""
    y_trajectory = [y0]
    y = y0
    for i in range(len(t_points) - 1):
        dt = t_points[i+1] - t_points[i]
        # 1. Get the derivative from the model
        # 2. Apply the Euler step: y = y + dt * dy_dt
        # 3. Store the new y
    return torch.stack(y_trajectory)

Verification

Create a target trajectory (e.g., a spiral). Train your ODEF network by running the solver, comparing the output trajectory to the target trajectory, and backpropagating the loss. After training, the solved trajectory from your initial condition y0 should closely match the target.

References

[1] Chen, R. T., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural Ordinary Differential Equations. NeurIPS 2018.