-
Working with an optimizer library: Optax
Optax is a popular library for optimization in JAX. It provides a wide range of optimizers and is designed to be highly modular. In this exercise, you will use Optax to train the Flax MLP from the...
-
Parallelization with `pmap`
For large models, it is often necessary to train on multiple devices (e.g., GPUs or TPUs). JAX's `pmap` transformation allows for easy parallelization of computations across devices. In this...
-
Checkpointing
When training large models, it is important to save the model's parameters periodically. This is known as checkpointing and allows you to resume training from a saved state in case of an...
-
The Elegant Gradient of Softmax-Cross-Entropy
One of the most satisfying derivations in deep learning is the gradient of the combined Softmax and Cross-Entropy loss. For a multi-class classification problem with $K$ classes, given true labels...
-
Softmax's Numerical Stability: The Max Trick
While the standard softmax formula $\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}}$ is mathematically correct, a direct implementation can lead to numerical instability due to potential...
-
Tracing Gradient Descent on a Parabola
Imagine a simple 1D function $f(x) = x^2 - 4x + 5$. Your goal is to find the minimum of this function using Gradient Descent. 1. **Derive the gradient**: What is $\frac{df}{dx}$? 2. **Perform a...
-
The Gradients of Activation Functions
Activation functions introduce non-linearity into neural networks, but their derivatives are crucial for backpropagation. 1. **Sigmoid**: Given $\sigma(x) = \frac{1}{1 + e^{-x}}$, derive...
-
Cross-Entropy: A Measure of Surprise
Cross-entropy loss is fundamental for classification tasks. Let's build some intuition for its formulation. 1. **Definition**: For a binary classification problem, the binary cross-entropy (BCE)...
-
L2 Regularization's Gradient Impact
L2 regularization (also known as weight decay) is a common technique to prevent overfitting. 1. **Loss Function**: Consider a simple linear regression loss with L2 regularization: $J(\mathbf{w},...
-
Deconstructing Self-Attention Scores
The self-attention mechanism is a core component of Transformers. Let's break down how attention scores are calculated. 1. **Query, Key, Value**: In self-attention, each input token (or its...
-
The Stabilizing Power of Batch Normalization
Batch Normalization (BatchNorm) is a crucial technique for stabilizing and accelerating deep neural network training. 1. **Normalization Step**: Given a mini-batch of activations $X = \{x_1, x_2,...
-
The Implicit Higher Dimension of Kernels
Support Vector Machines (SVMs) are powerful, and the "kernel trick" allows them to find non-linear decision boundaries without explicitly mapping data to high-dimensional spaces. 1. **Linear...
-
Dissecting the Variational Autoencoder's ELBO
Variational Autoencoders (VAEs) are powerful generative models that optimize a lower bound on the data log-likelihood, known as the Evidence Lower Bound (ELBO). The ELBO for a single data point...
-
Riding the Momentum Wave in Optimization
Stochastic Gradient Descent (SGD) with momentum is a popular optimization algorithm that often converges faster and more stably than plain SGD. 1. **Update Rule**: The update rule for SGD with...
-
KL Divergence Calculation and Interpretation
The Kullback-Leibler (KL) Divergence (also known as relative entropy) is a non-symmetric measure of how one probability distribution $P$ is different from a second, reference probability...
-
PCA from First Principles
Principal Component Analysis (PCA) is a fundamental dimensionality reduction technique. It works by transforming the data into a new coordinate system such that the greatest variance by any...
-
SVD for Image Compression
Singular Value Decomposition (SVD) is a powerful matrix factorization technique with numerous applications, including dimensionality reduction, noise reduction, and data compression. Any real $m...
-
Numerical Gradient Verification
Understanding and correctly implementing backpropagation is crucial in deep learning. A common way to debug backpropagation is using numerical gradient checking. This involves approximating the...
-
Softmax and its Jacobian
The softmax function is a critical component in multi-class classification, converting a vector of arbitrary real values into a probability distribution. Given an input vector $\mathbf{z} = [z_1,...
-
Numerical Stability: Log-Sum-Exp
When dealing with probabilities, especially in log-space, sums of exponentials can lead to numerical underflow or overflow. For example, computing $\log \left( \sum_i \exp(x_i) \right)$ can be...
-
L2 Regularization Gradient
L2 regularization (also known as Ridge Regression or weight decay) is a common technique to prevent overfitting in machine learning models by adding a penalty proportional to the square of the...
-
Linear Regression via Gradient Descent
Linear regression is a foundational supervised learning algorithm. Given a dataset of input features $X$ and corresponding target values $y$, the goal is to find a linear relationship $y =...
-
Backpropagation for a Single-Layer Network
Backpropagation is the cornerstone algorithm for training neural networks. It efficiently calculates the gradients of the loss function with respect to all the weights and biases in the network by...
-
Matrix Multiplication and Efficiency
Matrix multiplication is a fundamental operation in linear algebra and a cornerstone of deep learning. Given two matrices $A$ (size $m \times k$) and $B$ (size $k \times n$), their product $C =...
-
Einops Warm-up: Reshaping Tensors for Expert Batching
In Mixture of Experts (MoE) models, we often need to reshape tensors to efficiently process data across multiple 'experts'. Imagine you have a batch of sequences, and for each token in each...