Build a Transformer Encoder Block from Scratch
Description
The Transformer architecture is built upon a fundamental component: the Encoder block. [1] Each block is responsible for processing a sequence of embeddings and refining them. Your task is to implement a single Transformer Encoder block from scratch, containing a Multi-Head Self-Attention mechanism and a Position-wise Feed-Forward Network.
Guidance
A standard Encoder block has two main sub-layers: 1. Multi-Head Self-Attention: This involves splitting the input Query, Key, and Value tensors into multiple heads, applying scaled dot-product attention independently to each head, concatenating the results, and passing them through a final linear layer. 2. Position-wise Feed-Forward Network: A simple two-layer MLP applied independently to each position in the sequence. Each of these sub-layers is wrapped with a residual connection and a Layer Normalization.
Starter Code
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
# 1. Define linear layers for Q, K, V projections and the final output.
# 2. Store num_heads and d_model.
pass
def forward(self, q, k, v, mask=None):
# 1. Project Q, K, V.
# 2. Reshape and transpose for multi-head attention.
# 3. Compute scaled dot-product attention.
# 4. Concatenate heads and pass through the final linear layer.
pass
class EncoderBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# 1. Initialize MultiHeadAttention, two LayerNorms,
# and a Position-wise Feed-Forward network.
self.attention = MultiHeadAttention(...)
self.norm1 = nn.LayerNorm(d_model)
# ... and so on
def forward(self, x, mask=None):
# 1. Compute attention, add residual, and normalize.
# 2. Pass through feed-forward, add residual, and normalize.
pass
Verification
Instantiate your EncoderBlock
and pass a random tensor of shape (batch_size, seq_len, d_model)
. The output tensor must have the exact same shape. This is a critical property that allows stacking multiple blocks.
References
[1] Vaswani, A., et al. (2017). Attention Is All You Need. NeurIPS 2017.