Source code for fridom.shallowwater.modules.linear_tendency
"""Linear tendency module for the shallow water model."""
from __future__ import annotations
from functools import partial
import fridom.framework as fr
import fridom.shallowwater as sw
[docs]
@partial(fr.utils.jaxify, dynamic=("csqr", "f_coriolis"))
class LinearTendency(fr.modules.Module):
r"""
Computes the linear tendency of the shallow water model.
The linear tendency is given by:
.. math::
\partial_t \boldsymbol{u} = f \underset{\neg}{\boldsymbol{v}} - \nabla p
~, \quad
\partial_t p = -c^2 \nabla \cdot \boldsymbol{u}
"""
name = "Linear Tendency"
def _on_setup(self) -> None:
self.f_coriolis = self.mset.f_coriolis
self.csqr = self.mset.csqr_field
[docs]
@fr.modules.module_method
def update(self, mz: fr.ModelState) -> fr.ModelState: # noqa: D102
mz.dz = self.linear_tendency(mz.z, mz.dz)
return mz
[docs]
@fr.utils.jaxjit
def linear_tendency(self, z: sw.State, dz: sw.State) -> sw.State:
"""Compute the linear tendency term."""
interp = self.interp_module.interpolate
diff = self.diff_module.diff
div = self.diff_module.div
# positions
upos = z.u.position
vpos = z.v.position
c = self.csqr
# interpolate the coriolis parameter to the u position
f = interp(self.f_coriolis, z.u.position)
# calculate u-tendency
dz.u += interp(z.v, upos) * f - diff(z.p, axis=0)
dz.v += - interp(z.u * f, vpos) - diff(z.p, axis=1)
dz.p += - div((interp(c, upos) * z.u, interp(c, vpos) * z.v))
return dz