Training a Deep Q Network using PyTorch

Reinforcement Learning

Last updated: December 15, 2024

Code

The complete code for this project can be found at the following link: Deep Q-Learning on GitHub by edreate.com.

1. Replay Memory in Deep Q-Learning 

Replay memory $D$ is a critical component in Deep Q-Learning. It stores experiences in the form of transitions $(s_t, a_t, r_t, s_{t+1})$, where:

The size of $D$ is fixed, meaning older experiences are discarded as new ones arrive:

$$
 D = \{ (s_t, a_t, r_t, s_{t+1}) \} 
$$

This helps the agent learn from both recent and varied past interactions, improving stability and efficiency. Transitions are sampled randomly during training, breaking correlations between consecutive experiences.

1a. Replay Memory Implementation

from collections import namedtuple, deque
import random

# Define the structure of a Transition tuple
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))

class ReplayMemory:
    def __init__(self, capacity: int):
        self.memory: deque[Transition] = deque([], maxlen=capacity)

    def push(self, state: float, action: int, next_state: float, reward: float) -> None:
        """Save a transition in the memory."""
        self.memory.append(Transition(state, action, next_state, reward))

    def sample(self, batch_size: int) -> list[Transition]:
        """Retrieve a random batch of transitions for training."""
        return random.sample(self.memory, batch_size)

    def __len__(self) -> int:
        """Return the current size of the memory."""
        return len(self.memory)

2. Action Selection: Epsilon-Greedy Strategy 

The epsilon-greedy strategy balances exploration (trying new actions) and exploitation (choosing the best-known action).

$$
 a_t =
\begin{cases}
\text{random action} & \text{with probability } \epsilon \\
\arg \max_a Q(s_t, a; \theta) & \text{with probability } 1 - \epsilon
\end{cases} 
$$

2a. Implementation

import math

EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1.0 * steps_done / EPS_DECAY)
    steps_done += 1

    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1).indices.view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)

3. Neural Networks for Q-Function Approximation

The Q-function is approximated by a neural network $Q(s, a; \theta)$, where:

3a. Network Architecture:

import torch.nn as nn
import torch.nn.functional as F

class DQN(nn.Module):
    def __init__(self, n_observations: int, n_actions: int):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

4. Optimizing the Model 

The optimization process minimizes the error in predicted Q-values using sampled transitions. Key steps include:

  1. Batch Sampling: Randomly sample transitions from memory.
  2. Q-value Calculation:
    • $Q(s_t, a_t)$: Predicted by the policy network.
    • $\hat{Q}(s_t, a_t)$: Target value using the Bellman equation.
  3. Loss Calculation: Smooth L1 loss ($Huber$) between predicted and target Q-values.
  4. Gradient Descent: Updates network weights.
# Define the loss function and optimizer
criterion = nn.SmoothL1Loss()
optimizer = optim.AdamW(dqn.parameters(), lr=LR, amsgrad=True)

4a. Optimization Function

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return

    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(
        tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool
    )
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = dqn(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values

    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(dqn.parameters(), 100)
    optimizer.step()


5. Main Training Loop 

The training loop iterates over episodes to:

  1. Interact with the environment.
  2. Store transitions in replay memory.
  3. Optimize the model using the sampled transitions.

Core Loop:

for i_episode in range(num_episodes):
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

    for t in count():
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        next_state = None if done else torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        memory.push(state, action, next_state, reward)
        state = next_state

        optimize_model()

        if done:
            break

Code

The complete code for this project can be found at the following link: Deep Q-Learning on GitHub by edreate.com.

Sources:

Previous Lesson