ML Katas

Build a Transformer Encoder Block from Scratch

hard (<30 mins) pytorch transformer attention nlp
yesterday by E

Description

The Transformer architecture is built upon a fundamental component: the Encoder block. [1] Each block is responsible for processing a sequence of embeddings and refining them. Your task is to implement a single Transformer Encoder block from scratch, containing a Multi-Head Self-Attention mechanism and a Position-wise Feed-Forward Network.

Guidance

A standard Encoder block has two main sub-layers: 1. Multi-Head Self-Attention: This involves splitting the input Query, Key, and Value tensors into multiple heads, applying scaled dot-product attention independently to each head, concatenating the results, and passing them through a final linear layer. 2. Position-wise Feed-Forward Network: A simple two-layer MLP applied independently to each position in the sequence. Each of these sub-layers is wrapped with a residual connection and a Layer Normalization.

Starter Code

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        # 1. Define linear layers for Q, K, V projections and the final output.
        # 2. Store num_heads and d_model.
        pass
    def forward(self, q, k, v, mask=None):
        # 1. Project Q, K, V.
        # 2. Reshape and transpose for multi-head attention.
        # 3. Compute scaled dot-product attention.
        # 4. Concatenate heads and pass through the final linear layer.
        pass

class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        # 1. Initialize MultiHeadAttention, two LayerNorms,
        #    and a Position-wise Feed-Forward network.
        self.attention = MultiHeadAttention(...)
        self.norm1 = nn.LayerNorm(d_model)
        # ... and so on

    def forward(self, x, mask=None):
        # 1. Compute attention, add residual, and normalize.
        # 2. Pass through feed-forward, add residual, and normalize.
        pass

Verification

Instantiate your EncoderBlock and pass a random tensor of shape (batch_size, seq_len, d_model). The output tensor must have the exact same shape. This is a critical property that allows stacking multiple blocks.

References

[1] Vaswani, A., et al. (2017). Attention Is All You Need. NeurIPS 2017.