Parareax: Parareal in JAX¶
A solver-agnostic JAX implementation of the parareal algorithm for time-parallel ODE solving.
Quick Links: Installation | Quick Start | Documentation
Parareal in a nutshell¶
Parareal is a parallel-in-time algorithm that combines a cheap coarse propagator \(G: (y_{n-1}, t_{n-1}, t_n) \mapsto y_n\) with an expensive fine propagator \(F: (y_{n-1}, t_{n-1}, t_n) \mapsto y_n\):
- Initialize using the coarse propagator to get \(y_1^0, y_2^0, \dots, y_n^0\)
- Iterate with the correction formula: \(\(y_n^k = G(y_{n-1}^k, t_{n-1}, t_n) + F(y_{n-1}^{k-1}, t_{n-1}, t_n) - G(y_{n-1}^{k-1}, t_{n-1}, t_n)\)\)
- Converge when the solution stabilizes or max iterations reached
The fine solver evaluations in step 2 can run in parallel, giving speedups while maintaining accuracy.
Installation¶
The package is not yet available on PyPI. To install:
Quick Start¶
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.experimental.ode import odeint
from parareax import run_parareal
# Define the differential equation: logistic equation dy/dt = y(1-y)
def f(y, t):
return y * (1 - y)
t0, t1 = 0, 1
y0 = jnp.array([1e-2])
# Define the coarse step function (simple Euler method)
@jax.jit
def coarse_step(y0, t0, t1):
dy = f(y0, t0)
return y0 + dy * (t1 - t0)
# Define the fine step function (using JAX's odeint)
@jax.jit
def fine_step(y0, t0, t1):
return odeint(lambda y, t: f(y, t), y0, jnp.array([t0, t1]))[-1]
# Solve using Parareal
ts = jnp.linspace(0, 1, 101) # 100 time intervals
solution, info = run_parareal(coarse_step, fine_step, y0=y0, ts=ts, tol=1e-14)
print("Parareal iterations:", info["iterations"])
# Compute comparison solutions using odeint
sol_fine = odeint(lambda y, t: f(y, t), y0, ts, rtol=1e-10, atol=1e-13)
# High-accuracy reference solution (using smaller internal steps)
ts_ref = jnp.linspace(t0, t1, 1001) # More time points for higher accuracy
ref = odeint(lambda y, t: f(y, t), y0, ts_ref, rtol=1e-12, atol=1e-15)
# Compare errors
print("Fine solver error:", sol_fine[-1] - ref[-1])
print("Parareal error:", solution[-1] - ref[-1])
Output:
Parareal converged in just 6 iterations, requiring only 6 sequential steps (with 100 parallel fine solver calls per step) compared to 100 sequential steps in the standard approach. This parallel structure can provide speedups on expensive problems.
API Reference¶
run_parareal(
coarse_step, # Function: (y_start, t_start, t_end) -> y_end
fine_step, # Function: (y_start, t_start, t_end) -> y_end
y0, # Initial state array
ts, # Time points array
maxiters=1000, # Maximum iterations
tol=1e-9 # Convergence tolerance
)
Returns:
ys: Solution array of shape(len(ts), len(y0))info: Dictionary with convergence information