Bayesian Neural Network Layer
Description
In a standard neural network, weights are single point estimates. In a Bayesian Neural Network (BNN), we learn a probability distribution over each weight. [1] This allows for better uncertainty estimation. Your task is to implement a Bayesian linear layer using Variational Inference and the "Bayes by Backprop" method.
Guidance
Your BayesianLinear layer will be a custom nn.Module
. Instead of nn.Parameter
for weight and bias, you will store a mean
and a log_std
for both. The forward pass uses the reparameterization trick to sample weights: weight = mean + exp(log_std) * epsilon
, where epsilon
is random noise. The layer's loss function will have two parts: the standard task loss (e.g., MSE) and a KL divergence term that measures how much the learned weight distributions have diverged from a prior distribution (e.g., a standard normal).
Starter Code
import torch
import torch.nn as nn
class BayesianLinear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
# 1. Define trainable parameters for weight_mean, weight_log_std,
# bias_mean, and bias_log_std.
def forward(self, x):
# 1. Sample from a standard normal distribution (epsilon).
# 2. Use the reparameterization trick to sample the actual weights and biases.
# 3. Perform the linear transformation using the sampled weights and biases.
# Hint: F.linear(x, sampled_weight, sampled_bias)
pass
def kl_divergence(self):
# 1. Calculate the KL divergence between the learned weight distribution
# and a standard normal prior (mean=0, std=1).
# 2. Do the same for the bias.
# 3. Return the sum of the two KL terms.
pass
Verification
Instantiate your layer and check that the forward pass produces an output of the correct shape. Your kl_divergence
method should return a scalar tensor. In a training loop, the total loss would be task_loss + model.kl_divergence()
. This KL term should act as a regularizer, and you can observe that the learned log_std
values don't go to negative infinity (i.e., the model doesn't become overly confident).
References
[1] Blundell, C., et al. (2015). Weight Uncertainty in Neural Networks. ICML 2015.