Solving ODEs with Probabilistic Numerics
In this tutorial we solve a simple non-linear ordinary differential equation (ODE) with the probabilistic numerical ODE solvers implemented in this package.
If you never used DifferentialEquations.jl, check out their "Getting Started with Differential Equations in Julia" tutorial. It explains how to define and solve ODE problems and how to analyze the solution, so it's a great starting point. Most of ProbNumDiffEq.jl works exactly as you would expect from DifferentialEquations.jl – just with some added uncertainties and related functionality on top!
In this tutorial, we consider a Fitzhugh-Nagumo model described by an ODE of the form
\[\begin{aligned} \dot{y}_1 &= c (y_1 - \frac{y_1^3}{3} + y_2) \\ \dot{y}_2 &= -\frac{1}{c} (y_1 - a - b y_2) \end{aligned}\]
on a time span $t \in [0, T]$, with initial value $y(0) = y_0$. In the following, we
- define the problem with explicit choices of initial values, integration domains, and parameters,
- solve the problem with our ODE filters, and
- visualize the results and the corresponding uncertainties.
TL;DR: Just use DifferentialEquations.jl with the EK1
algorithm
using ProbNumDiffEq, Plots
function fitz(du, u, p, t)
a, b, c = p
du[1] = c * (u[1] - u[1]^3 / 3 + u[2])
du[2] = -(1 / c) * (u[1] - a - b * u[2])
end
u0 = [-1.0; 1.0]
tspan = (0.0, 20.0)
p = (0.2, 0.2, 3.0)
prob = ODEProblem(fitz, u0, tspan, p)
sol = solve(prob, EK1())
plot(sol)
Step 1: Define the problem
First, import ProbNumDiffEq.jl
using ProbNumDiffEq
Then, set up the ODEProblem
exactly as you would in DifferentialEquations.jl. Define the vector field
function fitz(du, u, p, t)
a, b, c = p
du[1] = c * (u[1] - u[1]^3 / 3 + u[2])
du[2] = -(1 / c) * (u[1] - a - b * u[2])
end
and then the ODEProblem
, with initial value u0
, time span tspan
, and parameters p
u0 = [-1.0; 1.0]
tspan = (0.0, 20.0)
p = (0.2, 0.2, 3.0)
prob = ODEProblem(fitz, u0, tspan, p)
Step 2: Solve the problem
To solve the ODE we just use DifferentialEquations.jl's solve
interface, together with one of the algorithms implemented in this package. For now, let's use EK1
:
sol = solve(prob, EK1())
retcode: Success
Interpolation: ODE Filter Posterior
t: 267-element Vector{Float64}:
0.0
0.021276864853851562
0.05530062522770621
0.09069833974374658
0.13926827781702378
0.18486670656185625
0.24179366397911012
0.29051816272523345
0.3490769287679599
0.39571551758025386
⋮
19.52620471671648
19.57118993725592
19.618145281565827
19.669683561096146
19.72660662373306
19.789272758467693
19.860208126620186
19.94055004101715
20.0
u: 267-element Vector{Vector{Float64}}:
[-1.0, 1.0]
[-0.9783978986607919, 1.0098599972789337]
[-0.9424079091803095, 1.0253304289278917]
[-0.9028542905882088, 1.0410170752594412]
[-0.8445349615757097, 1.0618117231749666]
[-0.7847703259504635, 1.0804970332988892]
[-0.7018976697192769, 1.1025562993298472]
[-0.6220689201271815, 1.1201790771323805]
[-0.5125846235780994, 1.139594841674336]
[-0.412302426024602, 1.1534759756505175]
⋮
[2.0826740294562938, 0.9061101524201137]
[2.078744560756923, 0.8805858429639286]
[2.0732021275130457, 0.853936833241539]
[2.0660462633712045, 0.8247002716083807]
[2.057329129463748, 0.7924433159689349]
[2.0471229128144666, 0.7569883434597815]
[2.0350861534147566, 0.7169394922557]
[2.0210436519484087, 0.6717007174964914]
[2.0104405118668827, 0.6383145073764079]
That's it! we just computed a probabilistic numerical ODE solution!
Step 3: Analyze the solution
Let's plot the result with Plots.jl.
using Plots
plot(sol)
Looks good! Looks like the EK1
managed to solve the Fitzhugh-Nagumo problem quite well.
To learn more about plotting ODE solutions, check out the plotting tutorial for DifferentialEquations.jl + Plots.jl provided here. Most of that works exactly as expected with ProbNumDiffEq.jl.
Plot the probabilistic error estimates
The plot above looks like a standard ODE solution – but it's not! The numerical errors are just so small that we can't see them in the plot, and the probabilistic error estimates are too. We can visualize them by plotting the errors and error estimates directly:
using OrdinaryDiffEq, Statistics
reference = solve(prob, Vern9(), abstol=1e-9, reltol=1e-9, saveat=sol.t)
errors = reduce(hcat, mean.(sol.pu) .- reference.u)'
error_estimates = reduce(hcat, std.(sol.pu))'
plot(sol.t, errors, label="error", color=[1 2], xlabel="t", ylabel="err")
plot!(sol.t, zero(errors), ribbon=3error_estimates, label="error estimate",
color=[1 2], alpha=0.2)
More about the ProbabilisticODESolution
The solution object returned by ProbNumDiffEq.jl mostly behaves just like any other ODESolution
in DifferentialEquations.jl – with some added uncertainties and related functionality on top. The ProbabilisticODESolution
can be indexed with
julia> sol.u[1]
2-element Vector{Float64}: -1.0 1.0
julia> sol.u[end]
2-element Vector{Float64}: 2.0104405118668827 0.6383145073764079
julia> sol.t[end]
20.0
But since sol
is a probabilistic numerical ODE solution, it contains a Gaussian distributions over solution values. The marginals of this posterior are stored in sol.pu
:
julia> sol.pu[end]
Gaussian{Vector{Float64},PSDMatrix{Float64, Matrix{Float64}}}( μ=[2.0104405118668827, 0.6383145073764079], Σ=2x2 PSDMatrix{Float64, Matrix{Float64}}; R=[4.819328866957359e-5 0.00015008628165071627; 3.994225962938706e-5 0.00012478474684357176; -1.3359035909208665e-5 -4.029546230185406e-5; -2.361011408675205e-7 -8.04110127435485e-6; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0] )
You can compute means, covariances, and standard deviations via Statistics.jl:
julia> using Statistics
julia> mean(sol.pu[5])
2-element Vector{Float64}: -0.8445349615757097 1.0618117231749666
julia> cov(sol.pu[5])
2x2 PSDMatrix{Float64, Matrix{Float64}} Right square root: R=8×2 Matrix{Float64}: -2.8014e-6 -1.24815e-7 0.0 -2.74772e-6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
julia> std(sol.pu[5])
2-element Vector{Float64}: 2.8014003609822313e-6 2.7505533924668052e-6
Dense output
Probabilistic numerical ODE solvers approximate the posterior distribution
\[p \Big( y(t) ~\big|~ y(0) = y_0, \{ \dot{y}(t_i) = f_\theta(y(t_i), t_i) \} \Big),\]
which describes a posterior not just for the discrete steps but for any $t$ in the continuous space $t \in [0, T]$; in classic ODE solvers, this is also known as "interpolation" or "dense output". The probabilistic solutions returned by our solvers can be interpolated as usual by treating them as functions, but they return Gaussian distributions
julia> sol(0.45)
Gaussian{Vector{Float64},PSDMatrix{Float64, Matrix{Float64}}}( μ=[-0.2773821283086794, 1.1675659430627088], Σ=2x2 PSDMatrix{Float64, Matrix{Float64}}; R=[-3.2083544077891136e-5 -4.7899698453730076e-6; 0.0 -2.4206415622162586e-5; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0] )
julia> mean(sol(0.45))
2-element Vector{Float64}: -0.2773821283086794 1.1675659430627088
Next steps
Check out one of the other tutorials:
- "Second Order ODEs and Energy Preservation" explains how to solve second-order ODEs more efficiently while also better preserving energy or other conserved quantities;
- "Solving DAEs with Probabilistic Numerics" demonstrates how to solve differential algebraic equations in a probabilistic numerical way.