ML Katas

The Elegant Gradient of Softmax-Cross-Entropy

hard (>1 hr) Backpropagation Cross-Entropy Softmax Gradients
this year by E

One of the most satisfying derivations in deep learning is the gradient of the combined Softmax and Cross-Entropy loss. For a multi-class classification problem with K classes, given true labels 𝐲 (one-hot encoded) and predicted logits 𝐳, the softmax output is y^k=ezkjezj. The categorical cross-entropy loss is L=kyklog(y^k).

  1. Partial Derivative of Softmax: First, derive y^kzi. You'll need to consider two cases:
    • When k=i (the derivative of a specific output with respect to its own logit).
    • When ki (the derivative of a specific output with respect to a different logit). Show that y^kzi=y^k(1y^i) if k=i, and y^kzi=y^ky^i if ki. This can be compactly written as y^k(δkiy^i), where δki is the Kronecker delta.
  2. Chain Rule Application: Now, use the chain rule to find the gradient of the loss L with respect to a logit zi: Lzi=kLy^ky^kzi.
    • Recall that Ly^k=yky^k.
  3. The Result: Show that the final gradient simplifies to a remarkably clean form: Lzi=y^iyi. This means the gradient is simply the difference between the predicted probability and the true one for that class.
  4. Intuition: Why is this derivative so elegant? What does it imply about how the network updates its logits based on the error? How does this simplicity contribute to the efficiency of backpropagation in classification tasks?
  5. Verification: You can conceptually verify this by considering a simple case (e.g., 2 classes, one-hot labels) and thinking about what would happen if y^i is much larger or much smaller than yi. Numerically, you could perform a small forward pass, calculate the loss, and then compute the analytical gradient and compare it to a numerical gradient approximation (e.g., using finite differences) for a single logit change.
import numpy as np

def softmax(z):
    exp_z = np.exp(z - np.max(z)) # for numerical stability
    return exp_z / np.sum(exp_z)

def cross_entropy_loss(y_true, y_pred_probs):
    # y_true is one-hot encoded
    # y_pred_probs are softmax outputs
    # Ensure to handle log(0) cases, usually by clipping predictions
    epsilon = 1e-10
    y_pred_probs = np.clip(y_pred_probs, epsilon, 1 - epsilon)
    return -np.sum(y_true * np.log(y_pred_probs))

# Example logits and true label
# logits = np.array([2.0, 1.0, 0.1])
# y_true = np.array() # Class 0 is true

# y_pred = softmax(logits)
# print(f"Predicted probabilities: {y_pred}")

# Analytical gradient for logits:
# grad_analytical = y_pred - y_true
# print(f"Analytical gradient: {grad_analytical}")

# For numerical verification (optional, more advanced):
# You'd perturb each logit slightly and observe change in loss.
# e.g., for z: (loss(z + epsilon_vec) - loss(z - epsilon_vec)) / (2 * epsilon)