Source code for fridom.shallowwater.modules.advection.spectral_advection
"""Spectral advection scheme for the shallow water equations."""
from __future__ import annotations
import fridom.framework as fr
import fridom.shallowwater as sw
[docs]
@fr.utils.jaxify
class SpectralAdvection(fr.modules.advection.AdvectionBase):
r"""
Spectral advection scheme for the shallow water equations.
Parameters
----------
padding : fr.grid.FFTPadding
Padding to use for the FFT operations.
"""
name = "Spectral Advection"
[docs]
def __init__(self, padding: fr.grid.FFTPadding = fr.grid.FFTPadding.TRIM) -> None:
super().__init__()
self.padding = padding
def _on_setup(self) -> None:
if hasattr(self.mset, "Ro"):
self.scaling = self.mset.Ro
[docs]
@fr.utils.jaxjit
def advect_state(self, z: sw.State, dz: sw.State) -> sw.State: # noqa: D102
def fft(x: fr.FieldBase) -> fr.FieldBase:
return x.fft(padding=self.padding)
def ifft(x: fr.FieldBase) -> fr.FieldBase:
return x.ifft(padding=self.padding)
grad = self.diff_module.grad
divergence = self.diff_module.div
vel_p = ifft(z.velocity)
dz.u -= self.scaling * fft(vel_p @ ifft(grad(z.u)))
dz.v -= self.scaling * fft(vel_p @ ifft(grad(z.v)))
dz.p -= self.scaling * divergence(fft(vel_p * ifft(z.p)))
for field in z.tracers:
if field.flags["NO_ADV"]:
continue
dz[field.name] -= self.scaling * fft(vel_p @ ifft(grad(field)))
return dz