Source code for fridom.framework.time_steppers.adam_bashforth

"""Adam Bashforth time stepping up to 4th order."""
from __future__ import annotations

from functools import partial

import numpy as np

import fridom.framework as fr

MAX_ORDER = 4

[docs] @partial(fr.utils.jaxify, dynamic=("dz_list", "pointer", "it_count", "coeff_AB", "coeffs")) class AdamBashforth(fr.time_steppers.TimeStepper): r""" Adam Bashforth time stepping up to 4th order. Parameters ---------- dt : float Time step size. (default 0.01) order : int Order of the time stepping. (default 3, max 4) eps : float 2nd order bashforth correction. (default 0.01) Description ----------- The Adam Bashforth time stepping scheme is a multi-step explicit time stepping scheme. It solves a given PDE .. math:: \partial_t \boldsymbol{z} = \boldsymbol{F}(\boldsymbol{z}, t) by using the following scheme of order :math:`n` .. math:: \boldsymbol{z}^{n+1} = \boldsymbol{z}^n + \Delta t \sum_{j=0}^{n-1} \alpha_j \boldsymbol{F}(\boldsymbol{z}^{n-j}, t^{n-j}) where :math:`\alpha_i` are the Adam Bashforth coefficients, :math:`\Delta t` is the time step size, :math:`\boldsymbol{z}^j` is the state at time :math:`t^j = t_0 + j \Delta t`. The coefficients for orders 1 to 4 are given in the table below. +-------+-------------------+-------------------+-------------------+-------------------+ | Order | :math:`\alpha_1` | :math:`\alpha_2` | :math:`\alpha_3` | :math:`\alpha_4` | +=======+===================+===================+===================+===================+ | 1 | 1 | | | | +-------+-------------------+-------------------+-------------------+-------------------+ | 2 | 3/2 + \epsilon | -1/2 - \epsilon | | | +-------+-------------------+-------------------+-------------------+-------------------+ | 3 | 23/12 | -4/3 | 5/12 | | +-------+-------------------+-------------------+-------------------+-------------------+ | 4 | 55/24 | -59/24 | 37/24 | -3/8 | +-------+-------------------+-------------------+-------------------+-------------------+ Stability Analysis ****************** Let :math:`\lambda` be the eigenvalues of the right-hand side of the PDE, e.g: .. math:: \partial_t \boldsymbol{z} = \boldsymbol{F}(\boldsymbol{z}, t) = -i \lambda \boldsymbol{z} Inserting this into the Adam Bashforth scheme gives: .. math:: \boldsymbol{z}^{n+1} = \sum_{j=0}^{n-1} c_j \boldsymbol{z}^{n-j} where .. math:: c_j = \begin{cases} 1 - i \Delta t \lambda & \text{if } j = 0 \\ -i \Delta t \lambda & \text{if } j > 0 \end{cases} We now insert the Ansatz: .. math:: \boldsymbol{z}^n = \boldsymbol{z}_0 e^{-i \omega n \Delta t} = \boldsymbol{z}_0 x^n with :math:`x = e^{-i \omega \Delta t}`. This yields a polynomial equation for :math:`x`: .. math:: x^{n+1} = \sum_{j=0}^{n-1} c_j x^{n-j} Finally, we find the eigenvalues of the time stepping scheme by solving the polynomial equation for :math:`x` numerically and taking the logarithm: .. math:: \omega = -i \log(x) / \Delta t """ name = "Adam Bashforth"
[docs] def __init__(self, dt: float = 1, order: int = 3, eps: float=0.01) -> None: # check that the order is not too high if order > MAX_ORDER: msg = f"Only support orders up to {MAX_ORDER}." raise ValueError(msg) super().__init__() self.order = order self.eps = eps self.AB1 = [1] self.AB2 = [3/2 + eps, -1/2 - eps] self.AB3 = [23/12, -4/3, 5/12] self.AB4 = [55/24, -59/24, 37/24, -3/8] self.it_count = None self.dt = dt
def _on_setup(self) -> None: ncp = fr.config.ncp dtype = fr.config.dtype_real # Adam Bashforth coefficients including time step size self.coeffs = [ ncp.asarray(self.AB1, dtype=dtype) * self.dt, ncp.asarray(self.AB2, dtype=dtype) * self.dt, ncp.asarray(self.AB3, dtype=dtype) * self.dt, ncp.asarray(self.AB4, dtype=dtype) * self.dt, ] self.coeff_AB = ncp.zeros(self.order, dtype=dtype) # pointers self.pointer = np.arange(self.order, dtype=ncp.int32) # tendencies self.dz_list = [self.mset.state_constructor() for _ in range(self.order)] self.it_count = 0 def _on_reset(self) -> None: self._on_setup() @fr.utils.jaxjit def _update_state(self, z: fr.VectorField, dz_list: list[fr.VectorField], ) -> fr.VectorField: """ Jax jitted time stepping function for Adam-Bashforth. Parameters ---------- z : State The state at the current time level. dz_list : list[State] List of tendency terms at previous time levels. Returns ------- State : The updated state. """ for i in range(len(dz_list)): # loop over all time levels z += dz_list[i] * self.coeff_AB[i] return z
[docs] @fr.modules.module_method def update(self, mz: fr.ModelState) -> fr.ModelState: """Update the time stepper.""" self._update_tendency() mz.dz = self.dz mz = self.mset.tendencies.update(mz) self.dz = mz.dz dz_list = [self.dz_list[p] for p in self.pointer] mz.z = self._update_state(mz.z, dz_list) self.it_count += 1 mz.clock.tick(self.dt) return mz
def _update_tendency(self) -> None: if self.it_count <= self.order+1: self.update_coeff_AB() self.update_pointer()
[docs] def update_pointer(self) -> None: """Update pointer for Adam-Bashforth time stepping.""" self.pointer = np.roll(self.pointer, 1)
[docs] def update_coeff_AB(self) -> None: """Upward ramping of Adam-Bashforth coefficients after restart.""" # current time level (ctl) # maximum ctl is the number of time levels - 1 ctl = min(self.it_count, self.order-1) # list of Adam-Bashforth coefficients coeffs = self.coeffs # choose Adam-Bashforth coefficients of current time level self.coeff_AB = fr.utils.modify_array(self.coeff_AB, slice(None), 0) self.coeff_AB = fr.utils.modify_array(self.coeff_AB, slice(ctl+1), coeffs[ctl])
[docs] def time_discretization_effect(self, omega: np.ndarray) -> np.ndarray: # noqa: D102 # shorthand notation ncp = fr.config.ncp # cast omega to ndarray omega = ncp.array(omega) # get adam-bashforth coefficients ab_coefficients = [self.AB1, self.AB2, self.AB3, self.AB4] # get the coefficients for the current time level coeff = ncp.array(ab_coefficients[self.order-1]) # construct polynomial coefficients for each grid point # tile the array such that coeff and omega have the same shape new_shape = (*tuple(omega.shape), 1) coeff = ncp.tile(coeff, new_shape) omega = omega[..., ncp.newaxis] # calculate the polynomial coefficients coeff = ncp.multiply(omega, coeff) * 1j * self.dt # subtract 1 from the last coefficient last_col = (..., 0) coeff = fr.utils.modify_array(coeff, last_col, coeff[last_col] - 1) # leading coefficient is 1 paddings = [(0,0)] * len(coeff.shape) paddings[-1] = (1,0) coeff = ncp.pad(coeff, paddings, "constant", constant_values=(1,0)) # reverse the order of the coefficients coeff = coeff[..., ::-1] def find_roots(c: np.ndarray) -> complex: """ Find the last root of the polynomial. Parameters ---------- c : ndarray Polynomial coefficients. Returns ------- complex Last root of the polynomial. """ return np.roots(c)[-1] # find the roots of the polynomial # root finding only works on the CPU coeff = fr.utils.to_numpy(coeff) roots = ncp.array(np.apply_along_axis(find_roots, -1, coeff)) return -1j * ncp.log(roots) / self.dt
# ================================================================ # Properties # ================================================================ def _on_time_step_change(self) -> None: if not self.is_setup: return # we need to call the setup method again when the time step is changed self._on_setup() @property def info(self) -> dict: # noqa: D102 second_order = 2 res = super().info res["order"] = self.order if self.order == second_order: res["eps"] = self.eps return res @property def dz(self) -> fr.VectorField: """Pointer on the current tendency term.""" return self.dz_list[self.pointer[0]] @dz.setter def dz(self, value: fr.VectorField) -> None: self.dz_list[self.pointer[0]] = value