Implementing Soft Actor-Critic (SAC) in PyTorch

Reinforcement Learning

Last updated: November 16, 2024

Introduction

Now that we've covered the mathematical foundations of SAC, let's implement this algorithm step by step. We'll see how the theoretical concepts translate into code, making it easier to understand the practical aspects of this powerful algorithm.

Architecture Overview

SAC consists of several key components:

  1. Two Q-Networks (Critics)
  2. A Policy Network (Actor)
  3. Target Q-Networks
  4. A Replay Buffer

Let's break down each component and understand its implementation.

1. Q-Networks Implementation

The Q-networks (critics) are essential for value estimation. We implement two Q-networks to reduce overestimation bias.

import torch
import torch.nn as nn


class QNetworks(nn.Module):
    def __init__(self, num_observations: int, num_actions: int):
        super(QNetworks, self).__init__()

        self.phi_1 = nn.Sequential(
            nn.Linear(num_observations + num_actions, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
        )

        self.phi_2 = nn.Sequential(
            nn.Linear(num_observations + num_actions, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
        )

    def forward(
        self, state: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = torch.cat([state, action], 1)

        phi_1_value = self.phi_1(x)
        phi_2_value = self.phi_2(x)

        return phi_1_value, phi_2_value

Key Points:

Soft and Hard Update

In reinforcement learning, target networks (e.g., the target Q-value networks in SAC) are used to stabilize training by providing a more consistent reference for the value estimates. These target networks are periodically updated to slowly track the weights of the main networks.

def soft_update(target: nn.Module, source: nn.Module, tau: float) -> None:
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

def hard_update(target: nn.Module, source: nn.Module) -> None:
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)

2. Policy Network Implementation

The policy network determines the agent's behavior by generating actions.

class Policy(nn.Module):
    def __init__(self, input_dim: int, num_actions: int, device: torch.device):
        super(Policy, self).__init__()
        self.noise = torch.Tensor(num_actions).to(device)

        self.policy_network = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, num_actions),
        )

        self.device = device

    def forward(self, state: torch.Tensor):
        return self.policy_network(state)

    def sample(self, state: torch.Tensor):
        # Forward pass through the network to get mean action
        action = self.forward(state)

        # Generate noise for exploration
        noise = self.noise.normal_(0.0, std=0.1)
        noise = noise.clamp(-0.25, 0.25)

        # Compute the action as mean + noise
        noisy_action = action + noise
        return noisy_action, torch.tensor(0.0), action  # Returning dummy log_prob for compatibility

    def select_action(self, state: torch.Tensor, evaluate: bool = False):
        # Convert state to a tensor and move to the correct device
        state_tensor = torch.FloatTensor(state).to(self.device).unsqueeze(0)

        # Sample action with noise if not evaluating, otherwise take the mean action
        if not evaluate:
            action, _, _ = self.sample(state_tensor)
        else:
            _, _, action = self.sample(state_tensor)

        # Return the action in numpy format
        return action.detach().cpu().numpy()[0]

Important Features:

3. Replay Buffer

The replay buffer stores experiences for off-policy learning.

from collections import deque
import random
import numpy as np
from typing import NamedTuple, Tuple, Deque
from numpy.typing import NDArray


class Transition(NamedTuple):
    state: NDArray[np.float64]
    action: NDArray[np.float64]
    reward: float
    next_state: NDArray[np.float64]
    done: int


class ReplayMemory:
    def __init__(self, capacity: int, seed: int) -> None:
        random.seed(seed)
        self.capacity: int = capacity
        self.buffer: Deque[Transition] = deque(maxlen=capacity)

    def push(
        self,
        state: NDArray[np.float64],
        action: NDArray[np.float64],
        reward: float,
        next_state: NDArray[np.float64],
        done: bool,
    ) -> None:
        """Store a transition in the replay buffer."""
        self.buffer.append(Transition(state, action, reward, next_state, int(done)))

    def sample(
        self, batch_size: int
    ) -> Tuple[
        NDArray[np.float64],
        NDArray[np.float64],
        NDArray[np.float64],
        NDArray[np.float64],
        NDArray[np.float64],
    ]:
        """Sample a batch of transitions."""
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self) -> int:
        return len(self.buffer)

Key Functionality:

4. Training: Soft Actor-Critic

Before training the SAC agent, it's essential to define the training setup and initialize key components, such as the neural networks, optimizers, loss functions, and replay buffer. Let's walk through the initialization step-by-step.

Step 1: Define Key Hyperparameters

These hyperparameters govern the agent's learning process:

# Environment dimensions
num_observations = 4  # Number of state features (observation space)
num_actions = 2       # Number of possible actions (action space)

# Hyperparameters for SAC
learning_rate = 0.001       # Learning rate for optimizers
batch_size = 256            # Number of samples per training batch
discount_factor_gamma = 0.99 # Discount factor for future rewards
temperature_alpha = 0.2     # Entropy temperature to encourage exploration
target_smoothing_tau = 0.05 # Soft update rate for target networks

Step 2: Select Computation Device

SAC can be computationally intensive, so we'll use a GPU if available:

import torch

# Select device: Use GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Step 3: Initialize the Actor (Policy) Network

The policy network decides the agent's actions. We'll define a neural network model for this.

# Initialize the policy (actor) network
policy = Policy(num_observations, num_actions, device=device)

Step 4: Initialize the Critic Networks

SAC uses two Q-value networks (critics) for stability, along with their respective target networks for training stability.

# Initialize two Q-value networks (critics)
critic = QNetworks(num_observations, num_actions)

# Initialize target networks for both critics
critic_target = QNetworks(num_observations, num_actions)

# Synchronize target networks with the main critic networks
hard_update(critic_target, critic)  # Copies parameters directly

Step 5: Set Up the Replay Buffer

A replay buffer stores past experiences for training. This allows the agent to learn from diverse experiences instead of only recent events.

# Replay buffer for experience replay
replay_buffer = ReplayMemory(capacity=100000, seed=1212)

Step 6: Define Optimizers

To optimize the neural networks, we need separate optimizers for the policy and critic networks:

from torch.optim import Adam

# Optimizer for the policy network (actor)
policy_optimizer = Adam(policy.parameters(), lr=learning_rate)

# Optimizer for the critic networks
critic_optimizer = Adam(critic.parameters(), lr=learning_rate)

Key Notes for Training

This setup ensures that your SAC agent has all the necessary components for effective training! Next, we’ll dive into the training loop.

Training Loop:

Here's how all components work together during training:

  1. Sample Batch:
# Sample and prepare batch data from replay buffer
state_batch, action_batch, reward_batch, next_state_batch, done_signal_batch = replay_buffer.sample(batch_size)

# Convert numpy arrays to PyTorch tensors and move to specified device (CPU/GPU)
state_batch = torch.FloatTensor(state_batch).to(device)          # Current states
next_state_batch = torch.FloatTensor(next_state_batch).to(device)  # Next states
action_batch = torch.FloatTensor(action_batch).to(device)        # Actions taken
reward_batch = torch.FloatTensor(reward_batch).to(device).unsqueeze(1)  # Rewards received
done_signal_batch = torch.FloatTensor(done_signal_batch).to(device).unsqueeze(1)  # Episode termination signals
 
  1. Compute Target Q-values:
# Compute target Q-values using target networks (no gradient tracking needed)
with torch.no_grad():
    # Get next actions and their log probabilities from the policy network
    next_state_action, next_state_log_pi, _ = policy.sample(next_state_batch)
    
    # Compute Q-values for next state-action pairs using target critic networks
    qf1_next_target, qf2_next_target = critic_target(next_state_batch, next_state_action)
    
    # Take minimum Q-value to reduce overestimation bias and subtract entropy term
    min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - temperature_alpha * next_state_log_pi
    
    # Compute target Q-value using Bellman equation: r + γ(1-d)(Q - α*log_pi)
    next_q_value = reward_batch + done_signal_batch * discount_factor_gamma * (min_qf_next_target)
 
  1. Update Critics:
# Compute critic loss using both Q-networks
# Get current Q-value estimates for both critics
qf1, qf2 = critic(state_batch, action_batch)  

# Compute MSE loss between current Q-values and target Q-values
# Using two critics helps reduce overestimation bias in Q-values
qf1_loss = F.mse_loss(qf1, next_q_value)  # Loss for first Q-network
qf2_loss = F.mse_loss(qf2, next_q_value)  # Loss for second Q-network
qf_loss = qf1_loss + qf2_loss              # Combined critic loss

# Update critic networks using gradient descent
critic_optimizer.zero_grad()    # Clear previous gradients
qf_loss.backward()             # Compute gradients
critic_optimizer.step()         # Update network parameters
 
  1. Update Policy:
# Sample actions for current states using the policy network
# Returns actions, log probabilities, and mean/std (unused in deterministic policy)
action_pi, log_pi, _ = policy.sample(state_batch)

# Evaluate action values using current critic networks
# Get Q-values for policy-generated actions
qf1_pi, qf2_pi = critic(state_batch, action_pi)
# Take minimum Q-value to reduce overestimation
min_qf_pi = torch.min(qf1_pi, qf2_pi)

# Compute policy loss
# Policy loss = entropy term (α * log_pi) - Q-value
# We minimize this loss to maximize expected return while maintaining entropy
policy_loss = ((temperature_alpha * log_pi) - min_qf_pi).mean()

# Update policy network using gradient descent
policy_optimizer.zero_grad()    # Clear previous gradients
policy_loss.backward()          # Compute gradients
policy_optimizer.step()         # Update network parameters

# Update target critic networks using soft update
# Slowly blend target network weights with current network weights
# target_params = τ * current_params + (1 - τ) * target_params
soft_update(critic_target, critic, target_smoothing_tau)
 

Evaluation

Monitor these metrics during training:

Remember to periodically evaluate without exploration noise to assess true performance.

 

Previous Lesson