ML Katas

Tensor Manipulation: Causal Mask for Transformers

easy (<10 mins) transformer attention tensor
today by E

Description

In decoder-style Transformers (like GPT), we need a "causal" or "look-ahead" mask to prevent positions from attending to subsequent positions. This is typically a lower-triangular matrix. Your task is to create this mask.

Guidance

Create a square matrix of size (N, N). Use torch.ones and torch.tril (or torch.triu) to generate the mask. The masked (future) positions are usually set to -inf before the softmax, so that they become zero probability.

Starter Code

import torch

def create_causal_mask(seq_len):
    # Create a mask of shape (seq_len, seq_len)
    # The upper triangle (excluding the diagonal) should be -inf
    # The lower triangle (including the diagonal) should be 0.0
    mask = torch.ones(seq_len, seq_len).tril()
    mask = mask.masked_fill(mask == 0, float('-inf'))
    mask = mask.masked_fill(mask == 1, 0.0)
    return mask

Verification

For seq_len=4, your function should produce a 4x4 tensor where the diagonal and all elements below it are 0.0, and all elements above the diagonal are -inf.