run_parareal
Parareal Algorithm for Time-Parallel State Propagation.
run_parareal(coarse_step, fine_step, y0, ts, maxiters=1000, tol=1e-09)
¶
Run the parareal algorithm for time-parallel state propagation.
The parareal algorithm is a parallel-in-time method for propagating a state through a sequence of time points. While commonly used for differential equations, this implementation is general and can work with any state propagation functions.
It uses two state propagation functions:
- A coarse step function \(G(y_n, t_n, t_{n+1})\) which is computationally efficient but potentially less accurate
- A fine step function \(F(y_n, t_n, t_{n+1})\) which is more accurate but computationally expensive
The algorithm works as follows:
- Initialize using the coarse propagator: $$ y_n^0 = G(y_{n-1}^0, t_{n-1}, t_n) \quad \text{for} \quad n=1...N $$
- For each iteration k=1,2,...: $$ y_n^k = G(y_{n-1}^k, t_{n-1}, t_n) + F(y_{n-1}^{k-1}, t_{n-1}, t_n) - G(y_{n-1}^{k-1}, t_{n-1}, t_n) $$
- Terminate when \(||y^k - y^{k-1}|| < tol\) or \(k\) reaches
maxiters
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coarse_step
|
Callable[[ndarray, float, float], ndarray]
|
Coarse propagator G(y, t_start, t_end) that takes a state vector, start time, and end time, returning the propagated state. Should be computationally efficient. |
required |
fine_step
|
Callable[[ndarray, float, float], ndarray]
|
Fine propagator F(y, t_start, t_end) that takes a state vector, start time, and end time, returning the propagated state. Can be computationally expensive but should be more accurate than the coarse step. |
required |
y0
|
ndarray
|
Initial state vector at time ts[0]. |
required |
ts
|
ndarray
|
Array of time points at which the solution is desired, must be strictly increasing. |
required |
maxiters
|
int
|
Maximum number of parareal iterations. |
1000
|
tol
|
float
|
Convergence tolerance, measured as mean normalized difference between successive iterations. |
1e-9
|
Returns:
ys : jnp.ndarray Solution array of shape (len(ts), len(y0)) containing the state at each time point. info : Dict[str, Union[int, float]] Dictionary containing convergence information: - 'iterations': Number of iterations performed - 'last_change_norm': Final change norm between successive iterations
Source code in parareax/parareax.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | |