HyperNetwork for Weight Generation
Description
Implement a simple HyperNetwork. A HyperNetwork is a neural network that generates the weights for another, larger network (the "target network"). [1] This allows for dynamic weight generation or significant parameter compression. Your task is to create a small MLP (the HyperNetwork) that outputs the flattened weights for a simple, two-layer target MLP.
Guidance
The core of this exercise is careful parameter counting and tensor reshaping. The HyperNet
will take some input (e.g., a latent vector z
) and produce a single long vector. You must then slice and dice this vector into the correct shapes for the weights and biases of each layer in the TargetNet
. The TargetNet
's forward pass must be written to accept these externally generated weights.
Starter Code
import torch
import torch.nn as nn
import torch.nn.functional as F
class TargetNet(nn.Module):
def __init__(self):
super().__init__()
# This network is defined by its architecture, but has no parameters itself.
# All weights will be passed into the forward method.
def forward(self, x, weights):
# 'weights' will be a tuple or list like (w1, b1, w2, b2)
# 1. Implement a standard MLP forward pass, but use the weights
# provided instead of internal nn.Linear layers.
# Hint: Use F.linear(input, weight, bias)
pass
class HyperNet(nn.Module):
def __init__(self, z_dim, target_in, target_hidden, target_out):
super().__init__()
# 1. Calculate the total number of parameters (weights and biases)
# needed for the TargetNet.
self.total_target_params = ...
# 2. Define a simple MLP that takes z and outputs a vector of
# size <!--CODE_BLOCK_4444-->.
self.generator = nn.Sequential(...)
def forward(self, z):
# 1. Get the flat vector of parameters from the generator.
# 2. Reshape this vector into the separate weight and bias tensors
# (w1, b1, w2, b2) with the correct dimensions.
# 3. Return these tensors.
pass
Verification
Create instances of both networks. Pass a random vector z
to the HyperNet
to generate weights. Pass data x
and the generated weights to the TargetNet
. The forward pass should complete without shape errors. The ultimate test is to successfully backpropagate a loss from the TargetNet
's output through to the HyperNet
's weights.
References
[1] Ha, D., Dai, A., & Le, Q. V. (2016). HyperNetworks.