ML Katas

Implement Lottery Ticket Hypothesis Pruning

medium (<30 mins) pytorch pruning lottery ticket
yesterday by E

Description

The Lottery Ticket Hypothesis suggests that a randomly initialized, dense network contains a smaller subnetwork (a "winning ticket") that, when trained in isolation, can match the performance of the original network. [1] Your task is to implement one step of the iterative magnitude pruning algorithm used to find these tickets.

Algorithm

  1. Initialize a dense network and save a deep copy of its initial weights (θ₀).
  2. Train the network for j iterations to get weights θⱼ.
  3. Prune a percentage (p%) of the weights in θⱼ with the lowest magnitude. This creates a mask.
  4. Reset the weights of the remaining connections to their initial values from θ₀.

Guidance

You can implement this as a function that takes a trained model and its initial state dictionary. The torch.nn.utils.prune module is your friend here, specifically global_unstructured. After you apply pruning, the module's weight parameter is replaced by weight_orig (the original tensor) and weight_mask (the pruning mask). Your job is to access weight_orig for each pruned layer and copy the values from your saved initial weights into it.

Starter Code

import torch
import torch.nn.utils.prune as prune
import copy

# Assume 'model' is your nn.Module instance
# and 'initial_state_dict' is a deep copy of its state at initialization.

def find_winning_ticket_step(model, initial_state_dict, prune_amount=0.2):
    # Note: The model is assumed to have been trained before this function is called.

    # 1. Identify the parameters to prune (e.g., weights of Linear and Conv2d layers).
    parameters_to_prune = []
    # ... loop through modules to populate this list ...

    # 2. Use prune.global_unstructured to apply the pruning mask in place.

    # 3. Iterate through the model's modules again. For any module that has been pruned,
    #    it will now have 'weight_orig' and 'weight_mask' attributes.
    #    Reset the values in 'weight_orig' to the corresponding weights
    #    from 'initial_state_dict'.
    with torch.no_grad():
        pass # Your logic here

    return model

Verification

Verify that after running your function, the model's non-pruned weights have been reset to their initial values. You can do this by comparing the weight_orig tensor in a pruned module to the corresponding tensor in initial_state_dict. Also, check the sparsity of the model to ensure that the correct percentage of weights has been pruned (set to zero).

References

[1] Frankle, J., & Carbin, M. (2018). The Lottery Ticket Hypothesis.