ML Katas

Implementing a Custom `nn.Module` for a Gated Recurrent Unit (GRU)

medium (<1 hr) rnn gru custom module recurrent
this month by E

Implement a custom GRU cell as a subclass of torch.nn.Module. Your implementation should handle the reset gate, update gate, and the new hidden state computation from scratch, using torch.nn.Linear layers for the weight matrices. The forward method should take an input and the previous hidden state and return the new hidden state. Do not use torch.nn.GRU or torch.nn.GRUCell.

Verification: Compare the output of your custom GRU cell to torch.nn.GRUCell for a single step with a random input and hidden state. The outputs should be numerically very close, with only minor floating-point differences.