Gradient Reversal Layer
Description
Implement a Gradient Reversal Layer (GRL), a key component in Domain-Adversarial Neural Networks (DANNs). [1] The GRL acts as an identity function during the forward pass but reverses the gradient (multiplies it by a negative scalar) during the backward pass. This is achieved by creating a custom torch.autograd.Function
.
Guidance
To create a custom layer with a non-standard backward pass, you must subclass torch.autograd.Function
. You need to implement two static methods: forward
and backward
. The forward
method performs the operation and can save any necessary values for the backward pass in a context object ctx
. The backward
method receives the incoming gradient (gradient of the loss w.r.t the function's output) and must return the gradient w.r.t the function's inputs.
Starter Code
from torch.autograd import Function
import torch.nn as nn
class GradientReversalFunction(Function):
@staticmethod
def forward(ctx, x, alpha):
# Store the alpha value for the backward pass
ctx.alpha = alpha
# This function is an identity function during the forward pass
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
# grad_output is the gradient from the next layer
# 1. Reverse the gradient by multiplying with -1.
# 2. Scale it by the stored alpha value.
# 3. The backward method needs to return a gradient for each input
# of the forward method. The gradient for <!--CODE_BLOCK_3768--> is None.
pass
class GradientReversalLayer(nn.Module):
def __init__(self, alpha=1.0):
super(GradientReversalLayer, self).__init__()
self.alpha = alpha
def forward(self, x):
# Apply the custom autograd function
return GradientReversalFunction.apply(x, self.alpha)
Verification
Create a simple network: input -> linear1 -> GRL -> linear2 -> output
. Compute a loss and call loss.backward()
. Now, check the gradient of the weights in linear1
(linear1.weight.grad
). The sign of the gradient should be flipped compared to what it would be without the GRL.
References
[1] Ganin, Y., et al. (2016). Domain-adversarial training of neural networks.