ML Katas

Spiking Neuron with Leaky Integrate-and-Fire

hard (<30 mins) pytorch snn spiking neuroscience
yesterday by E

Description

Implement a single Leaky Integrate-and-Fire (LIF) neuron, the fundamental building block of many Spiking Neural Networks (SNNs). Unlike traditional neurons, LIF neurons operate on discrete time steps and communicate through binary spikes. Their internal state (membrane potential) changes based on input spikes and a leak term. [1]

Equations

Membrane Potential Update: V[t+1]=βV[t]+I[t+1] Spike Condition: S[t+1]={1if V[t+1]>Vth0otherwise Reset Mechanism: V[t+1]=V[t+1]×(1S[t+1]) Where V is the membrane potential, I is the input current, S is the output spike, Vth is the threshold, and β is the leak/decay factor (0<β<1).

Guidance

The key challenge here is managing the neuron's state (its membrane potential) across multiple time steps (i.e., multiple calls to forward). This state should be stored as part of the module, for example as a PyTorch buffer, so that it persists between calls.

Starter Code

import torch
import torch.nn as nn

class LIFNeuron(nn.Module):
    def __init__(self, threshold=1.0, beta=0.9):
        super().__init__()
        self.threshold = threshold
        self.beta = beta
        # 1. The membrane potential is a state that needs to be stored.
        #    <!--CODE_BLOCK_3622--> is a good way to do this for a non-trainable parameter.
        self.register_buffer('potential', torch.tensor(0.0))
        self.register_buffer('spike', torch.tensor(0.0))

    def forward(self, input_current):
        # This function simulates one time step.
        # 1. Update the membrane potential based on the leak (beta) and the input current.
        # 2. Check if the new potential exceeds the threshold, producing a binary spike.
        # 3. If a spike occurred, reset the potential to 0 (or a reset potential).
        # 4. Store the new potential and the output spike in the module's state.
        # 5. Return the output spike.
        pass

Verification

Instantiate the LIF neuron. Pass a sequence of input currents over several calls to forward(). For example, a low constant input should not cause a spike. A high constant input should cause the neuron to spike periodically. You can plot the membrane potential over time to observe its dynamics. Note that making this differentiable for backpropagation requires using a surrogate gradient for the spike function, which is a more advanced topic. [2]

References

[1] Gerstner, W., & Kistler, W. M. (2002). Spiking neuron models.

[2] Neftci, E. O., Mostafa, H., & Zenke, F. (2019). Surrogate gradient learning in spiking neural networks.