ML Katas

The Stabilizing Power of Batch Normalization

medium (<1 hr) Optimization Batch Normalization Deep Learning
this year by E

Batch Normalization (BatchNorm) is a crucial technique for stabilizing and accelerating deep neural network training.

  1. Normalization Step: Given a mini-batch of activations X={x1,x2,,xm} for a particular feature map (or neuron output), the first step of BatchNorm is to normalize them.
    • Calculate the mean μB and variance σB2 of this mini-batch.
    • Show the formula for the normalized activation x^i. Include an epsilon term ϵ for numerical stability.
  2. Learnable Parameters: After normalization, BatchNorm applies a scaling and shifting operation: yi=γx^i+β. Explain the purpose of the learnable parameters γ (scale) and β (shift). Why are they necessary, and what problem would occur if they were omitted?
  3. Gradient Flow Intuition: Intuitively, how does BatchNorm help mitigate the "internal covariate shift" problem (without getting into complex proofs)? How does it make the loss landscape smoother and gradients more stable?
  4. Inference vs. Training: Explain the difference in how BatchNorm operates during training versus inference. What values are used for μB and σB2 during inference, and how are they obtained?
  5. Verification: You can implement the forward pass of BatchNorm for a small batch of 1D data and observe how the mean and variance change after each step.
import numpy as np

def batch_norm_forward(x_batch, gamma, beta, epsilon=1e-5):
    # x_batch is a 1D numpy array representing activations for one feature
    mean_b = np.mean(x_batch)
    var_b = np.var(x_batch)
    x_hat = (x_batch - mean_b) / np.sqrt(var_b + epsilon)
    y = gamma * x_hat + beta
    return y, mean_b, var_b # returning mean_b, var_b for verification

# Example:
# x = np.array([1.0, 2.0, 3.0, 4.0])
# gamma_val = 1.0 # Initial guess, these are learned
# beta_val = 0.0 # Initial guess, these are learned
# normalized_x, mean, var = batch_norm_forward(x, gamma_val, beta_val)
# print(f"Original x: {x}, Mean: {np.mean(x)}, Var: {np.var(x)}")
# print(f"Normalized y: {normalized_x}, Mean: {np.mean(normalized_x)}, Var: {np.var(normalized_x)}")
# (Note: mean/var of normalized_x should be close to 0/1 before gamma/beta)