-
Numerical Stability: Log-Sum-Exp
When dealing with probabilities, especially in log-space, sums of exponentials can lead to numerical underflow or overflow. For example, computing $\log \left( \sum_i \exp(x_i) \right)$ can be...
-
Backpropagation for a Single-Layer Network
Backpropagation is the cornerstone algorithm for training neural networks. It efficiently calculates the gradients of the loss function with respect to all the weights and biases in the network by...
-
Building a Simple Mixture of Experts (MoE) Layer
Now, let's combine the concepts of dispatching and aggregating into a full, albeit simplified, `torch.nn.Module` for a Mixture of Experts layer. This layer will replace a standard feed-forward...
-
Implement Layer Normalization
Implement Layer Normalization from scratch in JAX. Layer Normalization is a key component in many modern neural network architectures, like Transformers. It normalizes the inputs across the...
-
Implementing Dropout
Implement the dropout regularization technique in JAX. This involves randomly setting a fraction of input units to 0 at each update during training time. Remember that dropout should only be...
-
Implement a Convolutional Layer
Implement a 2D convolutional layer from scratch in JAX. This will involve using `jax.lax.conv_general_dilated`. You will need to manage the kernel initialization and the forward pass logic....