Source code for fridom.framework.time_steppers.runge_kutta

from copy import deepcopy
import fridom.framework as fr
import numpy as np
from enum import Enum
from typing import Union


[docs] class ButcherTableau: """ Butcher tableau for Runge-Kutta time stepping methods. Parameters ---------- `A` : `np.ndarray` Matrix of coefficients. `b` : `np.ndarray` Vector of coefficients. `c` : `np.ndarray` Vector of coefficients. """
[docs] def __init__(self, A, b, c, b_error=None): self.A = A self.b = b self.c = c self.b_error = b_error self.order = len(b) return
[docs] class RKMethods(Enum): """ Enumeration of Runge-Kutta methods. """ Euler = ButcherTableau( A = np.array([0]), b = np.array([1]), c = np.array([0]) ) RK2 = ButcherTableau( A = np.array([[0, 0], [1/2, 0]]), b = np.array([0, 1]), c = np.array([0, 1/2]) ) RK3 = ButcherTableau( A = np.array([[0, 0, 0], [1/2, 0, 0], [-1, 2, 0]]), b = np.array([1/6, 2/3, 1/6]), c = np.array([0, 1/2, 1]) ) RK4 = ButcherTableau( A = np.array([[0, 0, 0, 0], [1/2, 0, 0, 0], [0, 1/2, 0, 0], [0, 0, 1, 0]]), b = np.array([1/6, 1/3, 1/3, 1/6]), c = np.array([0, 1/2, 1/2, 1]) ) RK4_38 = ButcherTableau( A = np.array([[0, 0, 0, 0], [1/3, 0, 0, 0], [-1/3, 1, 0, 0], [1, -1, 1, 0]]), b = np.array([1/8, 3/8, 3/8, 1/8]), c = np.array([0, 1/3, 2/3, 1]) ) HEUN_EULER = ButcherTableau( A = np.array([[0, 0], [1, 0]]), b = np.array([1/2, 1/2]), c = np.array([0, 1]), b_error=np.array([1/2, -1/2]) ) BOGACKI_SHAMPINE = ButcherTableau( A = np.array([[0, 0, 0, 0], [1/2, 0, 0, 0], [0, 3/4, 0, 0], [2/9, 1/3, 4/9, 0]]), b = np.array([7/24, 1/4, 1/3, 1/8]), c = np.array([0, 1/2, 3/4, 1]), b_error=np.array([2/9-7/24, 1/3-1/4, 4/9-1/3, -1/8]) ) RKF45 = ButcherTableau( A = np.array([[0, 0, 0, 0, 0, 0], [1/4, 0, 0, 0, 0, 0], [3/32, 9/32, 0, 0, 0, 0], [1932/2197, -7200/2197, 7296/2197, 0, 0, 0], [439/216, -8, 3680/513, -845/4104, 0, 0], [-8/27, 2, -3544/2565, 1859/4104, -11/40, 0]]), b = np.array([16/135, 0, 6656/12825, 28561/56430, -9/50, 2/55]), c = np.array([0, 1/4, 3/8, 12/13, 1, 1/2]), b_error = np.array([-1/360, 0, 128/4275, 2197/75240, -1/50, -2/55]) )
@fr.utils.jaxjit def sum_product(coeefs, dt, k): return sum(coeefs[i] * dt * k[i] for i in range(len(k)))
[docs] class RungeKutta(fr.time_steppers.TimeStepper): name = "Runge-Kutta"
[docs] def __init__(self, dt: Union[np.timedelta64, float] = 1, method: RKMethods = RKMethods.RK4, max_dt: Union[np.timedelta64, float, None] = None, tol=1e-6): super().__init__() self.method = method.value self.dt = dt self.max_dt = max_dt self.dz_list = None self.tol = tol return
[docs] @fr.modules.module_method def setup(self, mset: 'fr.ModelSettingsBase') -> None: super().setup(mset) self.dz_list = [self.mset.state_constructor() for _ in range(self.method.order)] return
[docs] def calculate_tendency(self, mz: 'fr.ModelState') -> 'fr.StateBase': return self.mset.tendencies.update(mz).dz
[docs] @fr.modules.module_method def update(self, mz: 'fr.ModelState') -> None: """ Update the model state to the next time level. Parameters ---------- `mz` : `ModelState` Model state. """ method = self.method order = method.order # clone the clock clock = deepcopy(mz.clock) mod_state = fr.ModelState(self.mset, clock=clock) error = 1 while error > self.tol: k = [] dt = self.dt for i in range(order): mod_state.clock.tick(method.c[i] * dt) mod_state.z = mz.z + sum_product(method.A[i], dt, k) mod_state.dz = self.dz_list[i] dz = self.calculate_tendency(mod_state) k.append(dz) if method.b_error is not None: te = sum_product(method.b_error, dt, k) error = sum(f.norm_l2() for f in te.field_list) if self.max_dt is not None: self.dt = min(float(0.9 * dt * (self.tol / error) ** (1 / order)), self.max_dt) else: self.dt = float(0.9 * dt * (self.tol / error) ** (1 / order)) else: error = 0 mz.z += sum_product(method.b, dt, k) mz.clock.tick(dt) mz.it += 1 return mz
@property def max_dt(self) -> np.timedelta64: """ Time step size. """ return self._max_dt @max_dt.setter def max_dt(self, value: Union[np.timedelta64, float, None]) -> None: if value is None: self._max_dt = None return if isinstance(value, float) or isinstance(value, int): self._max_dt = value else: self._max_dt = fr.config.dtype_real(value / np.timedelta64(1, 's')) self.dt = self._max_dt