ML Katas

Physics-Informed Neural Network (PINN) for an ODE

hard (<30 mins) pytorch autograd ode pinn physics
this year by E

Description

Solve a simple Ordinary Differential Equation (ODE) using a Physics-Informed Neural Network. A PINN is a neural network that is trained to satisfy both the data and the underlying physical laws described by a differential equation. [1] Your task is to approximate the solution of an ODE like dudx+u=0 with the initial condition u(0)=1. The analytical solution is u(x)=ex.

Guidance

A PINN's loss function has two parts: 1. Data Loss: A standard supervised loss (like MSE) on the known data points. In this case, it's the initial condition u(0)=1. 2. Physics Loss: This is the core idea. The neural network itself should satisfy the ODE. You enforce this by defining a loss on the residual of the ODE. The residual is what you get when you plug the network's output into the equation (i.e., dunetdx+unet). This loss should be zero if the network is a perfect solution. You will need to use torch.autograd.grad to compute the derivative of the network's output with respect to its input.

Starter Code

import torch
import torch.nn as nn

# A simple MLP to approximate the function u(x)
class PINN(nn.Module):
    def __init__(self):
        super(PINN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 20),
            nn.Tanh(),
            nn.Linear(20, 1)
        )

    def forward(self, x):
        return self.net(x)

# --- In your training loop ---
# 1. Calculate loss_data: Evaluate the network at x=0 and compare to the known value of 1.
# 2. Calculate loss_physics: 
#    a. Create a tensor of random points in your domain, <!--CODE_BLOCK_17245-->.
#    b. Compute the network's output <!--CODE_BLOCK_17246--> for these points.
#    c. Compute the derivative <!--CODE_BLOCK_17247--> using torch.autograd.grad.
#    d. The physics loss is the mean squared error of the ODE residual (du_dx + u_pred).
# 3. total_loss = loss_data + loss_physics

Verification

After training, plot the output of your neural network for a range of x values (e.g., from 0 to 2) and compare it with the analytical solution u(x)=ex. The two curves should be very close.

References

[1] Raissi, M., Perdikaris, P., & Karniadakis, G. E. (2017). Physics Informed Deep Learning.