ML Katas

Tensor Manipulation: Numerically Stable Softmax

medium (<10 mins) pytorch deep learning tensor
today by E

Description

Implement the softmax function, which converts a vector of numbers into a probability distribution. A naive implementation can be numerically unstable if the input values are very large. Your task is to implement a stable version from scratch.

Guidance

The trick to stability is the "max trick": subtract the maximum value of the input tensor from all elements before exponentiating. This shifts the values into a negative range, preventing overflow, without changing the final probabilities.

softmax(xi)=eximax(x)jexjmax(x)

Starter Code

import torch

def stable_softmax(x):
    # x shape: (B, ..., C)
    # 1. Subtract the max for numerical stability
    x_minus_max = x - x.max(dim=-1, keepdim=True).values

    # 2. Exponentiate
    exp_x = torch.exp(x_minus_max)

    # 3. Normalize to get probabilities
    return exp_x / exp_x.sum(dim=-1, keepdim=True)

Verification

Create a random tensor. The output of your function should sum to 1 along the last dimension. Compare your function's output to torch.nn.functional.softmax. Then, test it with a tensor containing large values (e.g., torch.tensor([[1000, 1010]])) to see if a naive implementation fails while yours succeeds.