Implementing Soft Actor-Critic (SAC) in PyTorch
Last updated: November 16, 2024
Introduction
Now that we've covered the mathematical foundations of SAC, let's implement this algorithm step by step. We'll see how the theoretical concepts translate into code, making it easier to understand the practical aspects of this powerful algorithm.
Architecture Overview
SAC consists of several key components:
- Two Q-Networks (Critics)
- A Policy Network (Actor)
- Target Q-Networks
- A Replay Buffer
Let's break down each component and understand its implementation.
1. Q-Networks Implementation
The Q-networks (critics) are essential for value estimation. We implement two Q-networks to reduce overestimation bias.
import torch
import torch.nn as nn
class QNetworks(nn.Module):
def __init__(self, num_observations: int, num_actions: int):
super(QNetworks, self).__init__()
self.phi_1 = nn.Sequential(
nn.Linear(num_observations + num_actions, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1),
)
self.phi_2 = nn.Sequential(
nn.Linear(num_observations + num_actions, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1),
)
def forward(
self, state: torch.Tensor, action: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
x = torch.cat([state, action], 1)
phi_1_value = self.phi_1(x)
phi_2_value = self.phi_2(x)
return phi_1_value, phi_2_value
Key Points:
- Both networks have identical architecture but different parameters
- Input concatenates state and action
- Output is a single Q-value
Soft and Hard Update
In reinforcement learning, target networks (e.g., the target Q-value networks in SAC) are used to stabilize training by providing a more consistent reference for the value estimates. These target networks are periodically updated to slowly track the weights of the main networks.
- Hard Update: Directly copies the weights from the source network to the target network.
- Soft Update: Gradually updates the target network's weights as a weighted average of its current weights and the source network's weights. This smooth transition helps stabilize training.
def soft_update(target: nn.Module, source: nn.Module, tau: float) -> None:
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
def hard_update(target: nn.Module, source: nn.Module) -> None:
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
2. Policy Network Implementation
The policy network determines the agent's behavior by generating actions.
class Policy(nn.Module):
def __init__(self, input_dim: int, num_actions: int, device: torch.device):
super(Policy, self).__init__()
self.noise = torch.Tensor(num_actions).to(device)
self.policy_network = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, num_actions),
)
self.device = device
def forward(self, state: torch.Tensor):
return self.policy_network(state)
def sample(self, state: torch.Tensor):
# Forward pass through the network to get mean action
action = self.forward(state)
# Generate noise for exploration
noise = self.noise.normal_(0.0, std=0.1)
noise = noise.clamp(-0.25, 0.25)
# Compute the action as mean + noise
noisy_action = action + noise
return noisy_action, torch.tensor(0.0), action # Returning dummy log_prob for compatibility
def select_action(self, state: torch.Tensor, evaluate: bool = False):
# Convert state to a tensor and move to the correct device
state_tensor = torch.FloatTensor(state).to(self.device).unsqueeze(0)
# Sample action with noise if not evaluating, otherwise take the mean action
if not evaluate:
action, _, _ = self.sample(state_tensor)
else:
_, _, action = self.sample(state_tensor)
# Return the action in numpy format
return action.detach().cpu().numpy()[0]
Important Features:
- Generates continuous actions
- Includes exploration noise
- Returns both noisy actions (for training) and mean actions (for evaluation)
3. Replay Buffer
The replay buffer stores experiences for off-policy learning.
from collections import deque
import random
import numpy as np
from typing import NamedTuple, Tuple, Deque
from numpy.typing import NDArray
class Transition(NamedTuple):
state: NDArray[np.float64]
action: NDArray[np.float64]
reward: float
next_state: NDArray[np.float64]
done: int
class ReplayMemory:
def __init__(self, capacity: int, seed: int) -> None:
random.seed(seed)
self.capacity: int = capacity
self.buffer: Deque[Transition] = deque(maxlen=capacity)
def push(
self,
state: NDArray[np.float64],
action: NDArray[np.float64],
reward: float,
next_state: NDArray[np.float64],
done: bool,
) -> None:
"""Store a transition in the replay buffer."""
self.buffer.append(Transition(state, action, reward, next_state, int(done)))
def sample(
self, batch_size: int
) -> Tuple[
NDArray[np.float64],
NDArray[np.float64],
NDArray[np.float64],
NDArray[np.float64],
NDArray[np.float64],
]:
"""Sample a batch of transitions."""
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch))
return state, action, reward, next_state, done
def __len__(self) -> int:
return len(self.buffer)
Key Functionality:
- Stores transitions (state, action, reward, next_state, done)
- Provides random sampling for batch updates
- Fixed capacity with FIFO behavior
4. Training: Soft Actor-Critic
Before training the SAC agent, it's essential to define the training setup and initialize key components, such as the neural networks, optimizers, loss functions, and replay buffer. Let's walk through the initialization step-by-step.
Step 1: Define Key Hyperparameters
These hyperparameters govern the agent's learning process:
Step 2: Select Computation Device
SAC can be computationally intensive, so we'll use a GPU if available:
Step 3: Initialize the Actor (Policy) Network
The policy network decides the agent's actions. We'll define a neural network model for this.
# Initialize the policy (actor) network
policy = Policy(num_observations, num_actions, device=device)
Step 4: Initialize the Critic Networks
SAC uses two Q-value networks (critics) for stability, along with their respective target networks for training stability.
Step 5: Set Up the Replay Buffer
A replay buffer stores past experiences for training. This allows the agent to learn from diverse experiences instead of only recent events.
# Replay buffer for experience replay
replay_buffer = ReplayMemory(capacity=100000, seed=1212)
Step 6: Define Optimizers
To optimize the neural networks, we need separate optimizers for the policy and critic networks:
from torch.optim import Adam
# Optimizer for the policy network (actor)
policy_optimizer = Adam(policy.parameters(), lr=learning_rate)
# Optimizer for the critic networks
critic_optimizer = Adam(critic.parameters(), lr=learning_rate)
Key Notes for Training
- Actor-Critic Framework: The policy (actor) network learns to select actions, while the critic networks evaluate the quality of these actions.
- Target Networks: Using slowly updated target networks improves training stability by providing consistent Q-value targets.
- Replay Buffer: Enables more stable learning by breaking correlations in the training data.
This setup ensures that your SAC agent has all the necessary components for effective training! Next, we’ll dive into the training loop.
Training Loop:
Here's how all components work together during training:
- Sample Batch:
# Sample and prepare batch data from replay buffer
state_batch, action_batch, reward_batch, next_state_batch, done_signal_batch = replay_buffer.sample(batch_size)
# Convert numpy arrays to PyTorch tensors and move to specified device (CPU/GPU)
state_batch = torch.FloatTensor(state_batch).to(device) # Current states
next_state_batch = torch.FloatTensor(next_state_batch).to(device) # Next states
action_batch = torch.FloatTensor(action_batch).to(device) # Actions taken
reward_batch = torch.FloatTensor(reward_batch).to(device).unsqueeze(1) # Rewards received
done_signal_batch = torch.FloatTensor(done_signal_batch).to(device).unsqueeze(1) # Episode termination signals
- Compute Target Q-values:
# Compute target Q-values using target networks (no gradient tracking needed)
with torch.no_grad():
# Get next actions and their log probabilities from the policy network
next_state_action, next_state_log_pi, _ = policy.sample(next_state_batch)
# Compute Q-values for next state-action pairs using target critic networks
qf1_next_target, qf2_next_target = critic_target(next_state_batch, next_state_action)
# Take minimum Q-value to reduce overestimation bias and subtract entropy term
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - temperature_alpha * next_state_log_pi
# Compute target Q-value using Bellman equation: r + γ(1-d)(Q - α*log_pi)
next_q_value = reward_batch + done_signal_batch * discount_factor_gamma * (min_qf_next_target)
- Update Critics:
# Compute critic loss using both Q-networks
# Get current Q-value estimates for both critics
qf1, qf2 = critic(state_batch, action_batch)
# Compute MSE loss between current Q-values and target Q-values
# Using two critics helps reduce overestimation bias in Q-values
qf1_loss = F.mse_loss(qf1, next_q_value) # Loss for first Q-network
qf2_loss = F.mse_loss(qf2, next_q_value) # Loss for second Q-network
qf_loss = qf1_loss + qf2_loss # Combined critic loss
# Update critic networks using gradient descent
critic_optimizer.zero_grad() # Clear previous gradients
qf_loss.backward() # Compute gradients
critic_optimizer.step() # Update network parameters
- Update Policy:
# Sample actions for current states using the policy network
# Returns actions, log probabilities, and mean/std (unused in deterministic policy)
action_pi, log_pi, _ = policy.sample(state_batch)
# Evaluate action values using current critic networks
# Get Q-values for policy-generated actions
qf1_pi, qf2_pi = critic(state_batch, action_pi)
# Take minimum Q-value to reduce overestimation
min_qf_pi = torch.min(qf1_pi, qf2_pi)
# Compute policy loss
# Policy loss = entropy term (α * log_pi) - Q-value
# We minimize this loss to maximize expected return while maintaining entropy
policy_loss = ((temperature_alpha * log_pi) - min_qf_pi).mean()
# Update policy network using gradient descent
policy_optimizer.zero_grad() # Clear previous gradients
policy_loss.backward() # Compute gradients
policy_optimizer.step() # Update network parameters
# Update target critic networks using soft update
# Slowly blend target network weights with current network weights
# target_params = τ * current_params + (1 - τ) * target_params
soft_update(critic_target, critic, target_smoothing_tau)
Evaluation
Monitor these metrics during training:
- Average episode return
- Q-value estimates
- Policy entropy
- Actor and critic losses
Remember to periodically evaluate without exploration noise to assess true performance.