Tensor Manipulation: Numerically Stable Softmax
Description
Implement the softmax function, which converts a vector of numbers into a probability distribution. A naive implementation can be numerically unstable if the input values are very large. Your task is to implement a stable version from scratch.
Guidance
The trick to stability is the "max trick": subtract the maximum value of the input tensor from all elements before exponentiating. This shifts the values into a negative range, preventing overflow, without changing the final probabilities.
Starter Code
import torch
def stable_softmax(x):
# x shape: (B, ..., C)
# 1. Subtract the max for numerical stability
x_minus_max = x - x.max(dim=-1, keepdim=True).values
# 2. Exponentiate
exp_x = torch.exp(x_minus_max)
# 3. Normalize to get probabilities
return exp_x / exp_x.sum(dim=-1, keepdim=True)
Verification
Create a random tensor. The output of your function should sum to 1 along the last dimension. Compare your function's output to torch.nn.functional.softmax
. Then, test it with a tensor containing large values (e.g., torch.tensor([[1000, 1010]])
) to see if a naive implementation fails while yours succeeds.