ML Katas

Bayesian Neural Network Layer

hard (<30 mins) pytorch bayesian bnn uncertainty
this year by E

Description

In a standard neural network, weights are single point estimates. In a Bayesian Neural Network (BNN), we learn a probability distribution over each weight. [1] This allows for better uncertainty estimation. Your task is to implement a Bayesian linear layer using Variational Inference and the "Bayes by Backprop" method.

Guidance

Your BayesianLinear layer will be a custom nn.Module. Instead of nn.Parameter for weight and bias, you will store a mean and a log_std for both. The forward pass uses the reparameterization trick to sample weights: weight = mean + exp(log_std) * epsilon, where epsilon is random noise. The layer's loss function will have two parts: the standard task loss (e.g., MSE) and a KL divergence term that measures how much the learned weight distributions have diverged from a prior distribution (e.g., a standard normal).

Starter Code

import torch
import torch.nn as nn

class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        # 1. Define trainable parameters for weight_mean, weight_log_std,
        #    bias_mean, and bias_log_std.

    def forward(self, x):
        # 1. Sample from a standard normal distribution (epsilon).
        # 2. Use the reparameterization trick to sample the actual weights and biases.
        # 3. Perform the linear transformation using the sampled weights and biases.
        #    Hint: F.linear(x, sampled_weight, sampled_bias)
        pass

    def kl_divergence(self):
        # 1. Calculate the KL divergence between the learned weight distribution
        #    and a standard normal prior (mean=0, std=1).
        # 2. Do the same for the bias.
        # 3. Return the sum of the two KL terms.
        pass

Verification

Instantiate your layer and check that the forward pass produces an output of the correct shape. Your kl_divergence method should return a scalar tensor. In a training loop, the total loss would be task_loss + model.kl_divergence(). This KL term should act as a regularizer, and you can observe that the learned log_std values don't go to negative infinity (i.e., the model doesn't become overly confident).

References

[1] Blundell, C., et al. (2015). Weight Uncertainty in Neural Networks. ICML 2015.