Source code for fridom.nonhydro.modules.advection.spectral_advection
"""Spectral advection for the nonhydrostatic model."""
import fridom.framework as fr
import fridom.nonhydro as nh
[docs]
@fr.utils.jaxify
class SpectralAdvection(fr.modules.advection.AdvectionBase):
"""
Advection Scheme for the spectral grid.
Parameters
----------
padding : fr.grid.FFTPadding
Padding to use for the FFT operations.
"""
name = "Spectral Advection"
[docs]
def __init__(self, padding: fr.grid.FFTPadding = fr.grid.FFTPadding.TRIM) -> None:
super().__init__()
self.padding = padding
def _on_setup(self) -> None:
if hasattr(self.mset, "Ro"):
self.scaling = self.mset.Ro
[docs]
@nh.utils.jaxjit
def advect_state(self, z: nh.State, dz: nh.State) -> nh.State: # noqa: D102
divergence = self.diff_module.div
padding = self.padding
zp = z.ifft(padding=padding)
uu, uv, uw = (zp.u * zp.velocity).fft(padding=padding)
vu, vv, vw = uv, *(zp.v * zp.velocity[1:]).fft(padding=padding)
wu, wv, ww = uw, vw, (zp.w * zp.w).fft(padding=padding)
dz.u -= self.scaling * divergence([uu, vu, wu])
dz.v -= self.scaling * divergence([uv, vv, wv])
dz.w -= self.scaling * divergence([uw, vw, ww])
# now do all the tracer advection
for field in zp.tracers:
if field.flags["NO_ADV"]:
continue
flux = (field * zp.velocity).fft(padding=padding)
dz[field.name] -= self.scaling * divergence(flux)
return dz
# ================================================================
# Properties
# ================================================================
@property
def padding(self) -> fr.grid.FFTPadding:
"""Padding to use for the FFT operations."""
return self._padding
@padding.setter
def padding(self, padding: fr.grid.FFTPadding) -> None:
self._padding = padding