-
Einops: ViT-Style Patch Embedding
### Description Vision Transformers (ViT) process images by first breaking them down into a sequence of flattened patches. The `einops` library is perfectly suited for this task, offering a...
-
Einops: Multi-Head Attention Input Projection
### Description In the Multi-Head Attention mechanism, the input tensor `(B, N, D)` is linearly projected to create the Query, Key, and Value matrices. These are then reshaped to have separate...
-
Einops: Space-to-Depth Transformation
### Description Space-to-depth is an operation that rearranges blocks of spatial data into the channel dimension. For an input of shape `(B, C, H, W)` and a block size `S`, the output will be `(B,...
-
Tensor Manipulation: Causal Mask for Transformers
### Description In decoder-style Transformers (like GPT), we need a "causal" or "look-ahead" mask to prevent positions from attending to subsequent positions. This is typically a lower-triangular...
-
Einops: Reversing a Sequence
### Description Reversing the order of elements in a sequence is a common operation. While it can be done with slicing (`torch.flip`), let's practice doing it with `einops` for a different...
-
Tensor Manipulation: One-Hot Encoding
### Description Implement one-hot encoding for a batch of class indices. Given a 1D tensor of integer labels, create a 2D tensor where each row is a vector of zeros except for a `1` at the index...
-
Einops: Depth-to-Space Transformation
### Description Depth-to-space is the inverse of the space-to-depth operation. It rearranges features from the channel dimension into spatial blocks, increasing spatial resolution and decreasing...
-
Einops: Batched Matrix Multiplication
### Description Perform a batched matrix multiplication `(B, N, D) @ (B, D, M) -> (B, N, M)` using `einops` `einsum`. While `torch.bmm` is the standard, this is a good exercise to understand how...
-
Einops: Squeeze and Unsqueeze
### Description `torch.squeeze` and `torch.unsqueeze` are common for removing or adding dimensions of size one. `einops.rearrange` can do this as well, often with more clarity by explicitly naming...
-
Einops: Transpose for Attention Output
### Description After the multi-head attention calculation, the output tensor typically has the shape `(B, num_heads, N, head_dim)`. To feed this into the next layer (usually a feed-forward...
-
Einops: Repeat for Tiling/Broadcasting
### Description The `einops.repeat` function is a powerful and readable alternative to `torch.expand` or `torch.tile` for broadcasting or repeating a tensor along new or existing dimensions. ###...
-
Build a Custom ReLU Activation Function
Implement the Rectified Linear Unit (ReLU) activation function in JAX. Then, use `jax.grad` to find its derivative. The ReLU function is defined as: $ReLU(x) = max(0, x)$ **Verification:** - For...
-
Vectorized Operations with vmap
You have a function that processes a single data point. Your goal is to use `jax.vmap` to apply this function to a whole batch of data without writing an explicit loop. For example, consider a...
-
Debugging JAX code with `jax.debug.print`
In JAX, standard Python `print` statements don't always work as expected within `jit`-compiled functions because they execute at trace time. [11] The solution is to use `jax.debug.print`. [11, 23]...