ML Katas

Siamese Network for One-Shot Image Verification

hard (<30 mins) pytorch siamese metric learning one-shot
yesterday by E

Description

Your task is to implement a Siamese network that can determine if two images are of the same class, given only one or a few examples of that class at test time. You'll train a model to learn a feature embedding that clusters similar images and separates dissimilar ones. The network will take two images as input and output a similarity score. [1]

Guidance

You will need two main components: 1. A core network (e.g., a CNN) that acts as the feature encoder. This network is instantiated once but its weights are shared for both inputs. 2. A custom loss function, called Contrastive Loss. This loss function takes the pair of encoded feature vectors and a label (indicating if they are from the same class) and calculates a loss that pushes embeddings from the same class closer and embeddings from different classes further apart, enforcing a margin.

Starter Code

import torch
import torch.nn as nn

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # 1. Define your CNN feature encoder here.
        # It should take an image and output a flat feature vector.
        self.encoder = nn.Sequential(
            # ... your layers here ...
        )

    def forward_one(self, x):
        # Helper function to pass one image through the encoder
        return self.encoder(x)

    def forward(self, input1, input2):
        # Pass each input through the shared-weight encoder
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        return output1, output2

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        # 1. Calculate the euclidean distance between the two feature vectors.
        # 2. Implement the contrastive loss formula based on the label.
        pass

Verification

Train on a dataset like Omniglot. [2, 3] To test, create pairs of images (some same class, some different) from classes unseen during training. The model should output high similarity (low distance) for same-class pairs and low similarity (high distance) for different-class pairs.

References

[1] Koch, G., Zemel, R., & Salakhutdinov, R. (2015). Siamese Neural Networks for One-shot Image Recognition.

[2] Lake, B. M., et al. (2015). Human-level concept learning through probabilistic program induction.

[3] Omniglot dataset on GitHub.