Source code for fridom.framework.grid.cartesian.grid

from __future__ import annotations

from copy import deepcopy
from typing import Literal

import fridom.framework as fr
import numpy as np
from functools import partial


[docs] @fr.utils.jaxify class Grid(fr.grid.GridBase): """ An n-dimensional cartesian grid with capabilities for fourier transforms. Description ----------- The cartesian grid is a regular grid with constant grid spacing in each direction. The grid can be periodic in some directions and non-periodic in others. Parameters ---------- `N` : `tuple[int]` Number of grid points in each direction. `L` : `tuple[float]` Domain size in meters in each direction. `periodic_bounds` : `tuple[bool]`, (default: None) A list of booleans that indicate whether the axis is periodic. If True, the axis is periodic, if False, the axis is non-periodic. Default is True for all axes. `shared_axes` : `list[int]`, (default: None) A list of integers that indicate which axes are shared among MPI ranks. Default is None, which means that no fourier transforms are available. `diff_mod` : `DiffModule`, (default: None) A module that contains the differentiation operators. If None, the finite differences module is used. `interp_mod` : `InterpolationModule`, (default: None) A module that contains the interpolation methods. If None, the linear interpolation module is used. Examples -------- .. code-block:: python import fridom.framework as fr # construct a 3D grid: grid = fr.grid.CartesianGrid( N=(32, 32, 8), # 32x32x8 grid points L=(100.0, 100.0, 10.0), # 100m x 100m x 10m domain periodic_bounds=(True, True, False), # non-periodic in z shared_axes=[0, 1] # slab decomposition, shared in x and y ) # setup the grid using the model settings mset = fr.ModelSettingsBase(grid) mset.setup() # get the meshgrids X, Y, Z = grid.X # physical meshgrid of the local domain KX, KY, KZ = grid.K # spectral meshgrid of the local domain # get the grid spacing dx, dy, dz = grid.dx """
[docs] def __init__(self, N: list[int], L: list[float], periodic_bounds: list[bool] | None = None, domain_decomp: fr.domain_decomposition.DomainDecomposition | None = None, diff_mod: fr.grid.DiffModule | None = None, interp_mod: fr.grid.InterpolationModule | None = None ) -> None: super().__init__(len(N)) self.name = "Cartesian Grid" # -------------------------------------------------------------- # Check the input # -------------------------------------------------------------- # check that N and L have the same length if len(N) != len(L): raise ValueError("N and L must have the same number of dimensions.") n_dims = len(N) # check that periodic_bounds is the right length periodic_bounds = tuple(periodic_bounds or [True] * n_dims) # default is periodic if len(periodic_bounds) != n_dims: raise ValueError( "periodic_bounds must have the same number of dimensions as N and L.") fourier_transform_available = True # -------------------------------------------------------------- # Set the flags # -------------------------------------------------------------- self.fourier_transform_available = fourier_transform_available self.mpi_available = True # -------------------------------------------------------------- # Set the attributes # -------------------------------------------------------------- # public attributes self._n_dims = n_dims # private attributes self._N = N self._L = L self._dx = tuple(L / N for L, N in zip(L, N)) self._dV = np.prod(self._dx) self._total_grid_points = int(np.prod(N)) self._periodic_bounds = periodic_bounds self._domain_decomp = domain_decomp self._fft: fr.grid.cartesian.FFT | None = None self._diff_module = diff_mod or fr.grid.cartesian.FiniteDifferences() self._interp_module = interp_mod or fr.grid.cartesian.LinearInterpolation() return
[docs] def setup(self, mset: 'fr.ModelSettingsBase', req_halo: int | None = None, fft_module: 'fr.grid.cartesian.FFT | None' = None, ) -> None: ncp = fr.config.ncp dtype = fr.config.dtype_real # -------------------------------------------------------------- # Initialize the domain decomposition # -------------------------------------------------------------- if req_halo is None: req_halo = max(self._diff_module.required_halo, self._interp_module.required_halo) req_halo = max(req_halo, mset.halo) # get the domain decomposition module if self._domain_decomp is None: # if there is no domain decomposition created yet, create a new one self._construct_domain_decomp(req_halo) if self._domain_decomp.halo != req_halo: # if there was a domain decomposition with wrong halo size, # create a new self._construct_domain_decomp(req_halo) domain_decomp = self._domain_decomp # -------------------------------------------------------------- # Initialize the fourier transform # -------------------------------------------------------------- if self.fourier_transform_available: fft = fft_module or fr.grid.cartesian.FFT(self._periodic_bounds) else: fft = None # -------------------------------------------------------------- # Initialize the meshgrids # -------------------------------------------------------------- x = tuple(ncp.linspace(0, li, ni, dtype=dtype, endpoint=False) + 0.5 * dxi for li, ni, dxi in zip(self._L, self._N, self._dx)) X = domain_decomp.create_meshgrid(*x, pad=True, spectral=False) if self.fourier_transform_available: k = fft.get_freq(self._N, self._dx) K = domain_decomp.create_meshgrid(*k, pad=False, spectral=True) else: fr.log.warning("Fourier transform not available.") k = None K = None # ---------------------------------------------------------------- # Store the attributes # ---------------------------------------------------------------- self._mset = mset self._domain_decomp = domain_decomp self._fft = fft self._X = X self._x_global = x self._K = K self._k_global = k # call the setup method of the base class # This is called last since some of the setup methods of the grid base # class depend on the attributes set here. super().setup(mset) return
[docs] def get_mesh(self, position: fr.grid.Position | None = None, spectral: bool = False, ) -> tuple[np.ndarray]: if spectral: return self.K # compute the offsets based on the position position = position or self.cell_center offsets = [0.5 * dx if pos == fr.grid.AxisPosition.FACE else 0 for dx, pos in zip(self.dx, position.positions)] # apply the offsets return tuple(x + offset for x, offset in zip(self.X, offsets))
def _construct_domain_decomp(self, halo: int) -> None: DomainDecomposition = fr.domain_decomposition.get_default_domain_decomposition() # construct the domain decomposition domain_decomp: fr.domain_decomposition.DomainDecomposition = DomainDecomposition( shape=tuple(self._N), halo=halo, periods=self._periodic_bounds, shared_axes=None) self._domain_decomp = domain_decomp # ================================================================ # Fourier Transforms # ================================================================
[docs] @partial(fr.utils.jaxjit, static_argnames=["bc_types", "padding", "positions"]) def fft(self, arr: np.ndarray, padding = fr.grid.FFTPadding.NOPADDING, bc_types: tuple[fr.grid.BCType] | None = None, positions: tuple[fr.grid.AxisPosition] | None = None, axes: tuple[int] | None = None, ) -> np.ndarray: # Forward transform the array f = lambda x, axes: self._fft.forward(x, axes, bc_types, positions) forward = self._domain_decomp.parallel_forward_transform(f) u_hat = forward(arr, axes) # Apply padding if necessary if padding == fr.grid.FFTPadding.EXTEND: u_hat = self.domain_decomp.unpad_extend(u_hat) return u_hat
[docs] @partial(fr.utils.jaxjit, static_argnames=["bc_types", "padding", "positions"]) def ifft(self, arr: np.ndarray, padding = fr.grid.FFTPadding.NOPADDING, bc_types: tuple[fr.grid.BCType] | None = None, positions: tuple[fr.grid.AxisPosition] | None = None, axes: tuple[int] | None = None, ) -> np.ndarray: # Apply padding if necessary match padding: case fr.grid.FFTPadding.NOPADDING: u = arr case fr.grid.FFTPadding.TRIM: u = self.domain_decomp.pad_trim(arr) case fr.grid.FFTPadding.EXTEND: u = self.domain_decomp.pad_extend(arr) f = lambda x, axes: self._fft.backward(x, axes, bc_types, positions) backward = self._domain_decomp.parallel_backward_transform(f) return backward(u, axes)
# ================================================================ # Syncing and Boundary Conditions # ================================================================
[docs] @fr.utils.jaxjit def sync_multi(self, arrs: tuple[np.ndarray]) -> tuple[np.ndarray]: return self.domain_decomp.sync_multiple(arrs)
# ================================================================ # Shrinking and Expanding # ================================================================ def _shrink_topo(self, topo: tuple[bool], axes: tuple[int]) -> tuple[bool]: new_topo = list(topo) for i in axes: new_topo[i] = False return tuple(new_topo) def _extend_topo(self, topo: tuple[bool], axes: tuple[int]) -> tuple[bool]: new_topo = list(topo) for i in axes: new_topo[i] = True return tuple(new_topo)
[docs] @partial(fr.utils.jaxjit, static_argnames=["axes"]) def sum(self, # noqa: D102 field: fr.ScalarField, axes: tuple[int] | None = None) -> fr.ScalarField: fr.exceptions.PartialDomainError.check(field) # apply the boundary mask to the field field = field.apply_water_mask() axes = axes or tuple(range(self.n_dims)) value = self.domain_decomp.sum(field.arr, axes=axes, spectral=field.is_spectral) mdata = deepcopy(field.mdata) mdata.topo = self._shrink_topo(mdata.topo, axes) return fr.ScalarField(field.mset, mdata=mdata, arr=value)
[docs] @partial(fr.utils.jaxjit, static_argnames=["axes"]) def max(self, # noqa: D102 field: fr.ScalarField, axes: tuple[int] | None = None) -> fr.ScalarField: fr.exceptions.PartialDomainError.check(field) axes = axes or tuple(range(self.n_dims)) value = self.domain_decomp.max(field.arr, axes=axes, spectral=field.is_spectral) mdata = deepcopy(field.mdata) mdata.topo = self._shrink_topo(mdata.topo, axes) return fr.ScalarField(field.mset, mdata=mdata, arr=value)
[docs] @partial(fr.utils.jaxjit, static_argnames=["axes"]) def min(self, # noqa: D102 field: fr.ScalarField, axes: tuple[int] | None = None) -> fr.ScalarField: fr.exceptions.PartialDomainError.check(field) axes = axes or tuple(range(self.n_dims)) value = self.domain_decomp.min(field.arr, axes=axes, spectral=field.is_spectral) mdata = deepcopy(field.mdata) mdata.topo = self._shrink_topo(mdata.topo, axes) return fr.ScalarField(field.mset, mdata=mdata, arr=value)
[docs] @partial(fr.utils.jaxjit, static_argnames=["axes"]) def integrate(self, # noqa: D102 field: fr.ScalarField, axes: tuple[int] | None = None) -> fr.ScalarField: fr.exceptions.FieldSpaceError.check_if_physical(field) axes = axes or tuple(range(self.n_dims)) cell_area = 1.0 for i in axes: if field.topo[i]: cell_area *= self.dx[i] else: cell_area *= self.L[i] return self.sum(field * cell_area, axes)
[docs] @partial(fr.utils.jaxjit, static_argnames=["axis", "direction"]) def cumulative_integral(self, # noqa: D102 field: fr.ScalarField, axis: int, direction: Literal["forward", "backward"] = "forward", ) -> fr.ScalarField: # 1. CHECK THE INPUT # At the moment, we only support cumulative integrals on physical fields fr.exceptions.FieldSpaceError.check_if_physical(field) # At the moment, we only support cumulative integrals on fields that # are extended in all directions fr.exceptions.PartialDomainError.check(field) # only support forward + center position and backward + face position pos = field.position.positions[axis] face = fr.grid.AxisPosition.FACE center = fr.grid.AxisPosition.CENTER if direction == "forward" and pos != center: msg = "Can only do forward cumulative integral for fields on the" msg += " cell center." raise ValueError(msg) if direction == "backward" and pos != face: msg = "Can only do backward cumulative integral for fields on the" msg += " cell face." raise ValueError(msg) # 2. DO THE CUMULATIVE INTEGRAL # Apply the boundary mask before doing the cumulative integral field = field.apply_water_mask() if direction == "forward": field.arr = self.domain_decomp.cumsum(field.arr, axis=axis) if direction == "backward": field.arr = self.domain_decomp.inv_cumsum(field.arr, axis=axis) # multiply by the cell area field *= self.dx[axis] # update the position field.position = field.position.shift(axis) return field.sync()
# ================================================================ # Properties # ================================================================ @property def info(self) -> dict: res = super().info res["N"] = f"{self.N[0]}" res["L"] = fr.utils.humanize_number(self.L[0], "meters") res["dx"] = fr.utils.humanize_number(self.dx[0], "meters") res["Periodic"] = f"{self.periodic_bounds[0]}" for i in range(1, self.n_dims): res["N"] += f" x {self.N[i]}" res["L"] += f" x {fr.utils.humanize_number(self.L[i], 'meters')}" res["dx"] += f" x {fr.utils.humanize_number(self.dx[i], 'meters')}" res["Periodic"] += f" x {self.periodic_bounds[i]}" return res @property def L(self) -> tuple: """Domain size in each direction.""" return self._L @L.setter def L(self, value: tuple): self._L = value self._dx = tuple(L / N for L, N in zip(self._L, self._N)) @property def N(self) -> tuple: """Grid points in each direction.""" return self._N @N.setter def N(self, value: tuple): self._N = value self._dx = tuple(L / N for L, N in zip(self._L, self._N)) self._dV = np.prod(self._dx) self._total_grid_points = int(np.prod(self._N)) @property def K(self) -> tuple | None: """Spectral meshgrid on the local domain.""" return self._K @property def k_local(self) -> tuple | None: """Spectral k-vectors on the local domain.""" return self._k_local @property def k_global(self) -> tuple | None: """Global spectral k-vectors.""" return self._k_global