The Stabilizing Power of Batch Normalization
Batch Normalization (BatchNorm) is a crucial technique for stabilizing and accelerating deep neural network training.
- Normalization Step: Given a mini-batch of activations for a particular feature map (or neuron output), the first step of BatchNorm is to normalize them.
- Calculate the mean and variance of this mini-batch.
- Show the formula for the normalized activation . Include an epsilon term for numerical stability.
- Learnable Parameters: After normalization, BatchNorm applies a scaling and shifting operation: . Explain the purpose of the learnable parameters (scale) and (shift). Why are they necessary, and what problem would occur if they were omitted?
- Gradient Flow Intuition: Intuitively, how does BatchNorm help mitigate the "internal covariate shift" problem (without getting into complex proofs)? How does it make the loss landscape smoother and gradients more stable?
- Inference vs. Training: Explain the difference in how BatchNorm operates during training versus inference. What values are used for and during inference, and how are they obtained?
- Verification: You can implement the forward pass of BatchNorm for a small batch of 1D data and observe how the mean and variance change after each step.
import numpy as np
def batch_norm_forward(x_batch, gamma, beta, epsilon=1e-5):
# x_batch is a 1D numpy array representing activations for one feature
mean_b = np.mean(x_batch)
var_b = np.var(x_batch)
x_hat = (x_batch - mean_b) / np.sqrt(var_b + epsilon)
y = gamma * x_hat + beta
return y, mean_b, var_b # returning mean_b, var_b for verification
# Example:
# x = np.array([1.0, 2.0, 3.0, 4.0])
# gamma_val = 1.0 # Initial guess, these are learned
# beta_val = 0.0 # Initial guess, these are learned
# normalized_x, mean, var = batch_norm_forward(x, gamma_val, beta_val)
# print(f"Original x: {x}, Mean: {np.mean(x)}, Var: {np.var(x)}")
# print(f"Normalized y: {normalized_x}, Mean: {np.mean(normalized_x)}, Var: {np.var(normalized_x)}")
# (Note: mean/var of normalized_x should be close to 0/1 before gamma/beta)