import fridom.framework as fr
import numpy as np
from functools import partial
[docs]
@fr.utils.jaxify
class Grid(fr.grid.GridBase):
"""
An n-dimensional cartesian grid with capabilities for fourier transforms.
Description
-----------
The cartesian grid is a regular grid with constant grid spacing in each
direction. The grid can be periodic in some directions and non-periodic in
others.
Parameters
----------
`N` : `tuple[int]`
Number of grid points in each direction.
`L` : `tuple[float]`
Domain size in meters in each direction.
`periodic_bounds` : `tuple[bool]`, (default: None)
A list of booleans that indicate whether the axis is periodic.
If True, the axis is periodic, if False, the axis is non-periodic.
Default is True for all axes.
`shared_axes` : `list[int]`, (default: None)
A list of integers that indicate which axes are shared among MPI ranks.
Default is None, which means that no fourier transforms are available.
`diff_mod` : `DiffModule`, (default: None)
A module that contains the differentiation operators.
If None, the finite differences module is used.
`interp_mod` : `InterpolationModule`, (default: None)
A module that contains the interpolation methods.
If None, the linear interpolation module is used.
Examples
--------
.. code-block:: python
import fridom.framework as fr
# construct a 3D grid:
grid = fr.grid.CartesianGrid(
N=(32, 32, 8), # 32x32x8 grid points
L=(100.0, 100.0, 10.0), # 100m x 100m x 10m domain
periodic_bounds=(True, True, False), # non-periodic in z
shared_axes=[0, 1] # slab decomposition, shared in x and y
)
# setup the grid using the model settings
mset = fr.ModelSettingsBase(grid)
mset.setup()
# get the meshgrids
X, Y, Z = grid.X # physical meshgrid of the local domain
KX, KY, KZ = grid.K # spectral meshgrid of the local domain
# get the grid spacing
dx, dy, dz = grid.dx
"""
[docs]
def __init__(self,
N: list[int],
L: list[float],
periodic_bounds: list[bool] | None = None,
domain_decomp: fr.domain_decomposition.DomainDecomposition | None = None,
diff_mod: fr.grid.DiffModule | None = None,
interp_mod: fr.grid.InterpolationModule | None = None
) -> None:
super().__init__(len(N))
self.name = "Cartesian Grid"
# --------------------------------------------------------------
# Check the input
# --------------------------------------------------------------
# check that N and L have the same length
if len(N) != len(L):
raise ValueError("N and L must have the same number of dimensions.")
n_dims = len(N)
# check that periodic_bounds is the right length
periodic_bounds = tuple(periodic_bounds or [True] * n_dims) # default is periodic
if len(periodic_bounds) != n_dims:
raise ValueError(
"periodic_bounds must have the same number of dimensions as N and L.")
fourier_transform_available = True
# --------------------------------------------------------------
# Set the flags
# --------------------------------------------------------------
self.fourier_transform_available = fourier_transform_available
self.mpi_available = True
# --------------------------------------------------------------
# Set the attributes
# --------------------------------------------------------------
# public attributes
self._n_dims = n_dims
# private attributes
self._N = N
self._L = L
self._dx = tuple(L / N for L, N in zip(L, N))
self._dV = np.prod(self._dx)
self._total_grid_points = int(np.prod(N))
self._periodic_bounds = periodic_bounds
self._domain_decomp = domain_decomp
self._fft: fr.grid.cartesian.FFT | None = None
self._diff_module = diff_mod or fr.grid.cartesian.FiniteDifferences()
self._interp_module = interp_mod or fr.grid.cartesian.LinearInterpolation()
return
[docs]
def setup(self,
mset: 'fr.ModelSettingsBase',
req_halo: int | None = None,
fft_module: 'fr.grid.cartesian.FFT | None' = None,
) -> None:
ncp = fr.config.ncp
dtype = fr.config.dtype_real
# --------------------------------------------------------------
# Initialize the domain decomposition
# --------------------------------------------------------------
if req_halo is None:
req_halo = max(self._diff_module.required_halo,
self._interp_module.required_halo)
req_halo = max(req_halo, mset.halo)
# get the domain decomposition module
if self._domain_decomp is None:
DomainDecomposition = fr.domain_decomposition.get_default_domain_decomposition()
# construct the domain decomposition
domain_decomp: fr.domain_decomposition.DomainDecomposition = DomainDecomposition(
shape=tuple(self._N),
halo=req_halo,
periods=self._periodic_bounds,
shared_axes=None)
else:
domain_decomp = self._domain_decomp
# --------------------------------------------------------------
# Initialize the fourier transform
# --------------------------------------------------------------
if self.fourier_transform_available:
fft = fft_module or fr.grid.cartesian.FFT(self._periodic_bounds)
else:
fft = None
# --------------------------------------------------------------
# Initialize the meshgrids
# --------------------------------------------------------------
x = tuple(ncp.linspace(0, li, ni, dtype=dtype, endpoint=False) + 0.5 * dxi
for li, ni, dxi in zip(self._L, self._N, self._dx))
X = domain_decomp.create_meshgrid(*x, pad=True, spectral=False)
if self.fourier_transform_available:
k = fft.get_freq(self._N, self._dx)
K = domain_decomp.create_meshgrid(*k, pad=False, spectral=True)
else:
fr.log.warning("Fourier transform not available.")
k = None
K = None
# ----------------------------------------------------------------
# Store the attributes
# ----------------------------------------------------------------
self._mset = mset
self._domain_decomp = domain_decomp
self._fft = fft
self._X = X
self._x_global = x
self._K = K
self._k_global = k
# call the setup method of the base class
# This is called last since some of the setup methods of the grid base
# class depend on the attributes set here.
super().setup(mset)
return
[docs]
def get_mesh(self,
position: fr.grid.Position | None = None,
spectral: bool = False
) -> tuple[np.ndarray]:
if spectral:
return self.K
position = position or self.cell_center
X = list(self.X)
for i in range(self.n_dims):
if position.positions[i] == fr.grid.AxisPosition.FACE:
X[i] += 0.5 * self.dx[i]
return tuple(X)
# ================================================================
# Fourier Transforms
# ================================================================
[docs]
@partial(fr.utils.jaxjit,
static_argnames=["bc_types", "padding", "positions"])
def fft(self,
arr: np.ndarray,
padding = fr.grid.FFTPadding.NOPADDING,
bc_types: tuple[fr.grid.BCType] | None = None,
positions: tuple[fr.grid.AxisPosition] | None = None,
axes: tuple[int] | None = None,
) -> np.ndarray:
# Forward transform the array
f = lambda x, axes: self._fft.forward(x, axes, bc_types, positions)
forward = self._domain_decomp.parallel_forward_transform(f)
u_hat = forward(arr, axes)
# Apply padding if necessary
if padding == fr.grid.FFTPadding.EXTEND:
u_hat = self.domain_decomp.unpad_extend(u_hat)
return u_hat
[docs]
@partial(fr.utils.jaxjit,
static_argnames=["bc_types", "padding", "positions"])
def ifft(self,
arr: np.ndarray,
padding = fr.grid.FFTPadding.NOPADDING,
bc_types: tuple[fr.grid.BCType] | None = None,
positions: tuple[fr.grid.AxisPosition] | None = None,
axes: tuple[int] | None = None,
) -> np.ndarray:
# Apply padding if necessary
match padding:
case fr.grid.FFTPadding.NOPADDING:
u = arr
case fr.grid.FFTPadding.TRIM:
u = self.domain_decomp.pad_trim(arr)
case fr.grid.FFTPadding.EXTEND:
u = self.domain_decomp.pad_extend(arr)
f = lambda x, axes: self._fft.backward(x, axes, bc_types, positions)
backward = self._domain_decomp.parallel_backward_transform(f)
return backward(u, axes)
# ================================================================
# Syncing and Boundary Conditions
# ================================================================
[docs]
@fr.utils.jaxjit
def sync_multi(self,
arrs: tuple[np.ndarray]) -> tuple[np.ndarray]:
return self.domain_decomp.sync_multiple(arrs)
# ================================================================
# Properties
# ================================================================
@property
def info(self) -> dict:
res = super().info
res["N"] = f"{self.N[0]}"
res["L"] = fr.utils.humanize_number(self.L[0], "meters")
res["dx"] = fr.utils.humanize_number(self.dx[0], "meters")
res["Periodic"] = f"{self.periodic_bounds[0]}"
for i in range(1, self.n_dims):
res["N"] += f" x {self.N[i]}"
res["L"] += f" x {fr.utils.humanize_number(self.L[i], 'meters')}"
res["dx"] += f" x {fr.utils.humanize_number(self.dx[i], 'meters')}"
res["Periodic"] += f" x {self.periodic_bounds[i]}"
if self._domain_decomp is not None:
res["Processors"] = f"{self._domain_decomp.n_procs[0]}"
for i in range(1, self.n_dims):
res["Processors"] += f" x {self._domain_decomp.n_procs[i]}"
return res
@property
def L(self) -> tuple:
"""Domain size in each direction."""
return self._L
@L.setter
def L(self, value: tuple):
self._L = value
self._dx = tuple(L / N for L, N in zip(self._L, self._N))
@property
def N(self) -> tuple:
"""Grid points in each direction."""
return self._N
@N.setter
def N(self, value: tuple):
self._N = value
self._dx = tuple(L / N for L, N in zip(self._L, self._N))
self._dV = np.prod(self._dx)
self._total_grid_points = int(np.prod(self._N))
@property
def K(self) -> tuple | None:
"""Spectral meshgrid on the local domain."""
return self._K
@property
def k_local(self) -> tuple | None:
"""Spectral k-vectors on the local domain."""
return self._k_local
@property
def k_global(self) -> tuple | None:
"""Global spectral k-vectors."""
return self._k_global