ML Katas

Differentiating Through a Non-differentiable Function with `torch.autograd.Function`

hard (<1 hr) autograd custom gradient function backprop
this month by E

Implement a custom torch.autograd.Function for a non-differentiable operation, such as a custom quantization function. The forward method will perform the non-differentiable operation, and the backward method will provide a manual gradient approximation (e.g., a straight-through estimator). This is essential for training models with discrete operations.

Verification: Create a simple computation graph that uses your custom function. Print the gradient of the loss with respect to the input tensor. The gradient should be non-zero and should reflect the manual gradient you defined in the backward method, even though the forward pass was non-differentiable.