Tensor Manipulation: Causal Mask for Transformers
Description
In decoder-style Transformers (like GPT), we need a "causal" or "look-ahead" mask to prevent positions from attending to subsequent positions. This is typically a lower-triangular matrix. Your task is to create this mask.
Guidance
Create a square matrix of size (N, N)
. Use torch.ones
and torch.tril
(or torch.triu
) to generate the mask. The masked (future) positions are usually set to -inf
before the softmax, so that they become zero probability.
Starter Code
import torch
def create_causal_mask(seq_len):
# Create a mask of shape (seq_len, seq_len)
# The upper triangle (excluding the diagonal) should be -inf
# The lower triangle (including the diagonal) should be 0.0
mask = torch.ones(seq_len, seq_len).tril()
mask = mask.masked_fill(mask == 0, float('-inf'))
mask = mask.masked_fill(mask == 1, 0.0)
return mask
Verification
For seq_len=4
, your function should produce a 4x4 tensor where the diagonal and all elements below it are 0.0
, and all elements above the diagonal are -inf
.