Getting Started with Parareax¶
This notebook demonstrates how to use the Parareax package to solve an ordinary differential equation (ODE) using the parareal algorithm. We'll walk through a complete example using the logistic equation.
Setup¶
First, let's import the necessary libraries:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from parareax import run_parareal
# Enable double precision
jax.config.update("jax_enable_x64", True)
Define the Problem¶
We'll solve the logistic equation $$\frac{dy}{dt} = y(1-y),$$ which models constrained growth, on the interval $t \in [0, 10]$, with initial value $y(0) = 10^{-2}$.
# Define the differential equation: logistic equation dy/dt = -y(1-y)
def f(y, t):
return y * (1 - y)
# Initial conditions
t0, t1 = 0.0, 10.0 # Time span
y0 = jnp.array([1e-2]) # Initial value
This ODE has a closed form solution:
# Analytical solution for comparison
def analytical_solution(t, y0):
return 1 / (1 + (1 - y0) / y0 * jnp.exp(-t))
ts = jnp.linspace(t0, t1, 100)
plt.figure(figsize=(8, 5))
plt.plot(ts, analytical_solution(ts, y0))
plt.xlabel("t")
plt.ylabel("y(t)")
plt.title("Analytical Solution of the Logistic Equation")
plt.grid(True, alpha=0.3)
plt.show()
Define Propagators¶
The parareal algorithm requires two propagators:
- A coarse propagator: Fast but less accurate
- A fine propagator: Slow but more accurate
Coarse Propagator¶
We'll use a simple Euler method as our coarse propagator:
@jax.jit
def coarse_step(y0, t0, t1):
"""Simple Euler method for the coarse propagator."""
dy = f(y0, t0)
return y0 + dy * (t1 - t0)
Fine Propagator¶
For the fine propagator, we'll use a high-order Dormand-Prince method from jax.experimental:
from jax.experimental.ode import odeint
@jax.jit
def fine_step(y0, t0, t1):
"""High-accuracy Dormand-Prince method for the fine propagator."""
return odeint(lambda y, t: f(y, t), y0, jnp.array([t0, t1]))[-1]
Run the Parareal Algorithm¶
Now we'll solve our problem using the parareal algorithm:
# Define the time points
ts = jnp.linspace(t0, t1, 101) # 100 time intervals
# Solve using Parareal
sol, info = run_parareal(
coarse_step,
fine_step,
y0=y0,
ts=ts,
tol=1e-14, # Convergence tolerance
)
# Print convergence information
print(f"Parareal converged in {info['iterations']} iterations")
print(f"Final error norm: {info['last_change_norm']:.2e}")
Parareal converged in 9 iterations Final error norm: 9.72e-15
ref = analytical_solution(ts, y0)
plt.figure(figsize=(8, 5))
plt.plot(ts, ref, "k-", linewidth=2, label="Analytical solution")
plt.plot(ts, sol, "r--", linewidth=1.5, label="Parareal solution")
plt.xlabel("t")
plt.ylabel("y(t)")
plt.title("Logistic Equation Solution")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Compare to Standard Approaches¶
To fully appreciate the benefits of the parareal algorithm, let's compare it with the standard approaches of using either the coarse or fine solvers alone. This comparison will demonstrate the accuracy/performance tradeoff that makes parareal valuable.
First, we need to create equivalent solver functions from our stepping functions. The following helper converts any stepping function into a full solver that can propagate an initial state through a sequence of time points:
from parareax.utils import stepper_to_solver
coarse_solver = stepper_to_solver(coarse_step)
fine_solver = stepper_to_solver(fine_step)
Now let's compute the solutions using all three methods and compare their accuracy against the analytical solution:
sol_coarse = coarse_solver(y0, ts)
sol_fine = fine_solver(y0, ts)
coarse_rmse = jnp.mean((sol_coarse.ravel() - ref.ravel()) ** 2)
fine_rmse = jnp.mean((sol_fine.ravel() - ref.ravel()) ** 2)
parareal_rmse = jnp.mean((sol.ravel() - ref.ravel()) ** 2)
print(f"Coarse solver RMSE: {coarse_rmse:.2e}")
print(f"Fine solver RMSE: {fine_rmse:.2e}")
print(f"Parareal RMSE: {parareal_rmse:.2e}")
Coarse solver RMSE: 3.58e-04 Fine solver RMSE: 5.28e-14 Parareal RMSE: 5.28e-14
Note that the fine solver and parareal algorithm both achieve very high accuracy (close to machine precision), while the simple Euler method (coarse solver) has noticeable error. This demonstrates that parareal maintains the accuracy of the fine solver.
Let's also benchmark the performance of all three approaches
print("\nPerformance Benchmarking:")
print("=" * 40)
print("Coarse solver (Tsit5):")
%timeit -n 10 -r 3 coarse_solver(y0, ts)
print("\nFine solver (Kvaerno5):")
%timeit -n 10 -r 3 fine_solver(y0, ts)
print("\nParareal algorithm:")
%timeit -n 10 -r 3 run_parareal(coarse_step, fine_step, y0=y0, ts=ts, tol=1e-14)
Performance Benchmarking: ======================================== Coarse solver (Tsit5): 12.8 μs ± 4.99 μs per loop (mean ± std. dev. of 3 runs, 10 loops each) Fine solver (Kvaerno5): 356 μs ± 81.2 μs per loop (mean ± std. dev. of 3 runs, 10 loops each) Parareal algorithm: 20.2 μs ± 4.98 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Conclusion¶
The parareal algorithm offers an elegant compromise between accuracy and speed:
- It achieves the same high accuracy as the fine solver
- It can be significantly faster on parallel hardware
- It works with any pair of coarse and fine propagators
- It's particularly useful for long time integrations
This example demonstrates how easy it is to use the Parareax library to implement this neat time-parallel algorithm for your own differential equation problems.
Next Steps¶
Now that you've seen the basics, explore more advanced usage:
- Diffrax Integration: Using Parareax with the Diffrax ODE solver
- Flax Integration: Combining Parareax with Flax for neural differential equation solvers
- API Reference: Complete function documentation