Setup¶
import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from parareax import run_parareal
# Enable double precision
jax.config.update("jax_enable_x64", True)
if not backend.has_been_selected:
backend.select("jax")
Define the Problem¶
For this example, we will solve the Henon-Heiles ODE, a classic celestial mechanics problem, provided by the diffeqzoo package:
f, y0, (t0, t1), args = ivps.henon_heiles_first_order()
@jax.jit
def vf(t, y, p):
return f(y, *p)
term = diffrax.ODETerm(vf)
Define Coarse and Fine Propagators¶
To build a Parareal algorithm, we define a coarse and a fine propagator. Here, we use the faster, explicit Runge-Kutta method Tsit5() as the coarse solver and the more accurate implicit Runge-Kutta method Kvaerno5() as the fine solver.
@jax.jit
def coarse_step(y0, t0, t1):
sol = diffrax.diffeqsolve(
term,
diffrax.Tsit5(),
t0=t0,
t1=t1,
y0=y0,
max_steps=None,
args=args,
dt0=(t1 - t0),
)
return sol.ys[-1]
@jax.jit
def fine_step(y0, t0, t1):
sol = diffrax.diffeqsolve(
term,
diffrax.Kvaerno5(),
t0=t0,
t1=t1,
y0=y0,
max_steps=None,
args=args,
dt0=None,
stepsize_controller=diffrax.PIDController(rtol=1e-13, atol=1e-15),
)
return sol.ys[-1]
Run the Parareal Algorithm¶
Now we'll solve our problem using the parareal algorithm:
# Define the time points
ts = jnp.linspace(t0, t1, 201)
# Solve using Parareal
tol = 1e-8
sol, info = run_parareal(coarse_step, fine_step, y0=y0, ts=ts, tol=tol)
# Print convergence information
print(f"Parareal converged in {info['iterations']} iterations")
print(f"Final error norm: {info['last_change_norm']:.2e}")
Parareal converged in 4 iterations Final error norm: 8.05e-10
Let's plot the results:
fig, axes = plt.subplots(1, 3, figsize=(10, 4))
# Time series
axes[0].plot(ts, sol)
axes[0].set_xlabel("Time")
axes[0].set_ylabel("State variables")
axes[0].set_title("Time Evolution")
axes[0].legend(["x", "y", "px", "py"])
axes[0].grid(True, alpha=0.3)
# Position phase space
axes[1].plot(sol[:, 0], sol[:, 1])
axes[1].set_xlabel("x position")
axes[1].set_ylabel("y position")
axes[1].set_title("Position Phase Space")
axes[1].grid(True, alpha=0.3)
# Momentum phase space
axes[2].plot(sol[:, 2], sol[:, 3], color="C1")
axes[2].set_xlabel("x momentum")
axes[2].set_ylabel("y momentum")
axes[2].set_title("Momentum Phase Space")
axes[2].grid(True, alpha=0.3)
fig.suptitle("Henon-Heiles System Solution")
fig.tight_layout()
Compare to Standard Approaches¶
To fully appreciate the benefits of the parareal algorithm, let's compare it with the standard approaches of using either the coarse or fine solvers alone. This comparison will demonstrate the accuracy/performance tradeoff that makes parareal valuable.
For this comparison, we need:
- Reference solution: High-accuracy solution using Dopri8 with very tight tolerances
- Coarse solver: Fast Tsit5 method (same as our coarse propagator)
- Fine solver: Accurate Kvaerno5 method (same as our fine propagator)
- Parareal solution: Our hybrid approach combining both
# 1. High-accuracy reference solution
ref = diffrax.diffeqsolve(
term,
diffrax.Dopri8(),
t0=t0,
t1=t1,
y0=y0,
dt0=1e-5,
max_steps=None,
saveat=diffrax.SaveAt(ts=ts),
stepsize_controller=diffrax.PIDController(rtol=1e-18, atol=1e-18),
args=args,
)
# 2&3. Coarse and fine solutions
from parareax.utils import stepper_to_solver
coarse_solver = stepper_to_solver(coarse_step)
sol_coarse = coarse_solver(y0, ts)
fine_solver = stepper_to_solver(fine_step)
sol_fine = fine_solver(y0, ts)
# Compute errors
coarse_rmse = jnp.sqrt(jnp.mean((sol_coarse.ravel() - ref.ys.ravel()) ** 2))
fine_rmse = jnp.sqrt(jnp.mean((sol_fine.ravel() - ref.ys.ravel()) ** 2))
parareal_rmse = jnp.sqrt(jnp.mean((sol.ravel() - ref.ys.ravel()) ** 2))
print("\nAccuracy Comparison (RMSE vs reference):")
print(f" Coarse solver (Tsit5): {coarse_rmse:.2e}")
print(f" Fine solver (Kvaerno5): {fine_rmse:.2e}")
print(f" Parareal (hybrid): {parareal_rmse:.2e}")
print(
f"\nParareal achieves {coarse_rmse / parareal_rmse:.0f}x better accuracy than coarse solver"
)
print(
f"while being comparable to the fine solver ({parareal_rmse / fine_rmse:.1f}x its error)"
)
Accuracy Comparison (RMSE vs reference): Coarse solver (Tsit5): 1.90e-03 Fine solver (Kvaerno5): 2.54e-11 Parareal (hybrid): 2.57e-11 Parareal achieves 73760971x better accuracy than coarse solver while being comparable to the fine solver (1.0x its error)
Let's also benchmark the performance of all three approaches:
print("\nPerformance Benchmarking:")
print("=" * 40)
print("Coarse solver (Tsit5):")
%timeit -n 10 -r 3 coarse_solver(y0, ts)
print("\nFine solver (Kvaerno5):")
%timeit -n 10 -r 3 fine_solver(y0, ts)
print("\nParareal algorithm:")
%timeit -n 10 -r 3 run_parareal(coarse_step, fine_step, y0=y0, ts=ts, tol=tol)
Performance Benchmarking: ======================================== Coarse solver (Tsit5): 493 μs ± 114 μs per loop (mean ± std. dev. of 3 runs, 10 loops each) Fine solver (Kvaerno5): 333 ms ± 115 ms per loop (mean ± std. dev. of 3 runs, 10 loops each) Parareal algorithm: 215 ms ± 25.3 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Key Insights:
- Coarse solver: Fastest but least accurate
- Fine solver: Most accurate but slowest
- Parareal: Near fine-solver accuracy with improved efficiency
- Performance gains scale with longer time horizons and parallel hardware
Conclusion¶
This example demonstrates parareal's effectiveness on complex, multi-dimensional dynamic systems:
- Scalability: Works well with sophisticated Diffrax solvers
- Flexibility: Easy to swap different solver combinations (explicit/implicit, different orders)
- Performance: Achieves fine solver accuracy with improved computational efficiency
The combination of Parareax and Diffrax provides a powerful toolkit for solving challenging ODEs with parallel-in-time methods.
Next Steps¶
- Try different solver combinations (e.g.,
Euler()+Dopri8()) - Experiment with adaptive vs. fixed timesteps
- Explore stiff problems using implicit methods for both coarse and fine solvers
- Flax Integration: See how parareal works with neural differential equation solvers
- API Reference: Complete function documentation