Learning Universal Differential Equations with Neural Networks
This tutorial demonstrates how one can use Collimator to learn Universal Differential Equations (UDEs)—differential equations that are either fully or partly defined by a universal approximator such as a neural network. UDEs that are fully defined by a neural network are often known as neural differential equations.
As in [1], we demonstrate how the Lotka-Volterra system can be learned as a UDE from data obtained over a short duration. Subsequently, the UDE can be symbolically regressed over to find an analytical representation of the UDE terms. This system can then be used to reliably predict longer term behavior.
Lotka-Volterra system
The Lotka-Volterra system, also known as the predator-prey equations, comprises a pair of first-order, nonlinear, differential equations frequently used to describe the dynamics of biological systems in which two species interact, one as a predator and the other as prey. The system equations are given by:
$$\frac{dx}{dt} = \alpha x - \beta xy,$$
$$\frac{dy}{dt} = \gamma xy - \delta y,$$
where:
$x$ represents the number of prey (for example, rabbits),
$y$ represents the number of predators (for example, foxes),
$\frac{dx}{dt}$ and $\frac{dy}{dt}$ represent the growth rates of prey and predator populations over time, respectively,
$\alpha$, $\beta$, $\delta$, and $\gamma$ are positive real constants that represent the interaction rates between the species and their environment. Specifically, $\alpha$ is the natural growth rate of prey in the absence of predators, $\beta$ is the mortality rate of prey due to predation, $\delta$ is the natural death rate of predators in the absence of prey, and $\gamma$ represents the growth rate of the predator population per prey consumed.
These equations produce oscillations in the populations of both species, with the number of prey increasing when predators are scarce, followed by an increase in the number of predators as there is more prey to consume. This, in turn, leads to a decrease in the prey population due to predation, and subsequently, a decrease in the predator population as food becomes scarce. These cyclic dynamics are observed in many real-world predator-prey systems and provide insights into the complexity of ecological interactions and the importance of biodiversity for ecosystem stability.
Steps in this notebook
We will follow the following steps in this notebook:
1. Create and simulate the Lotka-Volterra system in Collimator with known parameters.
We will simulate the system for a short duration and a long duration to generate true data. We will treat the short duration data after adding noise as experimentally obtained data for training. The longer duration data will be used for testing.
2. We will assume that only the linear terms in the Lotka-Volterra system are known a priori.
For the missing dynamics we will use a neural network. Thus, our UDE will be of the following form:
$$\frac{dx}{dt} = \alpha x - f_{\theta}[0],$$
$$\frac{dy}{dt} = -\delta y - f_{\theta}[1],$$
where $f_{\theta}$ represents a neural network with parameters $\theta$. We will use a multilayer perceptron (MLP) for the neural network.
3. We will train $f_{\theta}$ on the short duration training data.
4. Finally, we will use symbolic regression to recover analytical expressions that represent $f_{\theta}$ and use this learned model for prediction.
The reader is referred to [1] for more details of the adopted methodology.
[1] Rackauckas, C., Ma, Y., Martensen, J., Warner, C., Zubov, K., Supekar, R., Skinner, D., Ramadhan, A. and Edelman, A., 2020. Universal differential equations for scientific machine learning (arXiv preprint arXiv:2001.04385).
%matplotlib inline
import matplotlib.pyplot as plt
from functools import partial
from scipy.optimize import minimize
import jax
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import optax
import collimator
from collimator.framework import LeafSystem
from collimator.simulation import SimulatorOptions, ResultsOptions
from collimator import logging
logging.set_log_level(logging.ERROR)
from collimator.library import MLP, Adder, Integrator, Power, Clock, FeedthroughBlock, SourceBlock
Create a Collimator LeafSystem for true Lotka-Volterra system
The code-block below implements the true Lotka-Volterra system dynamics.
class LotkaVolterra(LeafSystem):
''' True Lotka Volterra system '''
def __init__(self, u0=[10.0, 10.0], alpha=1.3, beta=0.9, gamma=0.8, delta=1.8, *args, **kwargs):
super().__init__(*args, **kwargs)
self.declare_parameter("alpha", alpha)
self.declare_parameter("beta", beta)
self.declare_parameter("gamma", gamma)
self.declare_parameter("delta", delta)
# Continuous state u = [x, y]
self.declare_continuous_state(default_value=jnp.array(u0), ode=self.ode)
# Output the full state [x,y] from the block
self.declare_continuous_state_output()
def ode(self, time, state, *inputs, **parameters):
# Unpack state
x, y = state.continuous_state
# Gather parameters
alpha = parameters["alpha"]
beta = parameters["beta"]
gamma = parameters["gamma"]
delta = parameters["delta"]
# Implement the ODE RHS for the Lotka-Volterra system
dot_x = alpha * x - beta * x * y
dot_y = gamma * x * y - delta * y
return jnp.array([dot_x, dot_y])
Simulate the system for short and long durations
We implement Scenario-1 from reference [1].
T_short = 3.0
T_long = 50.0
u0 = [0.44249296, 4.6280594] # initial state
params = [1.3, 0.9, 0.8, 1.8] # True parameters
def plot_sol(sols, labels=None):
''' Utiliity to plot multiple Simulation results objects '''
if labels is None:
labels = [f"sol-{i}" for i in range(len(sols))]
fig, (ax1, ax2) = plt.subplots(1,2,figsize=[18,4])
for sol, label in zip(sols, labels):
ax1.plot(sol.time, sol.outputs["u"][:,0], label = label + ": x")
ax2.plot(sol.time, sol.outputs["u"][:,1], label = label + ": y")
for ax in (ax1, ax2):
ax.legend(loc="upper right")
ax.set_xlabel("t")
plt.show()
def make_true_sys():
''' Create diagam with just the Lotka-Volterra block '''
builder = collimator.DiagramBuilder()
lv = builder.add(LotkaVolterra(u0, *params, name="lv"))
diagram = builder.build()
return diagram
def run_true_sys(tspan):
''' Utility to run the true system for a given time span '''
diagram = make_true_sys()
context = diagram.create_context()
options = SimulatorOptions(max_minor_step_size=0.1, rtol=1e-12, atol=1e-12)
recorded_signals = {"u": diagram["lv"].output_ports[0]}
sol = collimator.simulate(diagram, context, tspan, options=options, recorded_signals=recorded_signals)
return sol
sol_true_short = run_true_sys([0.0, T_short])
sol_true_long = run_true_sys([0.0, T_long])
plot_sol([sol_true_long, sol_true_short], ["true-test", "true-train"])
Generate training data by adding noise to the short duration simulation output
we will represent the terms $[\alpha x, -\delta y]$ in a [.code]FeedthroughBlock[.code] and the $f_{\theta}$ as an [.code]MLP[.code] block. Subsequently, we can use the [.code]Adder[.code] block to sum the outputs, providing us the RHS of the UDE. This RHS can then be passed to an [.code]ntegrator[.code] block to solve the UDE.
For training the UDE, it will also be useful for us to read the training data and compute the loss during the simulation. The instantaneous error $\mathbf{e}$ is the difference between the training data at time $t$ and the simulated output of the ODE, $\mathbf{e} = \mathbf{u}_{\text{train}} - \mathbf{u}_{\text{ude}}$. To obtain the cumulative loss during the entire training duration, we can integrate the instantaneous error to give us a loss $\mathcal{L} = \int_0^T \mathbf{e^T Q\; e} \, dt$, where $\mathbf{Q}$ is a positive definite matrix (we will choose Identity matrix). Thus, we add blocks to also compute this integral as part of our simulation run. The overall structure of the block diagrams is shown in the schematic below followed by implementation in code.
class KnownLotkaVolterraRHS(FeedthroughBlock):
''' Known terms of the Lotka-Volterra system '''
def __init__(self, alpha=1.3, delta=1.8, *args, **kwargs):
super().__init__(self.rhs, parameters = {"alpha": alpha, "delta":delta}, *args, **kwargs)
def rhs(self, inputs, **parameters):
x, y = inputs
alpha = parameters["alpha"]
delta = parameters["delta"]
return jnp.array([alpha * x, - delta * y])
class InterpTrainingData(SourceBlock):
''' Block to interpolate the training data and output the interpolated
values at the input time `t`
'''
def __init__(self, t_vec, x_arr, *args, **kwargs):
self.t_vec = t_vec
self.x_arr = x_arr
self.interp_fun = jax.vmap(jnp.interp, (None, None, 1))
super().__init__(
lambda t: self.interp_fun(t, self.t_vec, self.x_arr), *args, **kwargs
)
class LossStructure(FeedthroughBlock):
''' Quadratc prdocut of the error as loss '''
def __init__(self, Q, *args, **kwargs):
super().__init__(lambda e: e.T @ Q @ e, *args, **kwargs)
def make_ude_system():
builder = collimator.DiagramBuilder()
# the ude system
klv_rhs = builder.add(KnownLotkaVolterraRHS(name="klv_rhs"))
mlp = builder.add(MLP(2,2,5,3, seed=42, activation_str="rbf", name="mlp"))
rhs = builder.add(Adder(2, name="rhs"))
lv = builder.add(Integrator(u0, name="lv"))
builder.connect(klv_rhs.output_ports[0], rhs.input_ports[0])
builder.connect(mlp.output_ports[0], rhs.input_ports[1])
builder.connect(rhs.output_ports[0], lv.input_ports[0])
builder.connect(lv.output_ports[0], klv_rhs.input_ports[0])
builder.connect(lv.output_ports[0], mlp.input_ports[0])
# add extra components to compute cost/loss during simulation
ref = builder.add(InterpTrainingData(t_train, u_train, name="ref"))
err = builder.add(Adder(2, operators="+-", name="err"))
cost = builder.add(LossStructure(jnp.eye(2), name="cost"))
int_cost = builder.add(Integrator(0.0, name="int_cost"))
builder.connect(ref.output_ports[0], err.input_ports[0])
builder.connect(lv.output_ports[0], err.input_ports[1])
builder.connect(err.output_ports[0], cost.input_ports[0])
builder.connect(cost.output_ports[0], int_cost.input_ports[0])
diagram = builder.build()
return diagram
diagram = make_ude_system()
base_context = diagram.create_context()
The parameters of the MLP are of interest as these will be optimized to fit the training data. In the code below, we first extract the initial MLP parameters, which inherently have a PyTree structure. Then, we flatten this Pytree while also generating and [.code]unflatten[.code] function.
We can now take advantage of JAX's JIT compilation and automatic differentiation capabilities to compute the gradient of the loss with respect to the MLP parameters.
We now have all the machinery to train the UDE. As in [1], we first run the ADAM optimizer for a small number (200) of iterations to get a good guess. Subsequently, we use the BFGS optimizer, while starting from the solution of the ADAM optimizer.
First run ADAM for 200 iterations to get a good guess for BFGS
num_epochs = 200
# Optax optimizer
optimizer = optax.adam(learning_rate=0.1)
# Initialize optimizer state
params = initial_mlp_parameters_flat
opt_state = optimizer.init(params)
train_loss_history = []
for epoch in range(num_epochs):
train_loss, grads = value_and_grad(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
train_loss_history.append(train_loss)
# Print the function value at the current parameters
if (epoch + 1) % 500 == 0:
print(
f"Epoch [{epoch+1}/{num_epochs}]: loss = {train_loss}"
)
opt_params_adam = params
fig, ax = plt.subplots(1,1,figsize=(6,4))
ax.plot(train_loss_history, "-r", label="train loss")
ax.set_xlabel("epochs")
ax.set_ylabel("loss")
ax.set_yscale("log")
ax.legend()
plt.show()
We find that the training data is reproduced very well. However, the UDE model quickly falls apart when predicting for durations beyong the training duration. It is wiser to use the UDE to learn the missing ODE terms, i.e. learn analytical expressions representing the trained neural network in the UDE. This is shown next.
Symbolic regression to recover missing ODE terms
To obtain analytical expressions for the trained MLP, we can utilize Symbolic Regression. There are multiple choices for Symbolic Regression within the Python ecosystem. Below we demonstrate a few of these.
Prepare data for symbolic regression
We directly use the data generated during the above training. However, it is also possible to sample from the MLP to get better training data for symbolic regression.
sol = sol_ude_short
mlp_in = sol.outputs["u"]
mlp_out = sol.outputs["mlp"]
The problem of symbolic regression is to learn analytical terms that would map [.code]mlp_in[.code] vector to the [.code]mlp_out[.code] vector.
Use PySindy library for symbolic regression
To use PySindy for our purposes, we can can represent the derivatives [.code]x_dot[.code] as the [.code]mlp_out[.code] vector and [.code]x[.code] as the [.code]mlp_in[.code] vector. PySindy will then perform symbolic regression mapping [.code]x_dot[.code] to [.code]x[.code], thus achieving our objective.
import pysindy as ps
# Fit the model
poly_order = 2
threshold = 0.05
model = ps.SINDy(
optimizer=ps.STLSQ(threshold=threshold),
feature_library=ps.PolynomialLibrary(degree=poly_order),
)
model.fit(mlp_in, x_dot=mlp_out)
model.print()
(x0)' = -0.886 x0 x1
(x1)' = 0.809 x0 x1
We find that the [.code]pysindy[.code] has recovered the missing terms from the learned UDE data. The same form of equations as the true Lotka-Volterra system is recovered with $\beta = 0.886$ and $\gamma=0.809$. Recall that the true values for both these parameters were 0.9 and 0.8, respectively. For better precision, we can looks at [.code]model.coefficients[.code].
We can now simulate our learned system from the UDE approximation. Here, we replace the MLO with the analytical expressions we have found via symbolic regression.