ML Katas

Deep Canonical Correlation Analysis (DCCA) Loss

medium (<30 mins) pytorch dcca correlation multimodal
yesterday by E

Description

Canonical Correlation Analysis (CCA) is a statistical method for finding correlations between two sets of variables. Deep CCA (DCCA) uses neural networks to first project two "views" of data (e.g., an image and its text description) into a latent space, and then maximizes the correlation between those projections. [1] Your task is to implement the DCCA loss function, which is the core of this technique.

Guidance

Given the outputs of two networks, H1 and H2, the goal is to maximize their correlation. The loss function is derived from the correlation calculation: 1. Center the outputs H1 and H2 by subtracting their means. 2. Compute the covariance matrix of the concatenated, centered outputs. 3. From this large covariance matrix, extract the cross-covariance matrix T=H2TH1. 4. The sum of the singular values of T is the correlation. Therefore, your loss function should be the negative sum of these singular values, as optimizers minimize loss.

Starter Code

import torch

def dcca_loss(H1, H2, epsilon=1e-5):
    """Calculates the DCCA loss.
    H1, H2: (batch_size, out_dim) outputs of the two networks.
    """
    # 1. Center the features by subtracting the mean of each column.
    H1_bar = H1 - H1.mean(dim=0)
    H2_bar = H2 - H2.mean(dim=0)

    # 2. Compute the covariance matrices for each view, and the cross-covariance.
    #    Sigma11, Sigma22, Sigma12

    # 3. An alternative, simpler formulation: compute T = H1_bar.T @ H2_bar

    # 4. Compute the singular values of T. <!--CODE_BLOCK_2961--> can be used.

    # 5. The loss is the negative sum of the singular values.
    #    A high correlation corresponds to a low loss.
    pass

Verification

Create two random tensors, H1 and H2. The loss should be some value. Now, create a new tensor H3 that is highly correlated with H1 (e.g., H3 = H1 + torch.randn_like(H1) * 0.01). The result of dcca_loss(H1, H3) should be significantly lower (more negative) than dcca_loss(H1, H2), because the correlation is higher.

References

[1] Andrew, G., et al. (2013). Deep Canonical Correlation Analysis. ICML 2013.