Source code for fridom.framework.grid.water_mask

import fridom.framework as fr
from numpy import ndarray
import itertools
from functools import partial


[docs] @partial(fr.utils.jaxify, dynamic=('_water_mask', '_cache')) class WaterMask: """ Water mask for the grid cells (for boundary conditions). Description ----------- Let's consider the following staggered grid with periodic boundaries: :: -----e---------e---------e----- | | | | | x o x o x o | | | | -----e---------e---------e----- | | | | | (x) o (x) o x o | | | | -----e---------e---------e----- | | | | | (x) o x o x o | | | | ------------------------------- where `x` represents the cell center and `o` and `e` represent the cell faces. Grid cells denoted with `(x)` are the cells with land. The water mask is a boolean array that indicates whether a cell is water (1) or land (0). For the above grid, the water masks would be: :: [1, 1, 1] [1, 1, 1] [0, 1, 1] x = [0, 0, 1] o = [0, 0, 0] e = [0, 0, 1] [0, 1, 1] [0, 1, 0] [0, 0, 1] """
[docs] def __init__(self): self.name = "Water Mask" self._water_mask = None self._cache = {} self._domain_decomposition: fr.domain_decomposition.DomainDecomposition = None self._periodic_bounds = None return
[docs] def setup(self, mset: fr.ModelSettingsBase) -> None: # we can't set mset or grid as attributes due to recursion issues # with jaxjit, so we only set the attributes we need self._domain_decomposition = mset.grid.domain_decomp self._periodic_bounds = mset.grid.periodic_bounds self.water_mask = (mset.grid.domain_decomp.create_array(pad=True)+1).astype(bool) # self.water_mask = fr.config.ncp.ones(mset.grid.X[0].shape, dtype=bool) return
[docs] def get_mask(self, position: fr.grid.Position) -> ndarray: """ Get the water mask at the given position. """ id = hash(position) if id not in self._cache: self._cache[id] = self.create_mask_at_position(position) return self._cache[id]
[docs] def create_mask_at_position(self, position: fr.grid.Position) -> ndarray: """ Create a water mask at the given position. """ new_mask = self._water_mask for axis, axpos in enumerate(position.positions): new_mask = self.shift_mask_along_axis(new_mask, axis, axpos) return new_mask
[docs] def shift_mask_along_axis( self, mask: ndarray, axis: int, axpos: fr.grid.AxisPosition) -> ndarray: """ Shift the mask along the given axis to the new position. Description ----------- Let's say we have a mask given at the cell centers, denoted with `x` below. And we want to shift the mask to the right, denoted with `|` below. The overline represents the position of land in the mask. :: _______ _______ x | x | x | x | x | x | Hence the water mask would be: :: [0, 1, 1, 0, 1, 1] The new mask at the right position is only water if both neighboring cells are water. Hence the new mask at the right cell faces would be: :: [0, 1, 0, 0, 1, ?] where `?` depends on the neighboring cells. To find the new mask algorithmically, we first determine the left and right cell centers of the corresponding cell face. And then we check if both neighboring cells are water by multiplying the left and right cell centers. Finally we synchronize the mask across the processors and fill the halo cells with land. Parameters ---------- mask : ndarray The mask to shift (located at the center at the given axis) axis : int The axis along which the mask should be shifted axpos : fr.grid.AxisPosition The new position of the mask along the axis """ match axpos: case fr.grid.AxisPosition.CENTER: # nothing to do new_mask = mask case fr.grid.AxisPosition.FACE: # find out left and right side of the mask @self._domain_decomposition.shard_map def roll(arr): left_side = arr right_side = fr.config.ncp.roll(arr, -1, axis) # both sides must be water (True) for the new mask to be water return right_side * left_side new_mask = roll(mask) new_mask = self._sync_mask(new_mask) return new_mask
def _sync_mask(self, mask: ndarray) -> ndarray: """ Synchronize the mask across the processors. """ # for some reason, sync does not work on boolean arrays # so we convert the mask to integers, sync it and then convert it back mask = mask.astype(int) mask = self._domain_decomposition.sync(mask) mask = mask.astype(bool) return mask @property def water_mask(self) -> ndarray: """ Get the water mask. """ return self._water_mask @water_mask.setter def water_mask(self, mask: ndarray) -> None: mask = self._sync_mask(mask) self._water_mask = mask # clear the cache self._cache = {} # construct all possible masks ndim = mask.ndim CENTER = fr.grid.AxisPosition.CENTER; FACE = fr.grid.AxisPosition.FACE for position in itertools.product([CENTER, FACE], repeat=ndim): self.get_mask(fr.grid.Position(position)) return