Sliding Window Attention Preparation
Description
Full self-attention has a quadratic complexity with respect to sequence length, which is prohibitive for very long sequences. Models like Longformer introduce sliding window attention, where each token only attends to a local neighborhood of tokens. [11] Your task is to prepare the input tensor for this operation by creating overlapping windows (or chunks) of the sequence.
Guidance
This can be achieved elegantly using torch.Tensor.unfold. You will take a sequence of tokens and create a new dimension containing overlapping windows.
1. Input: A tensor of shape (B, N, D), where N is the sequence length.
2. Unfold: Use the unfold method on the sequence dimension (dim=1). You need to specify the size of the window and the step between windows.
3. Reshape for Attention: The output of unfold will have shape (B, num_windows, D, window_size). You'll want to permute and reshape this to (B * num_windows, window_size, D) so you can run a standard attention mechanism on the (B * num_windows) batch.
Starter Code
import torch
from einops import rearrange
def create_sliding_windows(x, window_size, step):
"""
Args:
x (torch.Tensor): Input tensor of shape (B, N, D)
window_size (int): The size of each attention window.
step (int): The step size between the start of each window.
Returns:
torch.Tensor: A tensor of overlapping windows, shape (B * num_windows, window_size, D)
"""
# 1. Use the .unfold() method on the sequence dimension.
# The output will have an extra dimension at the end for the window.
unfolded_x = x.unfold(dimension=1, size=window_size, step=step)
# 2. Permute and reshape the tensor for batch attention.
# The unfolded tensor has shape (B, num_windows, D, window_size).
# We need to rearrange it to (B, num_windows, window_size, D) and then
# merge the first two dimensions.
# einops is great for this!
windows = rearrange(unfolded_x, 'b nw d ws -> (b nw) ws d')
return windows
Verification
Create a dummy tensor and pass it through your function. Verify the output shape and also check that the content of a window corresponds to the correct slice of the original tensor.
# Parameters
B, N, D = 2, 1024, 64
window_size = 128
step = 64 # Overlap of 64
# Dummy data
input_tensor = torch.randn(B, N, D)
# Create windows
windows = create_sliding_windows(input_tensor, window_size, step)
# Verification checks
num_windows = (N - window_size) // step + 1
expected_shape = (B * num_windows, window_size, D)
assert windows.shape == expected_shape
# Check content of the first window of the first batch item
first_window_actual = windows
first_window_expected = input_tensor[0, 0:window_size, :]
assert torch.allclose(first_window_actual, first_window_expected)
# Check content of the second window of the first batch item
second_window_actual = windows
second_window_expected = input_tensor[0, step:step + window_size, :]
assert torch.allclose(second_window_actual, second_window_expected)
print("Verification successful!")
References
[1] Beltagy, I., Peters, M. E., & Cohan, A. (2020). Longformer: The Long-Document Transformer. arXiv preprint. [11, 12, 13, 17, 18]