Mastering CartPole with DQN: Deep Reinforcement Learning for Beginners

If you’ve played with reinforcement learning (RL) before, you’ve probably seen the classic CartPole balancing problem. And if you’ve tried solving it with traditional Q-learning, you might have run into some limitations.

That’s where DQN — Deep Q-Network — comes in.

In this guide, we’ll explain what DQN is, why it was a breakthrough in RL, and how to implement it step-by-step to solve the CartPole-v1 environment using OpenAI Gym and PyTorch. Whether you’re new to RL or ready to level up from Q-tables, this tutorial is for you.


1. What is DQN?

Q-Learning works well for problems with small, discrete state spaces. But in the real world — or even a simple simulation like CartPole — the state is continuous, and creating a Q-table for every possible state is infeasible.

DQN solves this by using a neural network to approximate the Q-function. Instead of a table, the network learns to predict the expected reward for each action, given a state.

DQN = Q-Learning + Deep Learning

Component Purpose
Neural Network Predict Q-values for each action
Replay Buffer Store past experiences
Target Network Improve stability
ε-greedy Policy Balance exploration vs. exploitation

This combination enables DQN to scale to more complex environments — including Atari games, robotics, and beyond.


2. Recap: The CartPole Problem

In CartPole, your agent controls a cart with a pole attached to it. The goal? Keep the pole from falling over by moving the cart left or right.

Environment Details:

  • State Space: 4 floating point values (position, velocity, pole angle, angular velocity)
  • Action Space: 0 (left), 1 (right)
  • Reward: +1 for every time step the pole remains upright
  • Done: When the pole falls beyond a threshold angle, or the cart moves too far from center

It’s a great starting point for reinforcement learning.


3. Why Q-Learning Isn’t Enough

Traditional Q-Learning relies on a Q-table that maps state-action pairs to expected rewards. That works for games like FrozenLake or GridWorld, but fails when:

  • States are continuous
  • The environment has high dimensionality
  • We want to generalize across unseen states

DQN overcomes these by using a function approximator — a neural net — to estimate Q-values, enabling RL to move beyond toy problems.


4. DQN Components Explained

Let’s break down what you’ll need to build a working DQN agent.

1. Neural Network

The core of DQN is a network that takes in a state and outputs Q-values for all possible actions.

import torch
import torch.nn as nn
import torch.nn.functional as F

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.out = nn.Linear(128, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

2. Experience Replay Buffer

Stores past experiences and samples them randomly to break temporal correlation.

from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

3. Epsilon-Greedy Action Selection

The agent explores randomly at first, then gradually exploits what it has learned.

def select_action(state, epsilon, policy_net):
    if random.random() < epsilon:
        return random.randint(0, 1)
    with torch.no_grad():
        return policy_net(state).argmax().item()

5. Training DQN on CartPole

Step-by-Step Loop:

import gym
import numpy as np
import torch.optim as optim

env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

policy_net = DQN(state_dim, action_dim)
target_net = DQN(state_dim, action_dim)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)
replay_buffer = ReplayBuffer(10000)

batch_size = 64
gamma = 0.99
epsilon = 1.0

for episode in range(300):
    state = torch.FloatTensor(env.reset())
    total_reward = 0

    for t in range(500):
        action = select_action(state, epsilon, policy_net)
        next_state, reward, done, _ = env.step(action)
        next_state = torch.FloatTensor(next_state)

        replay_buffer.push((state, action, reward, next_state, done))
        state = next_state
        total_reward += reward

        if len(replay_buffer) >= batch_size:
            transitions = replay_buffer.sample(batch_size)
            s, a, r, ns, d = zip(*transitions)

            s = torch.stack(s)
            a = torch.LongTensor(a).unsqueeze(1)
            r = torch.FloatTensor(r).unsqueeze(1)
            ns = torch.stack(ns)
            d = torch.FloatTensor(d).unsqueeze(1)

            q_values = policy_net(s).gather(1, a)
            next_q = target_net(ns).max(1)[0].unsqueeze(1).detach()
            target = r + gamma * next_q * (1 - d)

            loss = F.mse_loss(q_values, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if done:
            break

    epsilon = max(0.01, epsilon * 0.995)

    if episode % 10 == 0:
        target_net.load_state_dict(policy_net.state_dict())

    print(f"Episode {episode}, Total Reward: {total_reward}")

6. Observations and Performance

As training progresses:

  • The total reward per episode increases
  • The agent starts keeping the pole upright for longer durations
  • Eventually, it consistently reaches the max score of 200

This is a solid indicator that the DQN is learning to solve the task effectively.


7. DQN Limitations and Next Steps

While DQN is powerful, it’s not perfect:

  • It can be unstable or divergent without tricks
  • It struggles with continuous action spaces
  • It treats all experiences equally in the replay buffer

Enhancements (aka “Better DQN”):

  • Double DQN: Reduces overestimation bias
  • Dueling DQN: Separates state value and action advantage
  • Prioritized Experience Replay: Focus on important transitions
  • Rainbow DQN: Combines all of the above

We’ll explore these in future posts.


8. Conclusion

You’ve just implemented your first Deep Q-Network from scratch and trained it to solve CartPole. This is a big step toward mastering reinforcement learning with deep learning.

By understanding both the code and the concepts, you’re ready to explore more complex environments and powerful variants of DQN.