ML Katas

Gradient Reversal Layer

medium (<30 mins) pytorch autograd gan domain adaptation
yesterday by E

Description

Implement a Gradient Reversal Layer (GRL), a key component in Domain-Adversarial Neural Networks (DANNs). [1] The GRL acts as an identity function during the forward pass but reverses the gradient (multiplies it by a negative scalar) during the backward pass. This is achieved by creating a custom torch.autograd.Function.

Guidance

To create a custom layer with a non-standard backward pass, you must subclass torch.autograd.Function. You need to implement two static methods: forward and backward. The forward method performs the operation and can save any necessary values for the backward pass in a context object ctx. The backward method receives the incoming gradient (gradient of the loss w.r.t the function's output) and must return the gradient w.r.t the function's inputs.

Starter Code

from torch.autograd import Function
import torch.nn as nn

class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        # Store the alpha value for the backward pass
        ctx.alpha = alpha
        # This function is an identity function during the forward pass
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        # grad_output is the gradient from the next layer
        # 1. Reverse the gradient by multiplying with -1.
        # 2. Scale it by the stored alpha value.
        # 3. The backward method needs to return a gradient for each input
        #    of the forward method. The gradient for <!--CODE_BLOCK_3768--> is None.
        pass

class GradientReversalLayer(nn.Module):
    def __init__(self, alpha=1.0):
        super(GradientReversalLayer, self).__init__()
        self.alpha = alpha

    def forward(self, x):
        # Apply the custom autograd function
        return GradientReversalFunction.apply(x, self.alpha)

Verification

Create a simple network: input -> linear1 -> GRL -> linear2 -> output. Compute a loss and call loss.backward(). Now, check the gradient of the weights in linear1 (linear1.weight.grad). The sign of the gradient should be flipped compared to what it would be without the GRL.

References

[1] Ganin, Y., et al. (2016). Domain-adversarial training of neural networks.