Pairing Parareax with Flax: Neural Networks as Coarse Propagators¶
The parareal algorithm is remarkably flexible - the propagators don't need to be traditional ODE solvers but can be any function that advances a state forward in time. This opens exciting possibilities for hybrid approaches combining machine learning and physics-based methods.
Why Neural Networks in Parareal?¶
Neural networks can serve as coarse propagators in several scenarios:
- Learned dynamics: Networks trained on ODE solution data can provide fast approximations
- Physics-informed models: PINNs can encode known physics while learning complex behaviors
- Multi-fidelity modeling: Use cheap ML surrogates corrected by expensive simulations
What This Example Shows¶
In this notebook, we demonstrate this flexibility by using an untrained neural network as a coarse propagator. While this might seem counterintuitive, it showcases parareal's robustness:
- Even with a terrible coarse approximation, parareal's correction mechanism works
- The algorithm converges to the fine solver's accuracy regardless of coarse quality
- This proves the method's reliability for real-world hybrid ML-physics applications
Note: In practice, you'd train the neural network on relevant data to get both accuracy and speed benefits. Here we use an untrained network purely to demonstrate the algorithm's flexibility and error-correction capabilities.
Setup¶
First, let's import the necessary libraries:
import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from flax import linen as nn
from jax.experimental.ode import odeint
from parareax import run_parareal
# Enable double precision
jax.config.update("jax_enable_x64", True)
if not backend.has_been_selected:
backend.select("jax")
Define the Problem¶
For this example, we will solve the FitzHugh-Nagumo system, a classical model from neuroscience that describes the activation and deactivation of a spiking neuron,
provided by the diffeqzoo package:
The system consists of:
- V: Membrane voltage (fast variable)
- W: Recovery variable (slow variable)
f, y0, (t0, t1), args = ivps.fitzhugh_nagumo()
@jax.jit
def vf(y, t, p):
return f(y, *p)
Define Propagators¶
The parareal algorithm requires two propagators:
- A coarse propagator: Fast but less accurate
- A fine propagator: Slow but more accurate
class SimpleNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=16)(x)
x = nn.relu(x)
x = nn.Dense(features=len(y0))(x)
return nn.tanh(x)
key = jax.random.PRNGKey(42)
model = SimpleNN()
params = model.init(key, y0)
@jax.jit
def coarse_step(y0, t0, t1):
return y0 + (t1 - t0) * model.apply(params, y0)
In any realistic scenario, this network should of course actually be trained to produce good predictions, but for this simple proof-of-concept we will just use the untrained network.
Fine propagator: A standard Dormand-Prince Runge-Kutta method¶
@jax.jit
def fine_step(y0, t0, t1):
return odeint(vf, y0, jnp.array([t0, t1]), args)[-1]
Run the Parareal Algorithm¶
Now we'll solve our problem using the parareal algorithm:
# Define the time points
ts = jnp.linspace(t0, t1, 201)
# Solve using Parareal
tol = 1e-8
sol, info = run_parareal(coarse_step, fine_step, y0=y0, ts=ts, tol=tol)
# Print convergence information
print(f"Parareal converged in {info['iterations']} iterations")
print(f"Final error norm: {info['last_change_norm']:.2e}")
Parareal converged in 188 iterations Final error norm: 1.30e-14
Those are quite a lot of iterations! But this is expected since our untrained neural network essentially produces random output.
In practice with trained networks:
- Convergence typically occurs in 2-5 iterations
- Networks trained on ODE data can provide excellent coarse approximations
- The combination offers both speed and accuracy benefits
Let's check that despite the many iterations, we still get the correct result:
plt.plot(ts, sol);
This looks as expected! Let us also quickly evaluate the resulting errors.
from parareax.utils import stepper_to_solver
coarse_solver = stepper_to_solver(coarse_step)
fine_solver = stepper_to_solver(fine_step)
ref = diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, p: f(y, *p)),
diffrax.Dopri8(),
t0=t0,
t1=t1,
y0=y0,
dt0=1e-5,
max_steps=None,
saveat=diffrax.SaveAt(ts=ts),
stepsize_controller=diffrax.PIDController(rtol=1e-18, atol=1e-18),
args=args,
)
sol_coarse = coarse_solver(y0, ts)
sol_fine = fine_solver(y0, ts)
coarse_rmse = jnp.sqrt(jnp.mean((sol_coarse.ravel() - ref.ys.ravel()) ** 2))
fine_rmse = jnp.sqrt(jnp.mean((sol_fine.ravel() - ref.ys.ravel()) ** 2))
parareal_rmse = jnp.sqrt(jnp.mean((sol.ravel() - ref.ys.ravel()) ** 2))
print("\nAccuracy Comparison (RMSE vs reference):")
print(f" Neural network only: {coarse_rmse:.2e}")
print(f" Fine solver (Dormand-Prince): {fine_rmse:.2e}")
print(f" Parareal (hybrid): {parareal_rmse:.2e}")
print(
f"\nKey insight: Even with a terrible coarse propagator ({coarse_rmse:.1e} error),"
)
print(f"parareal achieves fine solver accuracy ({parareal_rmse:.1e} error)")
print(f"- Improvement over neural network: {coarse_rmse / parareal_rmse:.0f}x better")
print(f"- Comparable to fine solver: {parareal_rmse / fine_rmse:.1f}x its error")
Accuracy Comparison (RMSE vs reference): Neural network only: 5.09e+00 Fine solver (Dormand-Prince): 9.21e-07 Parareal (hybrid): 9.21e-07 Key insight: Even with a terrible coarse propagator (5.1e+00 error), parareal achieves fine solver accuracy (9.2e-07 error) - Improvement over neural network: 5529798x better - Comparable to fine solver: 1.0x its error
Perfect! As expected, parareal achieves the accuracy of the fine solver despite using a terrible neural network as the coarse propagator. This demonstrates the algorithm's robustness and potential for hybrid ML-physics approaches.
Conclusion¶
This example demonstrates the flexibility and robustness of the parareal algorithm:
Key Takeaways:¶
- ML-Physics Integration: Neural networks can serve as coarse propagators in parareal
- Error Correction: Even poor approximations get corrected to fine solver accuracy
- Broad Applicability: Opens doors for hybrid scientific computing approaches
Real-World Potential:¶
In practice, a trained neural network could provide:
- Fast approximations of complex dynamics
- Dramatic speedup for long-time integrations
- Learned representations of expensive physics simulations
This combination of machine learning and parallel-in-time methods represents an exciting frontier in computational science.
Next Steps¶
- Getting Started: Return to basics with traditional solvers
- Diffrax Integration: See parareal with sophisticated ODE solvers
- API Reference: Complete function documentation