Soft Actor-Critic (SAC) Critic Loss
Description
Soft Actor-Critic (SAC) is a state-of-the-art reinforcement learning algorithm known for its stability and sample efficiency. [1] A key component is its critic (or Q-network) update, which includes an entropy term to encourage exploration. Your task is to implement the loss function for the SAC critics.
Guidance
SAC uses two critics, Q1 and Q2, to reduce overestimation bias. The loss for each is the MSE between its prediction and a shared target value y
. The target y
is defined as:
Where is the next state, is the next action from the policy , are the slow-moving target networks, and is the entropy coefficient.
Starter Code
import torch
import torch.nn.functional as F
def calculate_sac_critic_loss(critic1, critic2, target_critic1, target_critic2, policy, rewards, states, actions, next_states, gamma, alpha):
# 1. Get the actions and log-probabilities for the *next_states* from the policy.
next_actions, log_pi = policy.sample(next_states)
# 2. Get the target Q-values for the next_state-action pairs from both target critics.
target_q1_next = target_critic1(next_states, next_actions)
target_q2_next = target_critic2(next_states, next_actions)
# 3. Take the minimum of the two target Q-values to mitigate overestimation.
min_target_q_next = ...
# 4. Calculate the full target value 'y' using the formula above.
# Use .detach() on the target value so gradients don't flow through the target networks.
target_y = ...
# 5. Get the current Q-values for the original state-action pairs.
current_q1 = critic1(states, actions)
current_q2 = critic2(states, actions)
# 6. Calculate the MSE loss for each critic against the target 'y'.
loss1 = F.mse_loss(...)
loss2 = F.mse_loss(...)
return loss1 + loss2
Verification
Instantiate the required networks with random weights. Create dummy input tensors for states, actions, rewards, etc. Your function should return a single scalar loss tensor. Verify that calling loss.backward()
populates the gradients for critic1
and critic2
, but not for the policy or target critic networks.
References
[1] Haarnoja, T., et al. (2018). Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor. ICML 2018.