Normalizing Flow for Density Estimation
Description
Implement a simple 2D Normalizing Flow model. Normalizing Flows transform a simple base distribution (like a Gaussian) into a more complex distribution by applying a sequence of invertible transformations. [1] This allows for both sampling and exact density estimation. Your task is to implement a RealNVP-style coupling layer.
Guidance
A coupling layer splits the input x
into two halves, x1
and x2
. One half (x1
) is passed into two small neural networks to produce a scale s
and a translation t
. The other half (x2
) is then transformed using this s
and t
. The key is that this transformation is easily invertible. You also need to calculate the log-determinant of the Jacobian of this transformation, which for this layer is simply the sum of the s
values.
Starter Code
import torch
import torch.nn as nn
class CouplingLayer(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
# Define two simple MLPs. Each should take input_dim // 2 inputs
# and produce input_dim // 2 outputs.
self.s_net = nn.Sequential(...)
self.t_net = nn.Sequential(...)
def forward(self, x):
# 1. Split x into two halves, x1 and x2.
# 2. Pass x1 through s_net and t_net to get s and t.
# 3. Transform x2 using the affine transformation: y2 = x2 * exp(s) + t.
# 4. The other half, y1, is just x1 (identity transformation).
# 5. Concatenate y1 and y2 to get the full output y.
# 6. Calculate the log-determinant of the Jacobian, which is just sum(s).
# 7. Return the output y and the log-determinant.
pass
def inverse(self, y):
# Implement the reverse transformation to go from y back to x.
pass
Verification
Stack a few of these coupling layers (alternating which half is transformed). Train the model on a 2D toy dataset (e.g., a moon shape from sklearn.datasets
). The goal is to maximize the log-likelihood of the data. After training, you should be able to sample from a standard 2D Gaussian, pass the samples through the inverse
function of your flow, and generate points that match the toy dataset's distribution.
References
[1] Dinh, L., Sohl-Dickstein, J., & Bengio, S. (2016). Density estimation using Real NVP.