Solving ODEs with Probabilistic Numerics

In this tutorial we solve a simple non-linear ordinary differential equation (ODE) with the probabilistic numerical ODE solvers implemented in this package.

Note

If you never used DifferentialEquations.jl, check out their "Getting Started with Differential Equations in Julia" tutorial. It explains how to define and solve ODE problems and how to analyze the solution, so it's a great starting point. Most of ProbNumDiffEq.jl works exactly as you would expect from DifferentialEquations.jl – just with some added uncertainties and related functionality on top!

In this tutorial, we consider a Fitzhugh-Nagumo model described by an ODE of the form

\[\begin{aligned} \dot{y}_1 &= c (y_1 - \frac{y_1^3}{3} + y_2) \\ \dot{y}_2 &= -\frac{1}{c} (y_1 - a - b y_2) \end{aligned}\]

on a time span $t \in [0, T]$, with initial value $y(0) = y_0$. In the following, we

  1. define the problem with explicit choices of initial values, integration domains, and parameters,
  2. solve the problem with our ODE filters, and
  3. visualize the results and the corresponding uncertainties.

TL;DR: Just use DifferentialEquations.jl with the EK1 algorithm

using ProbNumDiffEq, Plots

function fitz(du, u, p, t)
    a, b, c = p
    du[1] = c * (u[1] - u[1]^3 / 3 + u[2])
    du[2] = -(1 / c) * (u[1] - a - b * u[2])
end
u0 = [-1.0; 1.0]
tspan = (0.0, 20.0)
p = (0.2, 0.2, 3.0)
prob = ODEProblem(fitz, u0, tspan, p)

sol = solve(prob, EK1())
plot(sol)
Example block output

Step 1: Define the problem

First, import ProbNumDiffEq.jl

using ProbNumDiffEq

Then, set up the ODEProblem exactly as you would in DifferentialEquations.jl. Define the vector field

function fitz(du, u, p, t)
    a, b, c = p
    du[1] = c * (u[1] - u[1]^3 / 3 + u[2])
    du[2] = -(1 / c) * (u[1] - a - b * u[2])
end

and then the ODEProblem, with initial value u0, time span tspan, and parameters p

u0 = [-1.0; 1.0]
tspan = (0.0, 20.0)
p = (0.2, 0.2, 3.0)
prob = ODEProblem(fitz, u0, tspan, p)

Step 2: Solve the problem

To solve the ODE we just use DifferentialEquations.jl's solve interface, together with one of the algorithms implemented in this package. For now, let's use EK1:

sol = solve(prob, EK1())
retcode: Success
Interpolation: ODE Filter Posterior
t: 267-element Vector{Float64}:
  0.0
  0.021276864853851562
  0.05530062522770621
  0.09069833974374658
  0.13926827781702378
  0.18486670656185625
  0.24179366397911012
  0.29051816272523345
  0.3490769287679599
  0.39571551758025386
  ⋮
 19.52620471671648
 19.57118993725592
 19.618145281565827
 19.669683561096146
 19.72660662373306
 19.789272758467693
 19.860208126620186
 19.94055004101715
 20.0
u: 267-element Vector{Vector{Float64}}:
 [-1.0, 1.0]
 [-0.978397898660792, 1.0098599972789335]
 [-0.9424079091803095, 1.0253304289278915]
 [-0.9028542905882087, 1.0410170752594414]
 [-0.8445349615757094, 1.0618117231749664]
 [-0.7847703259504633, 1.0804970332988895]
 [-0.7018976697192764, 1.102556299329847]
 [-0.622068920127181, 1.1201790771323807]
 [-0.512584623578099, 1.1395948416743362]
 [-0.41230242602460154, 1.1534759756505175]
 ⋮
 [2.0826740294562933, 0.9061101524201192]
 [2.078744560756923, 0.880585842963934]
 [2.073202127513046, 0.8539368332415443]
 [2.0660462633712053, 0.8247002716083859]
 [2.0573291294637492, 0.7924433159689404]
 [2.0471229128144675, 0.7569883434597867]
 [2.0350861534147575, 0.7169394922557054]
 [2.0210436519484105, 0.6717007174964964]
 [2.0104405118668844, 0.6383145073764129]

That's it! we just computed a probabilistic numerical ODE solution!

Step 3: Analyze the solution

Let's plot the result with Plots.jl.

using Plots
plot(sol)
Example block output

Looks good! Looks like the EK1 managed to solve the Fitzhugh-Nagumo problem quite well.

Tip

To learn more about plotting ODE solutions, check out the plotting tutorial for DifferentialEquations.jl + Plots.jl provided here. Most of that works exactly as expected with ProbNumDiffEq.jl.

Plot the probabilistic error estimates

The plot above looks like a standard ODE solution – but it's not! The numerical errors are just so small that we can't see them in the plot, and the probabilistic error estimates are too. We can visualize them by plotting the errors and error estimates directly:

using OrdinaryDiffEq, Statistics
reference = solve(prob, Vern9(), abstol=1e-9, reltol=1e-9, saveat=sol.t)
errors = reduce(hcat, mean.(sol.pu) .- reference.u)'
error_estimates = reduce(hcat, std.(sol.pu))'
plot(sol.t, errors, label="error", color=[1 2], xlabel="t", ylabel="err")
plot!(sol.t, zero(errors), ribbon=3error_estimates, label="error estimate",
      color=[1 2], alpha=0.2)
Example block output

More about the ProbabilisticODESolution

The solution object returned by ProbNumDiffEq.jl mostly behaves just like any other ODESolution in DifferentialEquations.jl – with some added uncertainties and related functionality on top. The ProbabilisticODESolution can be indexed with

julia> sol.u[1]2-element Vector{Float64}:
 -1.0
  1.0
julia> sol.u[end]2-element Vector{Float64}: 2.0104405118668844 0.6383145073764129
julia> sol.t[end]20.0

But since sol is a probabilistic numerical ODE solution, it contains a Gaussian distributions over solution values. The marginals of this posterior are stored in sol.pu:

julia> sol.pu[end]Gaussian{Vector{Float64},PSDMatrix{Float64, Matrix{Float64}}}([2.0104405118668844, 0.6383145073764129], 2x2 PSDMatrix{Float64, Matrix{Float64}}; R=[4.819328867073896e-5 0.000150086281654363; -4.209242496334892e-5 -0.00013099898185084177; 4.6624624651022493e-7 6.695624542113227e-6; -9.340635819017074e-8 5.806477585177133e-8; 1.3698832861365067e-6 4.548663507983009e-6; -1.6958383599728362e-7 -5.775663746240759e-6; 0.0 0.0; 0.0 0.0])

You can compute means, covariances, and standard deviations via Statistics.jl:

julia> using Statistics
julia> mean(sol.pu[5])2-element Vector{Float64}: -0.8445349615757094 1.0618117231749664
julia> cov(sol.pu[5])2x2 PSDMatrix{Float64, Matrix{Float64}} Right square root: R=8×2 Matrix{Float64}: -2.8014e-6 -1.24815e-7 0.0 2.74772e-6 0.0 -2.06464e-21 0.0 -9.13207e-22 0.0 1.90066e-19 0.0 0.0 0.0 0.0 0.0 0.0
julia> std(sol.pu[5])2-element Vector{Float64}: 2.8014003609809637e-6 2.750553392466558e-6

Dense output

Probabilistic numerical ODE solvers approximate the posterior distribution

\[p \Big( y(t) ~\big|~ y(0) = y_0, \{ \dot{y}(t_i) = f_\theta(y(t_i), t_i) \} \Big),\]

which describes a posterior not just for the discrete steps but for any $t$ in the continuous space $t \in [0, T]$; in classic ODE solvers, this is also known as "interpolation" or "dense output". The probabilistic solutions returned by our solvers can be interpolated as usual by treating them as functions, but they return Gaussian distributions

julia> sol(0.45)Gaussian{Vector{Float64},PSDMatrix{Float64, Matrix{Float64}}}([-0.2773821283086789, 1.1675659430627081], 2x2 PSDMatrix{Float64, Matrix{Float64}}; R=[-3.208354407821569e-5 -4.789969845355876e-6; 0.0 1.9264788456676448e-5; 0.0 1.460579017560511e-5; 0.0 1.0839739025995338e-6; 0.0 5.606932635215914e-7; 0.0 0.0; 0.0 0.0; 0.0 0.0])
julia> mean(sol(0.45))2-element Vector{Float64}: -0.2773821283086789 1.1675659430627081

Next steps

Check out one of the other tutorials: