Source code for fridom.framework.grid.cartesian.finite_differences
from copy import deepcopy
import fridom.framework as fr
from functools import partial
[docs]
@partial(fr.utils.jaxify, dynamic=('_dx1', 'water_mask'))
class FiniteDifferences(fr.grid.DiffModule):
name = "Finite Differences"
[docs]
def __init__(self) -> None:
super().__init__()
# ----------------------------------------------------------------
# Set attributes
# ----------------------------------------------------------------
self.required_halo = 1
self._dx1 = None
self.water_mask = None
[docs]
@fr.modules.module_method
def setup(self, mset: 'fr.ModelSettingsBase') -> None:
super().setup(mset)
from .grid import Grid
if not isinstance(self.mset.grid, Grid):
raise ValueError("Finite differences only work with Cartesian grids.")
conf = fr.config
self._dx1 = 1 / conf.ncp.array(self.mset.grid.dx, dtype=conf.dtype_real)
self.water_mask = self.mset.grid.water_mask
return
[docs]
@partial(fr.utils.jaxjit, static_argnames=('axis', 'order'))
def diff(self,
f: fr.ScalarField,
axis: int,
order: int = 1) -> fr.ScalarField:
# differentiate the field
match f.position[axis]:
case fr.grid.AxisPosition.CENTER:
f = self._diff_forward(f, axis)
case fr.grid.AxisPosition.FACE:
f = self._diff_backward(f, axis)
# check if we need to differentiate more
if order == 1:
return f
else:
return self.diff(f, axis, order-1)
@partial(fr.utils.jaxjit, static_argnames=('axis',))
def _diff_forward(self,
f: fr.ScalarField,
axis: int) -> fr.ScalarField:
res = fr.ScalarField(mset=f.mset, mdata=deepcopy(f.mdata))
new_pos = f.position.shift(axis)
mask = self.water_mask.get_mask(new_pos)
next = tuple(slice(1, None) if i == axis else slice(None)
for i in range(f.arr.ndim))
prev = tuple(slice(None, -1) if i == axis else slice(None)
for i in range(f.arr.ndim))
@self.grid.domain_decomp.shard_map
def _diff(arr):
diff = (arr[next] - arr[prev]) * self._dx1[axis]
return fr.utils.modify_array(arr, prev, diff)
res.arr = _diff(f.arr) * mask
res.position = new_pos
return res
@partial(fr.utils.jaxjit, static_argnames=('axis',))
def _diff_backward(self,
f: fr.ScalarField,
axis: int) -> fr.ScalarField:
res = fr.ScalarField(mset=f.mset, mdata=deepcopy(f.mdata))
new_pos = f.position.shift(axis)
mask = self.water_mask.get_mask(new_pos)
next = tuple(slice(1, None) if i == axis else slice(None)
for i in range(f.arr.ndim))
prev = tuple(slice(None, -1) if i == axis else slice(None)
for i in range(f.arr.ndim))
@self.grid.domain_decomp.shard_map
def _diff(arr):
diff = (arr[next] - arr[prev]) * self._dx1[axis]
return fr.utils.modify_array(arr, next, diff)
res.arr = _diff(f.arr) * mask
res.position = new_pos
return res