ML Katas

Neural Cellular Automata (NCA) Update Step

hard (<30 mins) pytorch generative nca alife complex systems
yesterday by E

Description

Neural Cellular Automata (NCA) are a fascinating generative model where complex global patterns emerge from simple, local rules learned by a neural network. [1] A grid of "cells," each with a state vector, updates itself iteratively. Your task is to implement a single update step for a 2D NCA.

Guidance

The update step for all cells happens in parallel and involves: 1. Perception: Each cell perceives its local neighborhood. This can be implemented efficiently as a convolution with fixed filters (e.g., Sobel filters to detect gradients) applied to the cell state grid. 2. Update Rule: The perceived neighborhood information for each cell is passed through a small MLP (the same MLP for all cells). This MLP outputs a state update vector. 3. Stochastic Update: A random mask determines which cells actually apply their update. This prevents the NCA from learning to rely on brittle, coordinated updates. 4. State Update: The calculated updates are added to the current cell states.

Starter Code

import torch
import torch.nn as nn

class NCA(nn.Module):
    def __init__(self, channel_n=16, hidden_n=128):
        super().__init__()
        # The update rule is a small MLP
        self.update_mlp = nn.Sequential(
            # It takes the perceived state as input and outputs a state update
            # Input channels = perception filters * channel_n
        )
        # Perception filters (e.g., Sobel filters) can be defined here
        # as fixed convolutional kernels.

    def perceive(self, x):
        # Use your fixed convolutional kernels to perceive the neighborhood
        pass

    def forward(self, x, update_rate=0.5):
        # 1. Get perception vectors for each cell.
        # 2. Pass them through the update MLP to get the state updates.
        # 3. Create a random mask for the stochastic update.
        # 4. Add the updates to the original state grid <!--CODE_BLOCK_3178--> where the mask is active.
        pass

Verification

Create an initial grid of cell states (e.g., all zeros except for a single "seed" cell in the center). Instantiate your NCA model with random weights. Repeatedly apply the forward method to the grid. You should observe the cell states changing and spreading outwards from the initial seed, demonstrating that the local update rule is functioning.

References

[1] Mordvintsev, A., et al. (2020). Growing Neural Cellular Automata. Distill.