ML Katas

Implement a Knowledge Distillation Loss

medium (<30 mins) pytorch compression distillation
yesterday by E

Description

Knowledge Distillation is a model compression technique where a small "student" model is trained to mimic a larger, pre-trained "teacher" model. [1] This is achieved by training the student on a loss that combines performance on the true labels with mimicking the teacher's output probabilities. Your task is to implement this combined loss function.

Guidance

The total loss is a weighted sum: L=α·LCE+(1α)·LKD. - LCE is the standard cross-entropy loss with the ground truth labels. - LKD is the distillation loss. It's typically the KL Divergence between the student's and teacher's "softened" predictions. Softening is done by dividing the logits by a temperature T before applying the softmax. pi=exp(zi/T)jexp(zj/T)

Starter Code

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, teacher_model, alpha=0.1, T=2.0):
        super().__init__()
        self.teacher_model = teacher_model
        self.alpha = alpha
        self.T = T
        # Use KL Divergence for distillation and Cross Entropy for the student loss
        self.distillation_loss_fn = nn.KLDivLoss(reduction='batchmean')
        self.student_loss_fn = nn.CrossEntropyLoss()

    def forward(self, student_logits, labels, inputs):
        # 1. Get the teacher's logits for the same inputs.
        #    Remember to put the teacher in eval mode and use no_grad().
        with torch.no_grad():
            teacher_logits = self.teacher_model(inputs)

        # 2. Calculate the standard student loss against the true labels.
        loss_ce = self.student_loss_fn(student_logits, labels)

        # 3. Calculate the distillation loss.
        #    a. Soften the student and teacher logits using the temperature T.
        #    b. Apply log_softmax to the softened student logits and softmax to the teacher's.
        #    c. Compute the KL Divergence.
        loss_kd = self.distillation_loss_fn(
            ..., # Soft student predictions
            ...  # Soft teacher predictions
        )

        # 4. Return the weighted average of the two losses.
        return self.alpha * loss_ce + (1.0 - self.alpha) * self.T**2 * loss_kd

Verification

Create a dummy student and teacher model. Pass random data through your loss function. The output should be a scalar tensor. Check that the loss is lower when student and teacher logits are similar, and higher when they are different.

References

[1] Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the Knowledge in a Neural Network. arXiv preprint.