Source code for fridom.framework.grid.grid_base

import fridom.framework as fr
from numpy import ndarray
from abc import abstractmethod
from functools import partial


[docs] @partial(fr.utils.jaxify, dynamic=('_X', '_x_global', '_K', '_k_global')) class GridBase: """ Base class for all grids in the framework. Description ----------- This class does not implement any functionality, but provides a template for all grid classes in the framework. The base class also sets default flags and attributes that are common to all grids. Child classes should override these flags and attributes as needed. Flags ----- `fourier_transform_available` : `bool` Indicates whether the grid supports fast fourier transforms. `mpi_available` : `bool` Indicates whether the grid supports MPI parallelization. """
[docs] def __init__(self, n_dims: int) -> None: self.name = "GridBase" self._n_dims = n_dims self._N = None self._L = None self._total_grid_points = None self._periodic_bounds = None self._X = None self._x_global = None self._x_local = None self._dx = None self._dV = None self._mset = None self._water_mask = fr.grid.WaterMask() # The domain decomposition self._domain_decomposition = None # The cell center CENTER = fr.grid.AxisPosition.CENTER self._cell_center = fr.grid.Position(tuple([CENTER] * n_dims)) # spectral properties self._K = None self._k_global = None self._k_local = None self._omega_analytical = None self._omega_space_discrete = None self._omega_time_discrete = None # operator modules self._diff_module: fr.grid.DiffModule = None self._interp_module: fr.grid.InterpolationModule = None # prepare for numpy conversion (the numpy copy will be stored here) self._cpu = None # --------------------------------------------------------------------- # Set default flags # --------------------------------------------------------------------- self._fourier_transform_available = False self._mpi_available = False self._spectral_grid = False return
[docs] def setup(self, mset: fr.ModelSettingsBase) -> None: """ Initialize the grid from the model settings. Parameters ---------- `mset` : `ModelSettingsBase` The model settings object. This is for example needed to determine the required halo size. """ self._diff_module.setup(mset=mset) self._interp_module.setup(mset=mset) return
[docs] def get_mesh(self, position: fr.grid.Position | None = None, spectral: bool = False ) -> tuple[ndarray]: """ Get the meshgrid of the grid points. Parameters ---------- `position` : `Position` or `None` (default: `None`) The position of the field. `spectral` : `bool` (default: `False`) Whether to return the meshgrid of the spectral domain. Returns ------- `tuple[ndarray]` The meshgrid of the grid points. """ if position is None: position = self.cell_center if position != self.cell_center: raise NotImplementedError("Not implemented for this grid") if spectral: return self._K return self._X
# ---------------------------------------------------------------- # Fourier Transform Methods # ----------------------------------------------------------------
[docs] @abstractmethod def fft(self, arr: ndarray, padding = fr.grid.FFTPadding.NOPADDING, bc_types: tuple[fr.grid.BCType] | None = None, positions: tuple[fr.grid.AxisPosition] | None = None, axes: tuple[int] | None = None, ) -> ndarray: """ Perform a (fast) fourier transform on the input array. Parameters ---------- `arr` : `ndarray` The input array. `padding` : `FFTPadding` (default: `FFTPadding.NOPADDING`) The padding to apply to the array. `bc_types` : `tuple[BCType]` or `None` (default: `None`) The boundary conditions to apply to each axis. `positions` : `tuple[AxisPosition]` or `None` (default: `None`) The position of the field. `axes` : `tuple[int]` or `None` (default: `None`) The axes to transform. Returns ------- `ndarray` The transformed array. """ raise NotImplementedError
[docs] @abstractmethod def ifft(self, arr: ndarray, padding = fr.grid.FFTPadding.NOPADDING, bc_types: tuple[fr.grid.BCType] | None = None, positions: tuple[fr.grid.AxisPosition] | None = None, axes: tuple[int] | None = None, ) -> ndarray: """ Perform an inverse (fast) fourier transform on the input array. Parameters ---------- `arr` : `ndarray` The input array. `padding` : `FFTPadding` (default: `FFTPadding.NOPADDING`) The padding to apply to the array. `bc_types` : `tuple[BCType]` or `None` (default: `None`) The boundary conditions to apply to each axis. `positions` : `tuple[AxisPosition]` or `None` (default: `None`) The position of the field. `axes` : `tuple[int]` or `None` (default: `None`) The axes to transform. Returns ------- `ndarray` The transformed array. """ raise NotImplementedError
# ---------------------------------------------------------------- # Spectral Analysis Tools # ----------------------------------------------------------------
[docs] @abstractmethod def omega(self, k: tuple[float] | tuple[ndarray], use_discrete: bool = False ) -> ndarray: """ Compute the dispersion relation of the model. Parameters ---------- `k` : `tuple[float] | tuple[ndarray]` The wave numbers `use_discrete` : `bool` (default: False) Whether to include space-discretization effects. Returns ------- `ndarray` The dispersion relation (omega(k)). """ raise NotImplementedError
[docs] @abstractmethod def vec_q(self, s: int, use_discrete: bool = True) -> fr.StateBase: """ Computes the eigenvector of the linear operator of the mode `s`. Parameters ---------- `s` : `int` The mode (which eigenvalue / eigenvector to compute). `use_discrete` : `bool` (default: True) Whether to include space-discretization effects. Returns ------- `StateBase` The eigenvector of the linear operator. """ raise NotImplementedError
[docs] @abstractmethod def vec_p(self, s: int, use_discrete: bool = True) -> fr.StateBase: """ Computes the projection vector of the linear operator of the mode `s`. Parameters ---------- `s` : `int` The mode (which eigenvalue / eigenvector to compute). `use_discrete` : `bool` (default: True) Whether to include space-discretization effects. Returns ------- `StateBase` The projection vector of the linear operator. """ raise NotImplementedError
@property def omega_analytical(self) -> ndarray: """ Analytical dispersion relation. """ if self._omega_analytical is None: self._omega_analytical = self.omega(self.K, use_discrete=False) return self._omega_analytical @property def omega_space_discrete(self) -> ndarray: """ Dispersion relation with space-discretization effects. """ if self._omega_space_discrete is None: self._omega_space_discrete = self.omega(self.K, use_discrete=True) return self._omega_space_discrete @property def omega_time_discrete(self): """ Dispersion relation with space-time-discretization effects. Warning: The computation may be very slow. """ if self._omega_time_discrete is None: om_space_discrete = self.omega_space_discrete ts = self.mset.time_stepper om = ts.time_discretization_effect(om_space_discrete) self._omega_time_discrete = om return self._omega_time_discrete # ---------------------------------------------------------------- # Domain Decomposition Methods # ----------------------------------------------------------------
[docs] def sync(self, arr: ndarray, flat_axes: list[int] | None = None) -> ndarray: """ Synchronize the halo (boundary) points of an array across all MPI ranks. Parameters ---------- `arr` : `ndarray` The array to synchronize. Returns ------- `ndarray` The synchronized array. """ return self.domain_decomp.sync(arr, flat_axes=flat_axes)
[docs] @fr.utils.jaxjit def sync_multi(self, arrs: tuple[ndarray]) -> tuple[ndarray]: """ Synchronize the halo (boundary) points of multiple arrays across all MPI ranks. Parameters ---------- `arrs` : `list[ndarray]` The list of arrays to synchronize. Returns ------- `list[ndarray]` The synchronized list of arrays. """ return self.domain_decomp.sync_multiple(arrs)
[docs] @fr.utils.jaxjit def unpad(self, arr: ndarray) -> ndarray: """ Remove the halo padding from an array. Parameters ---------- `arr` : `ndarray` The padded array. Returns ------- `ndarray` The unpadded array. """ return self.domain_decomp.unpad(arr)
[docs] @fr.utils.jaxjit def pad(self, arr: ndarray) -> ndarray: """ Add halo padding to an array. Parameters ---------- `arr` : `ndarray` The unpadded array. Returns ------- `ndarray` The padded array. """ return self.domain_decomp.pad(arr)
[docs] @partial(fr.utils.jaxjit, static_argnames=('pad', 'spectral', 'topo')) def create_array(self, pad: bool = True, spectral: bool = False, topo: tuple[bool] | None = None) -> ndarray: """ Create an array. Parameters ---------- `pad` : bool Whether to add padding to the array. `spectral` : bool Whether the array is in spectral space. `topo` : tuple[bool] | None The topology of the array. Axes with false are flat (only one grid point) """ return self.domain_decomp.create_array( pad=pad, spectral=spectral, topo=topo)
[docs] def create_random_array(self, seed: int = 1234, pad: bool = True, spectral: bool = False, topo: tuple[bool] | None = None ) -> ndarray: """ Create a random array. Parameters ---------- `seed` : int The seed for the random number generator. `pad` : bool Whether to add padding to the array. `spectral` : bool Whether the array is in spectral space. `topo` : tuple[bool] | None The topology of the array. Axes with false are flat (only one grid point) """ return self.domain_decomp.create_random_array( seed=seed, pad=pad, spectral=spectral, topo=topo)
# ---------------------------------------------------------------- # Display methods # ---------------------------------------------------------------- @property def info(self) -> dict: """ Return a dictionary with information about the grid. Description ----------- This method should be overridden by the child class to return a dictionary with information about the grid. This information is used to print the grid in the `__repr__` method. """ return {} def __repr__(self) -> str: """ String representation of the grid. """ res = self.name for key, value in self.info.items(): res += "\n - {}: {}".format(key, value) return res # ---------------------------------------------------------------- # Grid Modules # ---------------------------------------------------------------- @property def diff_module(self) -> fr.grid.DiffModule: """The differential operator module.""" return self._diff_module @diff_module.setter def diff_module(self, value: fr.grid.DiffModule) -> None: if not isinstance(value, fr.grid.DiffModule): raise ValueError("The differential operator module must be a DiffBase object") self._diff_module = value return @property def interp_module(self) -> fr.grid.InterpolationModule: """The interpolation operator module.""" return self._interp_module @interp_module.setter def interp_module(self, value: fr.grid.InterpolationModule) -> None: if not isinstance(value, fr.grid.InterpolationModule): raise ValueError("The interpolation operator module must be an InterpolationBase object") self._interp_module = value return @property def water_mask(self) -> fr.grid.WaterMask: """ Get the water mask. """ return self._water_mask @water_mask.setter def water_mask(self, value: fr.grid.WaterMask) -> None: self._water_mask = value return # ---------------------------------------------------------------- # Properties # ---------------------------------------------------------------- @property def mset(self) -> fr.ModelSettingsBase | None: """The model settings object.""" return self._mset @property def domain_decomp(self) -> fr.domain_decomposition.DomainDecomposition: """The domain decomposition object.""" return self._domain_decomp @property def n_dims(self) -> int: """The number of dimensions of the grid.""" return self._n_dims @property def N(self) -> tuple[int]: """The number of grid points in each dimension.""" return self._N @property def L(self) -> tuple[float]: """The length of the grid in each dimension.""" return self._L @property def total_grid_points(self) -> int: """The total number of grid points in the grid.""" return self._total_grid_points @property def periodic_bounds(self) -> list[bool]: """A tuple of booleans indicating whether the grid is periodic in each dimension.""" return self._periodic_bounds @property def cell_center(self) -> fr.grid.Position: """The position of the cell centers.""" return self._cell_center @property def X(self) -> tuple[ndarray]: """The meshgrid of the grid points.""" return self._X @property def x_global(self) -> tuple[ndarray]: """The x-vector of the global grid points.""" return self._x_global @property def K(self) -> ndarray: """The wavenumber of the grid.""" return self._K @property def k_global(self) -> ndarray: """The global wavenumber of the grid.""" return self._k_global @property def dx(self) -> tuple[ndarray]: """The grid spacing in each dimension.""" return self._dx @property def dV(self) -> ndarray: """The volume element of the grid.""" return self._dV # ================================================================ # Flags # ================================================================ @property def fourier_transform_available(self) -> bool: """Indicates whether the grid supports fast fourier transforms.""" return self._fourier_transform_available @fourier_transform_available.setter def fourier_transform_available(self, value: bool) -> None: self._fourier_transform_available = value @property def mpi_available(self) -> bool: """Indicates whether the grid supports MPI parallelization.""" return self._mpi_available @mpi_available.setter def mpi_available(self, value: bool) -> None: self._mpi_available = value @property def spectral_grid(self) -> bool: """Indicates whether the grid is a spectral grid.""" return self._spectral_grid @spectral_grid.setter def spectral_grid(self, value: bool) -> None: self._spectral_grid = value return