Source code for fridom.framework.time_steppers.runge_kutta

"""Runge Kutta time stepping methods."""
from __future__ import annotations

from copy import deepcopy
from enum import Enum

import numpy as np

import fridom.framework as fr


[docs] class ButcherTableau: """ Butcher tableau for Runge-Kutta time stepping methods. Parameters ---------- A : np.ndarray Matrix of coefficients. b : np.ndarray Vector of coefficients. c : np.ndarray Vector of coefficients. """
[docs] def __init__(self, A: np.ndarray, b: np.ndarray, c: np.ndarray, b_error: np.ndarray | None = None) -> None: self.A = A self.b = b self.c = c self.b_error = b_error self.order = len(b)
[docs] class RKMethods(Enum): """Enumeration of Runge-Kutta methods.""" # ---------------------------------------------------------------- # Explicit Rung-Kutta methods (fixed time step) # ---------------------------------------------------------------- Euler = ButcherTableau( A = np.array([0]), b = np.array([1]), c = np.array([0]), ) RK2 = ButcherTableau( A = np.array([[0, 0], [1/2, 0]]), b = np.array([0, 1]), c = np.array([0, 1/2]), ) RK3 = ButcherTableau( A = np.array([[ 0, 0, 0], [1/2, 0, 0], [ -1, 2, 0]]), b = np.array([1/6, 2/3, 1/6]), c = np.array([ 0, 1/2, 1]), ) RK4 = ButcherTableau( A = np.array([[ 0, 0, 0, 0], [1/2, 0, 0, 0], [ 0, 1/2, 0, 0], [ 0, 0, 1, 0]]), b = np.array([1/6, 1/3, 1/3, 1/6]), c = np.array([ 0, 1/2, 1/2, 1]), ) RK4_38 = ButcherTableau( A = np.array([[ 0, 0, 0, 0], [ 1/3, 0, 0, 0], [-1/3, 1, 0, 0], [ 1, -1, 1, 0]]), b = np.array([1/8, 3/8, 3/8, 1/8]), c = np.array([ 0, 1/3, 2/3, 1]), ) # ---------------------------------------------------------------- # Adaptive Runge-Kutta methods # ---------------------------------------------------------------- HEUN_EULER = ButcherTableau( A = np.array([[0, 0], [1, 0]]), b = np.array([1/2, 1/2]), c = np.array([ 0, 1]), b_error=np.array([1/2, -1/2]), ) BOGACKI_SHAMPINE = ButcherTableau( A = np.array([[ 0, 0, 0, 0], [1/2, 0, 0, 0], [ 0, 3/4, 0, 0], [2/9, 1/3, 4/9, 0]]), b = np.array([7/24, 1/4, 1/3, 1/8]), c = np.array([ 0, 1/2, 3/4, 1]), b_error=np.array([2/9-7/24, 1/3-1/4, 4/9-1/3, -1/8]), ) RKF45 = ButcherTableau( A = np.array([[ 0, 0, 0, 0, 0, 0], [ 1/4, 0, 0, 0, 0, 0], [ 3/32, 9/32, 0, 0, 0, 0], [1932/2197, -7200/2197, 7296/2197, 0, 0, 0], [ 439/216, -8, 3680/513, -845/4104, 0, 0], [ -8/27, 2, -3544/2565, 1859/4104, -11/40, 0]]), b = np.array([16/135, 0, 6656/12825, 28561/56430, -9/50, 2/55]), c = np.array([ 0, 1/4, 3/8, 12/13, 1, 1/2]), b_error = np.array([-1/360, 0, 128/4275, 2197/75240, -1/50, -2/55]), )
@fr.utils.jaxjit def sum_product(coeefs, dt, k): return sum(coeefs[i] * dt * k[i] for i in range(len(k))) #TODO(Silvano): Jaxify this class
[docs] class RungeKutta(fr.time_steppers.TimeStepper): #TODO(Silvano): Add documentation name = "Runge-Kutta"
[docs] def __init__(self, dt: np.timedelta64 | float = 1, method: RKMethods = RKMethods.RK4, max_dt: np.timedelta64 | float | None = None, tol: float = 1e-6) -> None: super().__init__() self.method = method.value self.dt = dt self.max_dt = max_dt self.dz_list = None self.tol = tol
def _on_setup(self) -> None: self.dz_list = [self.mset.state_constructor() for _ in range(self.method.order)] def _calculate_tendency(self, mz: fr.ModelState) -> fr.VectorField: return self.mset.tendencies.update(mz).dz
[docs] @fr.modules.module_method def update(self, mz: fr.ModelState) -> fr.ModelState: """ Update the model state to the next time level. Parameters ---------- mz : ModelState Model state. """ method = self.method order = method.order # clone the clock clock = deepcopy(mz.clock) mod_state = fr.ModelState(self.mset, clock=clock) error = 1 while error > self.tol: k = [] dt = self.dt for i in range(order): mod_state.clock.tick(method.c[i] * dt) mod_state.z = mz.z + sum_product(method.A[i], dt, k) mod_state.dz = self.dz_list[i] dz = self._calculate_tendency(mod_state) k.append(dz) if method.b_error is not None: te = sum_product(method.b_error, dt, k) error = sum(f.norm_l2() for f in te.field_list) if self.max_dt is not None: self.dt = min(float(0.9 * dt * (self.tol / error) ** (1 / order)), self.max_dt) else: self.dt = float(0.9 * dt * (self.tol / error) ** (1 / order)) else: error = 0 mz.z += sum_product(method.b, dt, k) mz.clock.tick(dt) return mz
@property def max_dt(self) -> np.timedelta64: """Time step size.""" return self._max_dt @max_dt.setter def max_dt(self, value: np.timedelta64 | float | None) -> None: if value is None: self._max_dt = None return if isinstance(value, float | int): self._max_dt = value else: self._max_dt = fr.config.dtype_real(value / np.timedelta64(1, "s")) self.dt = self._max_dt