"""
Base class for domain decomposition.
"""
from abc import abstractmethod
import numpy as np
from numpy import ndarray
import fridom.framework as fr
[docs]
@fr.utils.jaxify
class DomainDecomposition:
"""
Construct a grid of processors and decompose a global domain into subdomains.
Description
-----------
Decompose the global domain into subdomains for parallel computing. The
domain decomposition is done in a cartesian grid of processors. The
decomposition can be done in multiple dimensions. Axes that are shared
between processors can be specified (e.g. for fft)
::
-----------------------------------
/ / /|
/ / / |
/ / / |
/ / / |
/ / / |
/ / / /|
/ / / / |
---------------------------------- / |
| | | / |
| PROCESSOR | PROCESSOR | / |
| 0, 1 | 1, 1 | / /
| | |/ /
|----------------|----------------| / ^
| | | / /
| PROCESSOR | PROCESSOR | / / shared_axis
| 0, 0 | 1, 0 | / /
| | |/
-----------------------------------
Parameters
----------
`shape` : `tuple[int]`
The total number of grid points in each dimension.
`halo` : `int`, optional (default=0)
The number of halo cells (ghost cells) around the local domain
for the exchange of boundary values.
`periods` : `tuple[bool]`, optional (default=None)
A list of booleans indicating whether the domain is periodic in each dimension.
If None, all dimensions are periodic.
`shared_axes` : `list[int]`, optional (default=None)
A list of axes that are shared between processors.
`device_ids` : `list[int]`, optional (default=None)
Optional list of device ids to use. If None, all devices are used.
This option is useful for coupled simulations.
"""
[docs]
def __init__(self,
shape: tuple[int],
halo: int = 0,
periods: tuple[bool] | None = None,
shared_axes: tuple[int] | None = None,
device_ids: list[int] | None = None):
self._shape = shape
self._n_dims = len(shape)
self._halo = halo
self._periods = periods or tuple(True for _ in range(self.n_dims))
self._shared_axes = shared_axes or []
self._rank = 0
self._device_ids = device_ids
self._p_dims = None
# ================================================================
# Halo exchange
# ================================================================
[docs]
@abstractmethod
def sync(self, arr: ndarray, flat_axes: list[int] | None = None) -> ndarray:
"""
Synchronize the halo regions of an array across all processes.
Parameters
----------
`arr` : ndarray
The array to synchronize.
`flat_axes` : list[int] | None
Dimensions which are flat (no halo exchange).
If None, all dimensions are exchanged.
"""
[docs]
def sync_multiple(self, arr: list[ndarray]) -> list[ndarray]:
"""
Synchronize the halo regions of multiple arrays across all processes.
Parameters
----------
`arr` : list[ndarray]
The list of arrays to synchronize.
"""
return [self.sync(a) for a in arr]
# ================================================================
# Apply Transform (e.g. FFT)
# ================================================================
[docs]
def parallel_forward_transform(self, func: callable) -> callable:
"""
Parallel forward transform.
Parameters
----------
`func` : callable
The function to apply the forward transform to.
func(arr: ndarray, axes: list[int] | None = None) -> ndarray
"""
def wrapper(arr: ndarray, axes: list[int] | None = None) -> ndarray:
# unpad the array
arr = self.unpad(arr)
# apply the forward transform
arr = func(arr, axes)
return arr
return wrapper
[docs]
def parallel_backward_transform(self, func: callable) -> callable:
"""
Parallel backward transform.
Parameters
----------
`func` : callable
The function to apply the backward transform to.
func(arr: ndarray, axes: list[int] | None = None) -> ndarray
"""
def wrapper(arr: ndarray, axes: list[int] | None = None) -> ndarray:
# apply the backward transform
arr = func(arr, axes)
# pad the array
arr = self.pad(arr)
return arr
return wrapper
# ================================================================
# Padding
# ================================================================
[docs]
@abstractmethod
def pad(self, arr: ndarray) -> ndarray:
"""
Add padding to an array.
Parameters
----------
`arr` : ndarray
The array to pad.
"""
[docs]
@abstractmethod
def unpad(self, arr: ndarray) -> ndarray:
"""
Remove padding from an array.
Parameters
----------
`arr` : ndarray
The array to unpad.
"""
# ----------------------------------------------------------------
# Spectral paddings
# ----------------------------------------------------------------
[docs]
def pad_extend(self, arr: ndarray) -> ndarray:
"""
Extend the array with zeros (for spectral padding)
Parameters
----------
`arr` : ndarray
The array to pad.
Returns
-------
ndarray
The padded array.
"""
[docs]
def unpad_extend(self, arr: ndarray) -> ndarray:
"""
Remove the extension of the array (for spectral padding)
Parameters
----------
`arr` : ndarray
The array to unpad.
Returns
-------
ndarray
The unpadded array.
"""
[docs]
def pad_trim(self, arr: ndarray) -> ndarray:
"""
Set the padded region to zero (for spectral padding)
Parameters
----------
`arr` : ndarray
The array to pad.
"""
# ================================================================
# Gather
# ================================================================
[docs]
@abstractmethod
def gather(self,
arr: ndarray,
slc: tuple[slice] | None = None,
dest_rank: int | None = None,
spectral: bool = False) -> ndarray:
"""
Gather an array to a single process.
Parameters
----------
`arr` : ndarray
The array to gather.
`slc` : tuple[slice] (default=None)
The slice of the array to gather.
If None, gather the entire array.
`dest_rank` : int (default=None)
The rank of the process to gather to.
If None, gather to all processes.
`spectral` : bool
Whether the array is in spectral space.
"""
# ================================================================
# Array creation
# ================================================================
[docs]
@abstractmethod
def create_array(self,
pad: bool = True,
spectral: bool = False,
topo: tuple[bool] | None = None) -> ndarray:
"""
Create an array.
Parameters
----------
`pad` : bool
Whether to add padding to the array.
`spectral` : bool
Whether the array is in spectral space.
`topo` : tuple[bool] | None
The topology of the array. Axes with false are flat (only one grid point)
"""
[docs]
@abstractmethod
def create_random_array(self,
seed: int = 1234,
pad: bool = True,
spectral: bool = False,
topo: tuple[bool] | None = None
) -> ndarray:
"""
Create a random array.
Parameters
----------
`seed` : int
The seed for the random number generator.
`pad` : bool
Whether to add padding to the array.
`spectral` : bool
Whether the array is in spectral space.
`topo` : tuple[bool] | None
The topology of the array. Axes with false are flat (only one grid point)
"""
[docs]
@abstractmethod
def create_meshgrid(self,
*args: ndarray,
pad: bool = True,
spectral: bool = False) -> tuple[ndarray]:
"""
Create a meshgrid of arrays.
Parameters
----------
`args` : ndarray
The arrays to meshgrid.
`pad` : bool
Whether to add padding to the meshgrid.
`spectral` : bool
Whether the meshgrid is in spectral space.
"""
# ================================================================
# Array operations
# ================================================================
[docs]
@abstractmethod
def sum(self,
arr: ndarray,
axes: list[int] | None = None,
spectral: bool = False) -> ndarray:
"""
Sum an array across specified axes.
Parameters
----------
`arr` : ndarray
The array to sum.
`axes` : list[int] | None
The axes to sum across.
If None, sum across all axes.
`spectral` : bool
Whether the array is in spectral space.
"""
[docs]
@abstractmethod
def max(self,
arr: ndarray,
axes: list[int] | None = None,
spectral: bool = False) -> ndarray:
"""
Find the maximum value of an array across specified axes.
Parameters
----------
`arr` : ndarray
The array to find the maximum value of.
`axes` : list[int] | None
The axes to find the maximum value across.
If None, find the maximum value across all axes.
`spectral` : bool
Whether the array is in spectral space.
"""
[docs]
@abstractmethod
def min(self,
arr: ndarray,
axes: list[int] | None = None,
spectral: bool = False) -> ndarray:
"""
Find the minimum value of an array across specified axes.
Parameters
----------
`arr` : ndarray
The array to find the minimum value of.
`axes` : list[int] | None
The axes to find the minimum value across.
If None, find the minimum value across all axes.
`spectral` : bool
Whether the array is in spectral space.
"""
# ================================================================
# Helper functions
# ================================================================
[docs]
def shard_map(self, func: callable) -> callable:
"""
Decorator to apply a function to the active processes only.
Parameters
----------
`func` : callable
The function to apply.
"""
return func
# ================================================================
# Properties
# ================================================================
@property
def n_dims(self) -> int:
"""
Number of dimensions.
"""
return self._n_dims
@property
def shape(self) -> tuple[int]:
"""
Shape of the domain (number of grid points).
"""
return self._shape
@property
def halo(self) -> int:
"""
Width of the halo region (same for all dimensions).
"""
return self._halo
@property
def periods(self) -> tuple[bool] | None:
"""
Periodic boundaries of the domain.
"""
return self._periods
@property
def parallel(self) -> bool:
"""
Whether the domain is parallel.
"""
return self.size > 1
@property
def rank(self) -> int:
"""
Rank of the current process.
"""
return self._rank
@property
def size(self) -> int:
"""
Number of processes.
"""
return np.prod(self.p_dims)
@property
def device_ids(self) -> list[int] | None:
"""
List of device ids.
"""
return self._device_ids
@property
def i_am_active(self) -> bool:
"""
Whether the current process is active in this domain.
"""
return self.rank in self.device_ids
@property
def p_dims(self) -> tuple[int]:
"""
Number of processes in each dimension.
"""
return self._p_dims
@property
def shared_axes(self) -> tuple[int]:
"""
Axes shared by all processes.
"""
return [i for i, x in enumerate(self.p_dims) if x == 1]
[docs]
def get_default_domain_decomposition() -> DomainDecomposition:
"""
Get the domain decomposition class for the specified backend.
Parameters
----------
`backend` : str
The backend to use. Options are 'single' and 'jax'.
"""
fall_back = fr.domain_decomposition.SingleDecomposition
# if the parallel flag is not set, use the fall back
if not fr.config.enable_parallel:
return fall_back
# if the backend is jax, use the jax decomposition
if fr.config.backend_is_jax:
# count the number of devices
import jax
n_devices = jax.device_count()
# if we only have one available device, we use single decomposition
if n_devices == 1:
return fall_back
# otherwise, we use the jax decomposition
return fr.domain_decomposition.JaxDecomposition