Riding the Momentum Wave in Optimization
Stochastic Gradient Descent (SGD) with momentum is a popular optimization algorithm that often converges faster and more stably than plain SGD.
- Update Rule: The update rule for SGD with momentum is typically defined as: where are the parameters, is the loss function, is the gradient, is the learning rate, is the velocity vector, and is the momentum hyperparameter (typically between 0 and 1).
- Exponential Moving Average: Show that the velocity vector can be interpreted as an exponentially decaying moving average of past gradients. What role does play in this moving average?
- Intuition:
- How does momentum help overcome local minima or saddle points?
- How does it help accelerate convergence in regions where gradients are consistent (e.g., a long, shallow valley)?
- How does it smooth out oscillations that might occur with plain SGD in noisy gradient landscapes?
- Ball Rolling Analogy: Explain the momentum update rule using the analogy of a ball rolling down a hill. How do friction (decay) and the slope (gradient) influence the ball's movement?
- Verification: Consider a 1D objective function . Plot the trajectory of plain Gradient Descent and Gradient Descent with momentum for a few steps, starting from the same initial point and using appropriate learning rates and momentum values. Observe the difference in convergence.
import numpy as np
import matplotlib.pyplot as plt
def f_prime(x): # Derivative of x^2
return 2 * x
# Plain GD
# x_gd = [5.0]
# lr = 0.1
# for _ in range(10):
# x_gd.append(x_gd[-1] - lr * f_prime(x_gd[-1]))
# GD with Momentum
# x_mom = [5.0]
# v = 0.0
# gamma = 0.9
# lr = 0.05 # potentially smaller learning rate
# for _ in range(10):
# grad = f_prime(x_mom[-1])
# v = gamma * v + lr * grad
# x_mom.append(x_mom[-1] - v)
# plt.plot(x_gd, label='Plain GD')
# plt.plot(x_mom, label='GD with Momentum')
# plt.legend()
# plt.show()