ML Katas

Softmax's Numerical Stability: The Max Trick

medium (<1 hr) Optimization Softmax Numerical Stability Log-Sum-Exp
this year by E

While the standard softmax formula softmax(zi)=ezijezj is mathematically correct, a direct implementation can lead to numerical instability due to potential overflow or underflow issues when dealing with very large or very small exponential values. This is especially true for large zi.

  1. The Problem with Naive Implementation: Consider the logits 𝐳=[1000,999,998]. What would happen if you tried to compute e1000 directly in a standard floating-point system? (Hint: Think about inf in Python/Numpy).
  2. The Max Trick: To overcome this, the softmax function is typically implemented by subtracting the maximum logit value from all logits before exponentiation. That is, softmax(zi)=eziCjezjC, where C=max(z). Prove that this transformation does not change the resulting probabilities.
  3. Applying the Trick: Apply the max trick to the logits 𝐳=[1000,999,998] and calculate the numerically stable softmax probabilities. Compare this to the results you'd get from the naive approach (if your system handles inf gracefully) or discuss the expected issues.
  4. Intuition: Why does subtracting the maximum logit before exponentiation solve the overflow problem without affecting the result? What does it do for the denominator and numerator simultaneously?
  5. Verification: Write two Python functions: one for a naive softmax and one for a numerically stable softmax (using the max trick). Test them with logits that would cause issues for the naive version (e.g., large positive numbers) and verify that the stable version produces correct, finite probabilities.
import numpy as np

def naive_softmax(z):
    # This implementation is prone to overflow
    exp_z = np.exp(z)
    return exp_z / np.sum(exp_z)

def stable_softmax(z):
    # Implement the numerically stable version here
    # max_z = ...
    # exp_z_shifted = ...
    # return ...
    pass

# Test with large logits
# large_logits = np.array([1000.0, 999.0, 998.0])

# print(f"Naive softmax: {naive_softmax(large_logits)}") # Expect RuntimeWarning or infs
# print(f"Stable softmax: {stable_softmax(large_logits)}") # Expect finite values

# Test with small logits (less critical for overflow, but good to check)
# small_logits = np.array([-100.0, -101.0, -102.0])
# print(f"Naive softmax (small): {naive_softmax(small_logits)}")
# print(f"Stable softmax (small): {stable_softmax(small_logits)}")

# Verify correctness by comparing to a known good implementation or small values
# normal_logits = np.array([2.0, 1.0, 0.1])
# print(f"Normal logits, Naive: {naive_softmax(normal_logits)}")
# print(f"Normal logits, Stable: {stable_softmax(normal_logits)}")