-
Implementing a simplified Beam Search Decoder with `gather` and `scatter`
Beam search is a popular decoding algorithm used in machine translation and text generation. A key step in beam search is to select the top-k most likely next tokens and update the corresponding...
-
Efficiently Updating a Sub-Tensor with `scatter_`
Sometimes you need to update a portion of a larger tensor with values from a smaller tensor, where the update locations are specified by an index tensor. This is a common pattern in various...
-
Forward-mode vs. Reverse-mode autodiff
JAX supports both forward-mode and reverse-mode automatic differentiation. While `grad` uses reverse-mode, you can use `jax.jvp` for forward-mode, which computes Jacobian-vector products....
-
Custom VJP
For some functions, you may want to define a custom vector-Jacobian product (VJP). This can be useful for numerical stability or for implementing algorithms that are not easily expressed in terms...
-
Parallelization with `pmap`
For large models, it is often necessary to train on multiple devices (e.g., GPUs or TPUs). JAX's `pmap` transformation allows for easy parallelization of computations across devices. In this...
-
The Elegant Gradient of Softmax-Cross-Entropy
One of the most satisfying derivations in deep learning is the gradient of the combined Softmax and Cross-Entropy loss. For a multi-class classification problem with $K$ classes, given true labels...
-
Deconstructing Self-Attention Scores
The self-attention mechanism is a core component of Transformers. Let's break down how attention scores are calculated. 1. **Query, Key, Value**: In self-attention, each input token (or its...
-
Dissecting the Variational Autoencoder's ELBO
Variational Autoencoders (VAEs) are powerful generative models that optimize a lower bound on the data log-likelihood, known as the Evidence Lower Bound (ELBO). The ELBO for a single data point...
-
PCA from First Principles
Principal Component Analysis (PCA) is a fundamental dimensionality reduction technique. It works by transforming the data into a new coordinate system such that the greatest variance by any...
-
SVD for Image Compression
Singular Value Decomposition (SVD) is a powerful matrix factorization technique with numerous applications, including dimensionality reduction, noise reduction, and data compression. Any real $m...
-
Softmax and its Jacobian
The softmax function is a critical component in multi-class classification, converting a vector of arbitrary real values into a probability distribution. Given an input vector $\mathbf{z} = [z_1,...
-
Backpropagation for a Single-Layer Network
Backpropagation is the cornerstone algorithm for training neural networks. It efficiently calculates the gradients of the loss function with respect to all the weights and biases in the network by...
-
MoE Aggregator: Combining Expert Outputs
After tokens have been dispatched to and processed by their respective experts, the outputs need to be combined based on the weights from the gating network. This exercise focuses on this...
-
Building a Simple Mixture of Experts (MoE) Layer
Now, let's combine the concepts of dispatching and aggregating into a full, albeit simplified, `torch.nn.Module` for a Mixture of Experts layer. This layer will replace a standard feed-forward...
-
Extract Diagonal from a Batch of Matrices
### Description Given a batch of square matrices of shape `(B, N, N)`, extract the diagonal of each matrix. The output should be a tensor of shape `(B, N)`. This can be achieved with...
-
Custom Gradient with `jax.custom_vjp`
Implement a function with a custom gradient using `jax.custom_vjp`. This is useful for numerical stability or for defining gradients for non-differentiable operations. A good example is a function...
-
Implement a Convolutional Layer
Implement a 2D convolutional layer from scratch in JAX. This will involve using `jax.lax.conv_general_dilated`. You will need to manage the kernel initialization and the forward pass logic....
-
Data Parallelism with `pmap`
Parallelize a training step across multiple devices (e.g., multiple CPU cores if you don't have GPUs/TPUs) using `jax.pmap`. This is a fundamental technique for scaling up training. **Task:** Take...
-
Combine `vmap` and `pmap`
For more complex parallelism patterns, you can combine `vmap` and `pmap`. For instance, you can use `pmap` for data parallelism across devices, and `vmap` for model ensembling on each device....
-
Implement a Neural Ordinary Differential Equation
### Description Instead of modeling a function directly, a Neural ODE models its derivative with a neural network. The output is then found by integrating this derivative over time. [1] Your task...
-
Model-Agnostic Meta-Learning (MAML) Update Step
### Description Model-Agnostic Meta-Learning (MAML) is a meta-learning algorithm that trains a model's initial parameters such that it can adapt to a new task with only a few gradient steps. [1]...
-
Build a Transformer Encoder Block from Scratch
### Description The Transformer architecture is built upon a fundamental component: the Encoder block. [1] Each block is responsible for processing a sequence of embeddings and refining them. Your...
-
Soft Actor-Critic (SAC) Critic Loss
### Description Soft Actor-Critic (SAC) is a state-of-the-art reinforcement learning algorithm known for its stability and sample efficiency. [1] A key component is its critic (or Q-network)...
-
Neural Cellular Automata (NCA) Update Step
### Description Neural Cellular Automata (NCA) are a fascinating generative model where complex global patterns emerge from simple, local rules learned by a neural network. [1] A grid of "cells,"...
-
Bayesian Neural Network Layer
### Description In a standard neural network, weights are single point estimates. In a Bayesian Neural Network (BNN), we learn a probability distribution over each weight. [1] This allows for...