No items found.

Training a Reinforcement Learning Agent to Balance an Inverted Pendulum

Training a Reinforcement Learning Agent to Balance an Inverted PendulumTraining a Reinforcement Learning Agent to Balance an Inverted Pendulum

Reinforcement learning (RL) is a general approach to training controllers for complex systems with limited observability and controllability, and possibly unknown dynamics. When RL "agents" (feedback control laws) are trained using a simulator, the algorithm typically treats the simulation as a black box, and learns to control the system by trial and error. The agent receives a reward signal from the simulator, and uses this signal to update its control policy. The goal is to learn a policy that maximizes the expected sum of rewards over time.

This approach splits the computation into two main parts: the simulation, which computes the system dynamics, and the agent, which learns to control the system. These are typically implemented in different frameworks; for instance, a simulation might be written in C++ or vanilla Python, while the agent and training loop might be implemented in TensorFlow or PyTorch.

Besides implementation complexity, the main drawback of this approach is that it can create computational bottlenecks as the agent and simulation communicate across frameworks, data structures, and potentially even hardware (e.g., CPU vs. GPU). This can slow down training and make it difficult to scale to more complex systems.

Brax: end-to-end training with JAX

One way to address these challenges is to use a single framework for both the simulation and the agent. This is the approach taken by Brax, an end-to-end framework for training RL agents with JAX. Brax is designed to be fast, flexible, and scalable, and it provides a simple API for defining custom environments and agents.

Since it uses the JAX framework, both the RL and simulation code can be written in Python and JIT-compiled for efficiently. Moreover, it is straightforward to target accelerators like GPUs and TPUs, and to scale up to large clusters of machines. With this approach the entire simulation, agent, and training loop can run on the same hardware and does not need to communicate across any computational barriers.


The original paper on Brax showed that a controller for the "Ant" robotic locomotion benchmark can be trained in about 10 seconds using this approach, compared to approximately an hour using a traditional RL simulator configuration (MuJoCo). This speedup is achieved by taking advantage of JAX's JIT compilation, parallelization, and hardware acceleration.

Beyond rigid body dynamics

While Brax has demonstrated impressive performance on the typical robotics-inspired reinforcment learning benchmarks, its focus has primarily been on that domain: rigid body mechanics with contact. However, since the Collimator simulation engine is written in JAX, we can easily package Collimator block-diagram system models as Brax reinforcement learning environments and take advantage of these end-to-end training capabilities.

This brings the power of reinforcment learning to a broader category of systems, including those with any combination of:

  • Hybrid dynamics, i.e. a mixture of continuous dynamics and discrete-time logic
  • Modelica-style acausal physical modeling (electrical, mechanical, hydraulic, etc.)
  • Finite state machines defining control logic or mode-switching behavior
  • Custom user-defined Python code

Besides flexibility in modeling the physical system, this also allows for developing advanced "hybrid" control strategies that combine more traditional controllers (e.g. state machines) with learned policies.

Controlling an inverted pendulum

In this notebook we will explore the process of training an RL agent to balance the "CartPole" inverted pendulum system. We will implement the plant in straightforward custom Python code, and then see how to package it as a Brax environment. We will then train a PPO agent to balance the pendulum using the Brax API. Because of the flexibility of targeting accelerators, the same training code can be run on a CPU, GPU, or TPU with minor modifications.


from functools import partial
from datetime import datetime

import matplotlib.pyplot as plt

import numpy as np
import jax
import jax.numpy as jnp
from brax.training.agents.ppo import train as ppo

# For 3D rendering
import mujoco
import mediapy

import collimator
from collimator import library
from collimator.optimization import RLEnv

from IPython.display import clear_output


Defining the dynamics model

First we will define the dynamics of the CartPole system. This is a simple system with two states: the position of the cart along a frictionless track, and the angle of the pole. The control input is a force applied to the cart. The dynamics are governed by the equations of motion for the cart and pole, which are typically derived using Lagrangian mechanics. This is actually implemented as a built-in model in Collimator, but we will implement it from scratch here for the sake of illustration.

In general, the most convenient way to define the dynamics of a system is to use a block diagram representation. This allows us to easily visualize the system, and to break it down into smaller components that can be implemented separately. However, since the CartPole system is so simple, we will implement the dynamics as a single custom block. For more information on custom block implementation, see the associated tutorial.

class CartPole(collimator.LeafSystem):
    def __init__(
        self,
        x0=jnp.zeros(4),
        m_c=1.0,
        m_p=0.1,
        L=1.0,
        g=9.81,
        name="CartPole",
    ):
        super().__init__(name=name)
        self.declare_dynamic_parameter("m_c", m_c)
        self.declare_dynamic_parameter("m_p", m_p)
        self.declare_dynamic_parameter("L", L)
        self.declare_dynamic_parameter("g", g)

        self.declare_input_port(name="fx")
        self.declare_continuous_state(default_value=x0, ode=self.ode)
        self.declare_continuous_state_output()

    def ode(self, time, state, *inputs, **parameters):
        x, theta, dot_x, dot_theta = state.continuous_state
        (fx,) = inputs

        m_c = parameters["m_c"]
        m_p = parameters["m_p"]
        L = parameters["L"]
        g = parameters["g"]

        mf = 1.0 / (m_c + m_p * jnp.sin(theta) ** 2)
        ddot_x = mf * (
            fx + m_p * jnp.sin(theta) * (L * dot_theta**2 + g * jnp.cos(theta))
        )
        ddot_theta = (1.0 / L) * mf * (
            -fx * jnp.cos(theta)
            - m_p * L * dot_theta**2 * jnp.cos(theta) * jnp.sin(theta)
            - (m_c + m_p) * g * jnp.sin(theta)
        )

        return jnp.array([dot_x, dot_theta, ddot_x[0], ddot_theta[0]])


Now we can simulate the system on its own using the [.code]simulate[.code] interface. Again, this works in much the same way for hierarchical block diagram representations, including those with state machines, mixed discrete/continuous dynamics, acausal physical components, etc.

system = CartPole(x0=np.array([0.0, np.pi/4, 0.0, 0.0]))
system.input_ports[0].fix_value(np.array([0.0]))  # No input

context = system.create_context()

recorded_signals = {"x": system.output_ports[0]}
results = collimator.simulate(
    system, context, (0.0, 10.0), recorded_signals=recorded_signals
)

plt.figure(figsize=(7, 2))
plt.plot(results.time, results.outputs["x"][:, :2], label=[r"$x$", r"$\theta$"])
plt.legend(loc=1)
plt.xlabel("Time [s]")
plt.ylabel("State")
plt.grid()
plt.show()
11:05:02.214 - [collimator][INFO]: max_major_steps=200 by default since no discrete period in system
11:05:02.223 - [collimator][INFO]: Simulator ready to start: SimulatorOptions(math_backend=jax, enable_tracing=True, max_major_step_length=None, max_major_steps=200, ode_solver_method=auto, rtol=1e-06, atol=1e-08, min_minor_step_size=None, max_minor_step_size=None, zc_bisection_loop_count=40, save_time_series=True, recorded_signals=1, return_context=True), Dopri5Solver(system=CartPole(system_id=1, name='CartPole', ui_id=None, parent=None), rtol=1e-06, atol=1e-08, max_step_size=None, min_step_size=None, method='auto', enable_autodiff=False, supports_mass_matrix=False)

Implementing the reinforcment learning environment

Our goal is to learn a feedback control law that can keep the pendulum in the unstable vertical position. To do this, we will train a reinforcment learning "agent" consisting of a neural network that will produce an appropriate action given the observations of the system.

Training this agent requires defining an "environment" that simulates the system and provides the agent with observations and rewards. Additionally, the environment can optionally implement functions that randomize the initial state of the system on reset and define early termination conditions.

In Brax, environments are defined as Python classes that inherit from [.code]brax.envs.Env[.code]. This is wrapped in a [.code]RLEnv[.code] class in Collimator that takes care of various bookkeeping tasks, such as resetting the environment and interfacing with the neural network. Here we will implement a custom [.code]CartPoleEnv[.code] class that inherits from [.code]RLEnv[.code] and defines the stabilization problem.

Specifically, the [.code]CartPoleEnv[.code] will do implement three methods (closely following the benchmark Gymnasium implementation):

  1. [.code]randomize[.code]: randomly perturb the initial state of the system on reset
  2. [.code]get_reward[.code]: return 1.0 for as long as the episode continues, rewarding the agent for keeping the pendulum upright
  3. [.code]get_done[.code]: return True if the pole falls below a certain angle, terminating the episode

class CartPoleEnv(RLEnv):
    def __init__(self, dt):
        plant = CartPole(name="plant")
        self.xr = np.array([0.0, np.pi, 0.0, 0.0])
        super().__init__(plant, act_size=1, dt=dt)
    
    def randomize(self, context, key):
        x0 = self.xr + jax.random.uniform(key, (4,), minval=-0.05, maxval=0.05)
        plant_context = context[self._plant_id].with_continuous_state(x0)
        return context.with_subcontext(self._plant_id, plant_context)

    def get_reward(self, context, obs, act):
        # Reward the time spent in the upright position
        return 1.0
    
    def get_done(self, context, obs):
        x = context[self._plant_id].continuous_state
        theta_lim = 12 * 2 * np.pi / 360
        x_lim = 2.4
        return jnp.where(
            (jnp.abs(x[0]) > x_lim) | (jnp.abs(x[1] - np.pi) > theta_lim),
            1.0,
            0.0,
        )
 

Now we can interact with this environment using an API very similar to the widely-used OpenAI/Farama Gym. This includes functions like [.code]reset[.code], [.code]step[.code], and [.code]render[.code], which allow us to interact with the environment in a standard way.

For example, the following code will create an environment, reset it, and take ten steps using random actions:

env = CartPoleEnv(dt=0.02)
state = env.reset(jax.random.PRNGKey(0))

for i in range(10):
    action = np.random.randn(1)
    state = env.step(state, action)
    print(state.obs)
11:05:03.130 - [collimator][INFO]: max_major_steps=200 by default since no discrete period in system
11:05:03.130 - [collimator][INFO]: Simulator ready to start: SimulatorOptions(math_backend=jax, enable_tracing=True, max_major_step_length=None, max_major_steps=200, ode_solver_method=auto, rtol=1e-06, atol=1e-08, min_minor_step_size=None, max_minor_step_size=None, zc_bisection_loop_count=40, save_time_series=False, recorded_signals=0, return_context=True), Dopri5Solver(system=Diagram(root, 2 nodes), rtol=1e-06, atol=1e-08, max_step_size=None, min_step_size=None, method='auto', enable_autodiff=False, supports_mass_matrix=False)
11:05:03.419 - [collimator][INFO]: max_major_steps=200 by default since no discrete period in system
11:05:03.420 - [collimator][INFO]: Simulator ready to start: SimulatorOptions(math_backend=jax, enable_tracing=True, max_major_step_length=None, max_major_steps=200, ode_solver_method=auto, rtol=1e-06, atol=1e-08, min_minor_step_size=None, max_minor_step_size=None, zc_bisection_loop_count=40, save_time_series=False, recorded_signals=0, return_context=True), Dopri5Solver(system=Diagram(root, 2 nodes), rtol=1e-06, atol=1e-08, max_step_size=None, min_step_size=None, method='auto', enable_autodiff=False, supports_mass_matrix=False)
[ 3.01805607e-03  3.12752256e+00 -5.81476639e-03  5.12125317e-02]
[ 2.71882296e-03  3.12833688e+00 -2.41058975e-02  3.02493694e-02]
[ 2.21236367e-03  3.12889190e+00 -2.65382197e-02  2.52726160e-02]
[ 1.65409237e-03  3.12934525e+00 -2.92874296e-02  2.00779675e-02]
[ 8.65986553e-04  3.12952062e+00 -4.95225823e-02 -2.53398332e-03]
[-6.61936745e-04  3.12890861e+00 -1.03271720e-01 -5.86892133e-02]
[-2.65371799e-03  3.12778284e+00 -9.59100888e-02 -5.39287569e-02]
[-0.0048979   3.1263504  -0.12851257 -0.08936614]
[-0.00761922  3.12438089 -0.14362629 -0.1076555 ]
[-0.01038208  3.12230228 -0.13266637 -0.10028048]

Now we can train an agent using any of Brax's reference algorithms, including PPO and SAC. Here we will use PPO, which is a simple and effective algorithm for training agents in continuous action spaces.

The following code for the training loop closely follows the Brax tutorial for training. We'll start with a very short training run using only 100 timesteps; this will let us see the initial behavior of the agent before properly training it.

state = env.reset(jax.random.PRNGKey(0))

# Based on "inverted pendulum" hyperparameters from brax tutorials
tf = 20.0
hparams = {
    "num_timesteps": 100,
    "reward_scaling": 10,
    "num_evals": 20,
    "episode_length": int(tf / env.dt),
    "normalize_observations": True,
    "action_repeat": 1,
    "unroll_length": 5,
    "num_minibatches": 1,
    "num_updates_per_batch": 4,
    "discounting": 0.97,
    "deterministic_eval": True,
    "learning_rate": 3e-4,
    "entropy_cost": 1e-2,
    "num_envs": 1,
    "batch_size": 1,
    "seed": 1,
}

times = [datetime.now()]
def progress(num_steps, metrics):
    times.append(datetime.now())

train_fn = partial(ppo.train, **hparams)
make_inference_fn, params, metrics = train_fn(environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')
11:05:06.532 - [collimator][INFO]: max_major_steps=200 by default since no discrete period in system
INFO:collimator:max_major_steps=200 by default since no discrete period in system
11:05:06.533 - [collimator][INFO]: Simulator ready to start: SimulatorOptions(math_backend=jax, enable_tracing=True, max_major_step_length=None, max_major_steps=200, ode_solver_method=auto, rtol=1e-06, atol=1e-08, min_minor_step_size=None, max_minor_step_size=None, zc_bisection_loop_count=40, save_time_series=False, recorded_signals=0, return_context=True), Dopri5Solver(system=Diagram(root, 2 nodes), rtol=1e-06, atol=1e-08, max_step_size=None, min_step_size=None, method='auto', enable_autodiff=False, supports_mass_matrix=False)
INFO:collimator:Simulator ready to start: SimulatorOptions(math_backend=jax, enable_tracing=True, max_major_step_length=None, max_major_steps=200, ode_solver_method=auto, rtol=1e-06, atol=1e-08, min_minor_step_size=None, max_minor_step_size=None, zc_bisection_loop_count=40, save_time_series=False, recorded_signals=0, return_context=True), Dopri5Solver(system=Diagram(root, 2 nodes), rtol=1e-06, atol=1e-08, max_step_size=None, min_step_size=None, method='auto', enable_autodiff=False, supports_mass_matrix=False)
11:05:09.689 - [collimator][INFO]: max_major_steps=200 by default since no discrete period in system
INFO:collimator:max_major_steps=200 by default since no discrete period in system
11:05:09.690 - [collimator][INFO]: Simulator ready to start: SimulatorOptions(math_backend=jax, enable_tracing=True, max_major_step_length=None, max_major_steps=200, ode_solver_method=auto, rtol=1e-06, atol=1e-08, min_minor_step_size=None, max_minor_step_size=None, zc_bisection_loop_count=40, save_time_series=False, recorded_signals=0, return_context=True), Dopri5Solver(system=Diagram(root, 2 nodes), rtol=1e-06, atol=1e-08, max_step_size=None, min_step_size=None, method='auto', enable_autodiff=False, supports_mass_matrix=False)
INFO:collimator:Simulator ready to start: SimulatorOptions(math_backend=jax, enable_tracing=True, max_major_step_length=None, max_major_steps=200, ode_solver_method=auto, rtol=1e-06, atol=1e-08, min_minor_step_size=None, max_minor_step_size=None, zc_bisection_loop_count=40, save_time_series=False, recorded_signals=0, return_context=True), Dopri5Solver(system=Diagram(root, 2 nodes), rtol=1e-06, atol=1e-08, max_step_size=None, min_step_size=None, method='auto', enable_autodiff=False, supports_mass_matrix=False)
time to jit: 0:00:04.082039
time to train: 0:00:40.833354

Next let's see how the agent performs after this very short training run.

inference_fn = make_inference_fn(params)

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)

rollout = []
rng = jax.random.PRNGKey(seed=0)
state = jit_env_reset(rng=rng)
for i in range(500):
    rollout.append(state.obs)
    act_rng, rng = jax.random.split(rng)
    act, done = jit_inference_fn(state.obs, act_rng)
    state = jit_env_step(state, act)
    if state.done:
        state = jit_env_reset(rng)


t = np.arange(len(rollout)) * env.dt
x = np.array(rollout)

plt.figure(figsize=(7, 2))
plt.plot(t, x)
plt.grid()
plt.show()


We're not using MuJoCo as a simulation backend here, but we can still take advantage of its 3D rendering capabilities to create an animation of the rollout.

To do this, we just have to create an XML file defining the system in MuJoCo format (without needing actuators, sensors, etc.). Here we can borrow from DeepMind's CartPole model, modified slightly to match our system.

mjmodel = mujoco.MjModel.from_xml_path("cartpole.xml")
mjdata = mujoco.MjData(mjmodel)
mujoco.mj_resetDataKeyframe(
    mjmodel, mjdata, key=mjmodel.keyframe("hanging_down").id
)
mujoco.mj_forward(mjmodel, mjdata)

renderer = mujoco.Renderer(mjmodel)
renderer.update_scene(mjdata)
frame = renderer.render()

frames = np.zeros((len(t), *frame.shape), dtype=np.uint8)
for i, q in enumerate(x):
    # Only need to change the joint positions, since we're not doing any
    # dynamics simulation.  But we do need to call the forward kinematics
    # to propagate the joint positions to the Cartesian body positions.
    mjdata.qpos[:] = q[:2]
    mujoco.mj_forward(mjmodel, mjdata)
    renderer.update_scene(mjdata)
    frames[i] = renderer.render()
    
mediapy.set_ffmpeg("/opt/homebrew/bin/ffmpeg")
mediapy.show_video(frames, fps=60, loop=False)


Unsurprisingly, the initial controller is unable to stabilize the pendulum. Let's train the agent for a longer period of time and see if it can learn to balance the pendulum.

hparams.update({
    "num_timesteps": 20_000_000,
    "num_minibatches": 32,
    "num_envs": 2048,
    "batch_size": 1024,
})

xdata, ydata = [], []
times = [datetime.now()]

# Function for plotting the progress of the training
def progress(num_steps, metrics):
    times.append(datetime.now())
    print(num_steps, metrics)
    xdata.append(num_steps)
    ydata.append(metrics['eval/episode_reward'])
    clear_output(wait=True)

    plt.figure(figsize=(7, 4))
    plt.plot(xdata, ydata, c='k')
    plt.xlim([0, train_fn.keywords['num_timesteps']])
    plt.ylim([0, train_fn.keywords['episode_length']])
    plt.xlabel('# environment steps')
    plt.ylabel('reward per episode')
    plt.grid()
    plt.show()


train_fn = partial(ppo.train, **hparams)
make_inference_fn, params, metrics = train_fn(environment=env, progress_fn=progress)


inference_fn = make_inference_fn(params)

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)

rollout = []
rng = jax.random.PRNGKey(seed=42)
state = jit_env_reset(rng=rng)
for _ in range(500):
  rollout.append(state.obs)
  act_rng, rng = jax.random.split(rng)
  act, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_env_step(state, act)
  if state.done:
    state = jit_env_reset(rng)

t = np.arange(len(rollout)) * env.dt
x = np.array(rollout)

plt.figure(figsize=(7, 2))
plt.plot(t, x)
plt.grid()
plt.show()

The RL agent is now much more effective at stabilizing the pendulum. It is still not perfect (in fact, linear LQR control would be much more effective for this particular task), but it has improved significantly. Further training would likely improve the performance even more.

frames = np.zeros((len(t), *frame.shape), dtype=np.uint8)
for i, q in enumerate(x):
    mjdata.qpos[:] = q[:2]
    mujoco.mj_forward(mjmodel, mjdata)
    renderer.update_scene(mjdata)
    frames[i] = renderer.render()
    
mediapy.show_video(frames, fps=60, loop=False)
Try it in Collimator