Deep Canonical Correlation Analysis (DCCA) Loss
Description
Canonical Correlation Analysis (CCA) is a statistical method for finding correlations between two sets of variables. Deep CCA (DCCA) uses neural networks to first project two "views" of data (e.g., an image and its text description) into a latent space, and then maximizes the correlation between those projections. [1] Your task is to implement the DCCA loss function, which is the core of this technique.
Guidance
Given the outputs of two networks, and , the goal is to maximize their correlation. The loss function is derived from the correlation calculation: 1. Center the outputs and by subtracting their means. 2. Compute the covariance matrix of the concatenated, centered outputs. 3. From this large covariance matrix, extract the cross-covariance matrix . 4. The sum of the singular values of is the correlation. Therefore, your loss function should be the negative sum of these singular values, as optimizers minimize loss.
Starter Code
import torch
def dcca_loss(H1, H2, epsilon=1e-5):
"""Calculates the DCCA loss.
H1, H2: (batch_size, out_dim) outputs of the two networks.
"""
# 1. Center the features by subtracting the mean of each column.
H1_bar = H1 - H1.mean(dim=0)
H2_bar = H2 - H2.mean(dim=0)
# 2. Compute the covariance matrices for each view, and the cross-covariance.
# Sigma11, Sigma22, Sigma12
# 3. An alternative, simpler formulation: compute T = H1_bar.T @ H2_bar
# 4. Compute the singular values of T. <!--CODE_BLOCK_2961--> can be used.
# 5. The loss is the negative sum of the singular values.
# A high correlation corresponds to a low loss.
pass
Verification
Create two random tensors, H1
and H2
. The loss should be some value. Now, create a new tensor H3
that is highly correlated with H1
(e.g., H3 = H1 + torch.randn_like(H1) * 0.01
). The result of dcca_loss(H1, H3)
should be significantly lower (more negative) than dcca_loss(H1, H2)
, because the correlation is higher.
References
[1] Andrew, G., et al. (2013). Deep Canonical Correlation Analysis. ICML 2013.