ML Katas

HyperNetwork for Weight Generation

hard (<30 mins) pytorch hypernetwork meta-learning
yesterday by E

Description

Implement a simple HyperNetwork. A HyperNetwork is a neural network that generates the weights for another, larger network (the "target network"). [1] This allows for dynamic weight generation or significant parameter compression. Your task is to create a small MLP (the HyperNetwork) that outputs the flattened weights for a simple, two-layer target MLP.

Guidance

The core of this exercise is careful parameter counting and tensor reshaping. The HyperNet will take some input (e.g., a latent vector z) and produce a single long vector. You must then slice and dice this vector into the correct shapes for the weights and biases of each layer in the TargetNet. The TargetNet's forward pass must be written to accept these externally generated weights.

Starter Code

import torch
import torch.nn as nn
import torch.nn.functional as F

class TargetNet(nn.Module):
    def __init__(self):
        super().__init__()
        # This network is defined by its architecture, but has no parameters itself.
        # All weights will be passed into the forward method.

    def forward(self, x, weights):
        # 'weights' will be a tuple or list like (w1, b1, w2, b2)
        # 1. Implement a standard MLP forward pass, but use the weights
        #    provided instead of internal nn.Linear layers.
        #    Hint: Use F.linear(input, weight, bias)
        pass

class HyperNet(nn.Module):
    def __init__(self, z_dim, target_in, target_hidden, target_out):
        super().__init__()
        # 1. Calculate the total number of parameters (weights and biases)
        #    needed for the TargetNet.
        self.total_target_params = ...

        # 2. Define a simple MLP that takes z and outputs a vector of
        #    size <!--CODE_BLOCK_4444-->.
        self.generator = nn.Sequential(...)

    def forward(self, z):
        # 1. Get the flat vector of parameters from the generator.
        # 2. Reshape this vector into the separate weight and bias tensors
        #    (w1, b1, w2, b2) with the correct dimensions.
        # 3. Return these tensors.
        pass

Verification

Create instances of both networks. Pass a random vector z to the HyperNet to generate weights. Pass data x and the generated weights to the TargetNet. The forward pass should complete without shape errors. The ultimate test is to successfully backpropagate a loss from the TargetNet's output through to the HyperNet's weights.

References

[1] Ha, D., Dai, A., & Le, Q. V. (2016). HyperNetworks.