ML Katas

Riding the Momentum Wave in Optimization

medium (<1 hr) Gradient Descent Optimization Momentum Hyperparameters
this year by E

Stochastic Gradient Descent (SGD) with momentum is a popular optimization algorithm that often converges faster and more stably than plain SGD.

  1. Update Rule: The update rule for SGD with momentum is typically defined as: vt=γvt1+ηJ(𝐰t1) 𝐰t=𝐰t1vt where 𝐰 are the parameters, J is the loss function, J is the gradient, η is the learning rate, v is the velocity vector, and γ is the momentum hyperparameter (typically between 0 and 1).
  2. Exponential Moving Average: Show that the velocity vector vt can be interpreted as an exponentially decaying moving average of past gradients. What role does γ play in this moving average?
  3. Intuition:
    • How does momentum help overcome local minima or saddle points?
    • How does it help accelerate convergence in regions where gradients are consistent (e.g., a long, shallow valley)?
    • How does it smooth out oscillations that might occur with plain SGD in noisy gradient landscapes?
  4. Ball Rolling Analogy: Explain the momentum update rule using the analogy of a ball rolling down a hill. How do friction (decay) and the slope (gradient) influence the ball's movement?
  5. Verification: Consider a 1D objective function f(x)=x2. Plot the trajectory of plain Gradient Descent and Gradient Descent with momentum for a few steps, starting from the same initial point and using appropriate learning rates and momentum values. Observe the difference in convergence.
import numpy as np
import matplotlib.pyplot as plt

def f_prime(x): # Derivative of x^2
    return 2 * x

# Plain GD
# x_gd = [5.0]
# lr = 0.1
# for _ in range(10):
#     x_gd.append(x_gd[-1] - lr * f_prime(x_gd[-1]))

# GD with Momentum
# x_mom = [5.0]
# v = 0.0
# gamma = 0.9
# lr = 0.05 # potentially smaller learning rate
# for _ in range(10):
#     grad = f_prime(x_mom[-1])
#     v = gamma * v + lr * grad
#     x_mom.append(x_mom[-1] - v)

# plt.plot(x_gd, label='Plain GD')
# plt.plot(x_mom, label='GD with Momentum')
# plt.legend()
# plt.show()