Source code for fridom.nonhydro.modules.advection.weno_advection

"""A WENO advection scheme following S. Mishra et al. (2021)."""
from __future__ import annotations

import fridom.framework as fr
import fridom.nonhydro as nh


[docs] @fr.utils.jaxify class WENO(nh.modules.advection.AdvectionBase): r""" Weighted Essentially Non-Oscillatory (WENO) advection scheme. Description ----------- This class implements the WENO advection scheme following S. Mishra et al. (2021). It is designed to handle discontinuities and sharp gradients in the solution while maintaining high accuracy. References ---------- .. [1] S. Mishra, C. Pares-Pulido, and K. G. Pressel, "Arbitrarily high-order (weighted) essentially non-oscillatory finite difference schemes for anelastic flows on staggered meshes" *Communications in Computational Physics*, 2021. """ name = "WENO Advection"
[docs] def __init__(self, order: int = 5, inter: fr.grid.InterpolationModule = None, weno: fr.grid.cartesian.InterWENO = None, flux_function: fr.modules.flux_functions.FluxFunctionBase = None, ) -> None: super().__init__() # check if the order is valid (only odd orders are allowed) if order % 2 == 0 or order < 1: msg = f"Invalid order {order}. Only odd orders >= 1 are allowed." raise ValueError(msg) self.order = order self.inter = inter or fr.grid.cartesian.PolynomialInterpolation(order=order) self.weno = weno or fr.grid.cartesian.InterWENO(order=order) self.flux_function = flux_function or fr.modules.flux_functions.Upwind()
def _on_setup(self) -> None: self.inter.setup(self.mset) self.weno.setup(self.mset) self.flux_function.setup(self.mset)
[docs] @fr.utils.jaxjit def advect_state(self, z: nh.State, dz: nh.State) -> nh.State: # noqa: D102 # Get the interpolation functions ip = self.inter.interpolate weno = self.weno.reconstruct flux_fun = self.flux_function.compute diff = self.diff_module # ---------------------------------------------------------------- # Momentum advection # ---------------------------------------------------------------- for v1 in z.velocity: for axis, v2 in enumerate(z.velocity): # Interpolate v2 to the position of v1 v2_at_v1 = ip(v2, v1.position) # Interpolate v1*v2 to the face of v1 using WENO flux_left, flux_right = weno(v1*v2_at_v1, v1.position.shift(axis=axis)) # Take the flux based on the advecting velocity # flux = flux_fun(flux_left=flux_left, # flux_right=flux_right, # velocity=v2_at_v1) flux = 0.5 * (flux_left + flux_right)# - 0.1 * (flux_right - flux_left) dz[v1.name] -= self.scaling * diff.diff(flux, axis=axis) # ---------------------------------------------------------------- # Tracer advection # ---------------------------------------------------------------- # interpolate the velocity to the cell centers vel_at_center = (ip(v, self.grid.cell_center) for v in z.velocity) for field in z.tracers: if field.flags["NO_ADV"]: continue for axis, v in enumerate(vel_at_center): # Interpolate the tracer to the face using WENO tracer_left, tracer_right = weno(field, field.position.shift(axis=axis)) # Take the tracer from the face based on the advecting velocity # tracer = flux_fun(flux_left=tracer_left, # flux_right=tracer_right, # velocity=v) tracer = 0.5 * (tracer_left + tracer_right) # Add the gradient to the field dz[field.name] -= self.scaling * diff.diff(tracer, axis=axis) * v return dz
@property def required_halo(self) -> int: """The required halo size based on the interpolation modules.""" return self.inter.required_halo + self.weno.required_halo