Implement a Knowledge Distillation Loss
Description
Knowledge Distillation is a model compression technique where a small "student" model is trained to mimic a larger, pre-trained "teacher" model. [1] This is achieved by training the student on a loss that combines performance on the true labels with mimicking the teacher's output probabilities. Your task is to implement this combined loss function.
Guidance
The total loss is a weighted sum: .
- is the standard cross-entropy loss with the ground truth labels.
- is the distillation loss. It's typically the KL Divergence between the student's and teacher's "softened" predictions. Softening is done by dividing the logits by a temperature T
before applying the softmax.
Starter Code
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def __init__(self, teacher_model, alpha=0.1, T=2.0):
super().__init__()
self.teacher_model = teacher_model
self.alpha = alpha
self.T = T
# Use KL Divergence for distillation and Cross Entropy for the student loss
self.distillation_loss_fn = nn.KLDivLoss(reduction='batchmean')
self.student_loss_fn = nn.CrossEntropyLoss()
def forward(self, student_logits, labels, inputs):
# 1. Get the teacher's logits for the same inputs.
# Remember to put the teacher in eval mode and use no_grad().
with torch.no_grad():
teacher_logits = self.teacher_model(inputs)
# 2. Calculate the standard student loss against the true labels.
loss_ce = self.student_loss_fn(student_logits, labels)
# 3. Calculate the distillation loss.
# a. Soften the student and teacher logits using the temperature T.
# b. Apply log_softmax to the softened student logits and softmax to the teacher's.
# c. Compute the KL Divergence.
loss_kd = self.distillation_loss_fn(
..., # Soft student predictions
... # Soft teacher predictions
)
# 4. Return the weighted average of the two losses.
return self.alpha * loss_ce + (1.0 - self.alpha) * self.T**2 * loss_kd
Verification
Create a dummy student and teacher model. Pass random data through your loss function. The output should be a scalar tensor. Check that the loss is lower when student and teacher logits are similar, and higher when they are different.
References
[1] Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the Knowledge in a Neural Network. arXiv preprint.