ML Katas

Soft Actor-Critic (SAC) Critic Loss

hard (<30 mins) pytorch reinforcement rl sac actor-critic
this year by E

Description

Soft Actor-Critic (SAC) is a state-of-the-art reinforcement learning algorithm known for its stability and sample efficiency. [1] A key component is its critic (or Q-network) update, which includes an entropy term to encourage exploration. Your task is to implement the loss function for the SAC critics.

Guidance

SAC uses two critics, Q1 and Q2, to reduce overestimation bias. The loss for each is the MSE between its prediction and a shared target value y. The target y is defined as:

y=r+γ(min(Qtarget1(s,a),Qtarget2(s,a))αlogπ(a|s))

Where s is the next state, a is the next action from the policy π, Qtarget are the slow-moving target networks, and α is the entropy coefficient.

Starter Code

import torch
import torch.nn.functional as F

def calculate_sac_critic_loss(critic1, critic2, target_critic1, target_critic2, policy, rewards, states, actions, next_states, gamma, alpha):
    # 1. Get the actions and log-probabilities for the *next_states* from the policy.
    next_actions, log_pi = policy.sample(next_states)

    # 2. Get the target Q-values for the next_state-action pairs from both target critics.
    target_q1_next = target_critic1(next_states, next_actions)
    target_q2_next = target_critic2(next_states, next_actions)

    # 3. Take the minimum of the two target Q-values to mitigate overestimation.
    min_target_q_next = ...

    # 4. Calculate the full target value 'y' using the formula above.
    #    Use .detach() on the target value so gradients don't flow through the target networks.
    target_y = ...

    # 5. Get the current Q-values for the original state-action pairs.
    current_q1 = critic1(states, actions)
    current_q2 = critic2(states, actions)

    # 6. Calculate the MSE loss for each critic against the target 'y'.
    loss1 = F.mse_loss(...)
    loss2 = F.mse_loss(...)

    return loss1 + loss2

Verification

Instantiate the required networks with random weights. Create dummy input tensors for states, actions, rewards, etc. Your function should return a single scalar loss tensor. Verify that calling loss.backward() populates the gradients for critic1 and critic2, but not for the policy or target critic networks.

References

[1] Haarnoja, T., et al. (2018). Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor. ICML 2018.