Skip to content

run_parareal

Parareal Algorithm for Time-Parallel State Propagation.

run_parareal(coarse_step, fine_step, y0, ts, maxiters=1000, tol=1e-09)

Run the parareal algorithm for time-parallel state propagation.

The parareal algorithm is a parallel-in-time method for propagating a state through a sequence of time points. While commonly used for differential equations, this implementation is general and can work with any state propagation functions.

It uses two state propagation functions:

  • A coarse step function \(G(y_n, t_n, t_{n+1})\) which is computationally efficient but potentially less accurate
  • A fine step function \(F(y_n, t_n, t_{n+1})\) which is more accurate but computationally expensive

The algorithm works as follows:

  1. Initialize using the coarse propagator: $$ y_n^0 = G(y_{n-1}^0, t_{n-1}, t_n) \quad \text{for} \quad n=1...N $$
  2. For each iteration k=1,2,...: $$ y_n^k = G(y_{n-1}^k, t_{n-1}, t_n) + F(y_{n-1}^{k-1}, t_{n-1}, t_n) - G(y_{n-1}^{k-1}, t_{n-1}, t_n) $$
  3. Terminate when \(||y^k - y^{k-1}|| < tol\) or \(k\) reaches maxiters

Parameters:

Name Type Description Default
coarse_step Callable[[ndarray, float, float], ndarray]

Coarse propagator G(y, t_start, t_end) that takes a state vector, start time, and end time, returning the propagated state. Should be computationally efficient.

required
fine_step Callable[[ndarray, float, float], ndarray]

Fine propagator F(y, t_start, t_end) that takes a state vector, start time, and end time, returning the propagated state. Can be computationally expensive but should be more accurate than the coarse step.

required
y0 ndarray

Initial state vector at time ts[0].

required
ts ndarray

Array of time points at which the solution is desired, must be strictly increasing.

required
maxiters int

Maximum number of parareal iterations.

1000
tol float

Convergence tolerance, measured as mean normalized difference between successive iterations.

1e-9
Returns:

ys : jnp.ndarray Solution array of shape (len(ts), len(y0)) containing the state at each time point. info : Dict[str, Union[int, float]] Dictionary containing convergence information: - 'iterations': Number of iterations performed - 'last_change_norm': Final change norm between successive iterations

Source code in parareax/parareax.py
@partial(jax.jit, static_argnames=("coarse_step", "fine_step", "maxiters"))
def run_parareal(
    coarse_step: Callable[[jnp.ndarray, float, float], jnp.ndarray],
    fine_step: Callable[[jnp.ndarray, float, float], jnp.ndarray],
    y0: jnp.ndarray,
    ts: jnp.ndarray,
    maxiters: int = 1000,
    tol: float = 1e-9,
) -> tuple[jnp.ndarray, dict[str, float]]:
    r"""Run the parareal algorithm for time-parallel state propagation.

    The parareal algorithm is a parallel-in-time method for propagating a state through
    a sequence of time points. While commonly used for differential equations, this
    implementation is general and can work with any state propagation functions.

    It uses two state propagation functions:

    - A coarse step function $G(y_n, t_n, t_{n+1})$ which is computationally efficient
      but potentially less accurate
    - A fine step function $F(y_n, t_n, t_{n+1})$ which is more accurate but
      computationally expensive

    The algorithm works as follows:

    1. Initialize using the coarse propagator:
       $$
       y_n^0 = G(y_{n-1}^0, t_{n-1}, t_n) \\quad \\text{for} \\quad n=1...N
       $$
    2. For each iteration k=1,2,...:
       $$
       y_n^k = G(y_{n-1}^k, t_{n-1}, t_n) + F(y_{n-1}^{k-1}, t_{n-1}, t_n)
               - G(y_{n-1}^{k-1}, t_{n-1}, t_n)
       $$
    3. Terminate when $||y^k - y^{k-1}|| < tol$ or $k$ reaches `maxiters`

    Parameters
    ----------
    coarse_step : Callable[[jnp.ndarray, float, float], jnp.ndarray]
        Coarse propagator G(y, t_start, t_end) that takes a state vector, start time,
        and end time, returning the propagated state. Should be computationally
        efficient.
    fine_step : Callable[[jnp.ndarray, float, float], jnp.ndarray]
        Fine propagator F(y, t_start, t_end) that takes a state vector, start time,
        and end time, returning the propagated state. Can be computationally expensive
        but should be more accurate than the coarse step.
    y0 : jnp.ndarray
        Initial state vector at time ts[0].
    ts : jnp.ndarray
        Array of time points at which the solution is desired, must be strictly
        increasing.
    maxiters : int, default=1000
        Maximum number of parareal iterations.
    tol : float, default=1e-9
        Convergence tolerance, measured as mean normalized difference between successive
        iterations.

    Returns:
    -------
    ys : jnp.ndarray
        Solution array of shape (len(ts), len(y0)) containing the state at each time
        point.
    info : Dict[str, Union[int, float]]
        Dictionary containing convergence information:
        - 'iterations': Number of iterations performed
        - 'last_change_norm': Final change norm between successive iterations
    """
    N_intervals = len(ts) - 1

    # initial iterative application of the coarse solver
    def scan_fn(
        y_prev: jnp.ndarray, t_pair: jnp.ndarray
    ) -> tuple[jnp.ndarray, jnp.ndarray]:
        t_current, t_next = t_pair[0], t_pair[1]
        y_next = coarse_step(y_prev, t_current, t_next)
        return y_next, y_next

    t_pairs = jnp.stack([ts[:-1], ts[1:]], axis=1)
    _, ys_scan = jax.lax.scan(scan_fn, y0, t_pairs)
    ys = jnp.concatenate([jnp.expand_dims(y0, 0), ys_scan])

    def parareal_step(ts: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
        N = len(ts)
        ys_fine = jnp.concatenate(
            (
                ys[:1],
                jax.vmap(lambda i: fine_step(ys[i], ts[i], ts[i + 1]))(
                    jnp.arange(N - 1)
                ),
            )
        )

        def f(carry: jnp.ndarray, j: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
            y = carry
            ynew = (
                coarse_step(y, ts[j], ts[j + 1])
                + ys_fine[j + 1]
                - coarse_step(ys[j], ts[j], ts[j + 1])
            )
            return ynew, ynew

        return jnp.concatenate((ys[:1], jax.lax.scan(f, ys[0], jnp.arange(N - 1))[1]))

    def _norm(x: jnp.ndarray) -> jnp.ndarray:
        return jax.vmap(jnp.linalg.norm)(x).mean() / jnp.sqrt(x.shape[1])

    def cond(
        val: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
    ) -> jnp.ndarray:
        _ys_before, _ys_after, diff, k = val
        return jnp.logical_and(diff > tol, k < min(maxiters, N_intervals))

    def body(
        val: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
    ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        ys_before, _, _, k = val
        ys_after = parareal_step(ts, ys_before)
        diff = _norm(ys_before - ys_after)
        return ys_after, ys_before, diff, k + 1

    ys, _, diff, k = jax.lax.while_loop(
        cond, body, (ys, jnp.zeros_like(ys), jnp.array(1e9), 1)
    )

    info_dict = {"iterations": k, "last_change_norm": diff}

    return ys, info_dict