Implement Lottery Ticket Hypothesis Pruning
Description
The Lottery Ticket Hypothesis suggests that a randomly initialized, dense network contains a smaller subnetwork (a "winning ticket") that, when trained in isolation, can match the performance of the original network. [1] Your task is to implement one step of the iterative magnitude pruning algorithm used to find these tickets.
Algorithm
- Initialize a dense network and save a deep copy of its initial weights (θ₀).
- Train the network for
j
iterations to get weights θⱼ. - Prune a percentage (p%) of the weights in θⱼ with the lowest magnitude. This creates a mask.
- Reset the weights of the remaining connections to their initial values from θ₀.
Guidance
You can implement this as a function that takes a trained model and its initial state dictionary. The torch.nn.utils.prune
module is your friend here, specifically global_unstructured
. After you apply pruning, the module's weight parameter is replaced by weight_orig
(the original tensor) and weight_mask
(the pruning mask). Your job is to access weight_orig
for each pruned layer and copy the values from your saved initial weights into it.
Starter Code
import torch
import torch.nn.utils.prune as prune
import copy
# Assume 'model' is your nn.Module instance
# and 'initial_state_dict' is a deep copy of its state at initialization.
def find_winning_ticket_step(model, initial_state_dict, prune_amount=0.2):
# Note: The model is assumed to have been trained before this function is called.
# 1. Identify the parameters to prune (e.g., weights of Linear and Conv2d layers).
parameters_to_prune = []
# ... loop through modules to populate this list ...
# 2. Use prune.global_unstructured to apply the pruning mask in place.
# 3. Iterate through the model's modules again. For any module that has been pruned,
# it will now have 'weight_orig' and 'weight_mask' attributes.
# Reset the values in 'weight_orig' to the corresponding weights
# from 'initial_state_dict'.
with torch.no_grad():
pass # Your logic here
return model
Verification
Verify that after running your function, the model's non-pruned weights have been reset to their initial values. You can do this by comparing the weight_orig
tensor in a pruned module to the corresponding tensor in initial_state_dict
. Also, check the sparsity of the model to ensure that the correct percentage of weights has been pruned (set to zero).
References
[1] Frankle, J., & Carbin, M. (2018). The Lottery Ticket Hypothesis.