JaxDecomposition#

class fridom.framework.domain_decomposition.JaxDecomposition(shape: tuple[int], halo: int = 0, periods: tuple[bool] | None = None, p_dims: tuple[int] | None = None, shared_axes: tuple[int] | None = None, device_ids: list[int] | None = None)[source]#

Bases: DomainDecomposition

__init__(shape: tuple[int], halo: int = 0, periods: tuple[bool] | None = None, p_dims: tuple[int] | None = None, shared_axes: tuple[int] | None = None, device_ids: list[int] | None = None)[source]#

Methods

__init__(shape[, halo, periods, p_dims, ...])

create_array

create_meshgrid(*args[, pad, spectral])

Create a meshgrid of arrays.

create_random_array([seed, pad, spectral])

Create a random array.

cumsum(arr, axis)

Cumulative sum of an array along a specified axis.

gather(arr[, slc, dest_rank, spectral])

Gather an array to a single process.

inv_cumsum(arr, axis)

Inverse cumulative sum of an array along a specified axis.

max(arr[, axes, spectral])

Find the maximum value of an array across specified axes.

min(arr[, axes, spectral])

Find the minimum value of an array across specified axes.

pad_extend(arr)

Extend the array with zeros (for spectral padding)

pad_trim(arr)

Set the padded region to zero (for spectral padding)

parallel_backward_transform(func)

Parallel backward transform.

parallel_forward_transform(func)

Parallel forward transform.

roll(arr, shift, axis)

Roll an array along specified axes.

shard_map(func)

Decorator to apply a function to the active processes only.

sum(arr[, axes, spectral])

Sum an array across specified axes.

sync(arr[, flat_axes])

Synchronize the halo regions of an array across all processes.

sync_multiple(arr)

Synchronize the halo regions of multiple arrays across all processes.

unpad_extend(arr)

Remove the extension of the array (for spectral padding)

Attributes

device_ids

List of device ids.

halo

Width of the halo region (same for all dimensions).

i_am_active

Whether the current process is active in this domain.

n_dims

Number of dimensions.

p_dims

Number of processes in each dimension.

pad

Add padding to an array.

parallel

Whether the domain is parallel.

periods

Periodic boundaries of the domain.

rank

Rank of the current process.

shape

Shape of the domain (number of grid points).

shared_axes

Axes shared by all processes.

size

Number of processes.

unpad

Remove padding from an array.

sync(arr: ndarray, flat_axes: list[int] | None = None) ndarray[source]#

Synchronize the halo regions of an array across all processes.

Parameters#

arrndarray

The array to synchronize.

flat_axeslist[int] | None

Dimensions which are flat (no halo exchange). If None, all dimensions are exchanged.

property pad: callable#

Add padding to an array.

Parameters#

arrndarray

The array to pad.

property unpad: callable#

Remove padding from an array.

Parameters#

arrndarray

The array to unpad.

pad_extend(arr: ndarray) ndarray[source]#

Extend the array with zeros (for spectral padding)

Parameters#

arrndarray

The array to pad.

Returns#

ndarray

The padded array.

unpad_extend(arr: ndarray) ndarray[source]#

Remove the extension of the array (for spectral padding)

Parameters#

arrndarray

The array to unpad.

Returns#

ndarray

The unpadded array.

pad_trim(arr: ndarray) ndarray[source]#

Set the padded region to zero (for spectral padding)

Parameters#

arrndarray

The array to pad.

gather(arr: ndarray, slc: tuple[slice] | None = None, dest_rank: int | None = None, spectral: bool = False) ndarray[source]#

Gather an array to a single process.

Parameters#

arrndarray

The array to gather.

slctuple[slice] (default=None)

The slice of the array to gather. If None, gather the entire array.

dest_rankint (default=None)

The rank of the process to gather to. If None, gather to all processes.

spectralbool

Whether the array is in spectral space.

create_array(pad: bool = True, spectral: bool = False) ndarray#

Create an array.

Parameters#

padbool

Whether to add padding to the array.

spectralbool

Whether the array is in spectral space.

topotuple[bool] | None

The topology of the array. Axes with false are flat (only one grid point)

create_random_array(seed: int = 1234, pad: bool = True, spectral: bool = False) ndarray[source]#

Create a random array.

Parameters#

seedint

The seed for the random number generator.

padbool

Whether to add padding to the array.

spectralbool

Whether the array is in spectral space.

topotuple[bool] | None

The topology of the array. Axes with false are flat (only one grid point)

create_meshgrid(*args: ndarray, pad: bool = True, spectral: bool = False) tuple[ndarray][source]#

Create a meshgrid of arrays.

Parameters#

argsndarray

The arrays to meshgrid.

padbool

Whether to add padding to the meshgrid.

spectralbool

Whether the meshgrid is in spectral space.

sum(arr: ndarray, axes: list[int] | None = None, spectral: bool = False) ndarray[source]#

Sum an array across specified axes.

Parameters#

arrndarray

The array to sum.

axeslist[int] | None

The axes to sum across. If None, sum across all axes.

spectralbool

Whether the array is in spectral space.

max(arr: ndarray, axes: list[int] | None = None, spectral: bool = False) ndarray[source]#

Find the maximum value of an array across specified axes.

Parameters#

arrndarray

The array to find the maximum value of.

axeslist[int] | None

The axes to find the maximum value across. If None, find the maximum value across all axes.

spectralbool

Whether the array is in spectral space.

min(arr: ndarray, axes: list[int] | None = None, spectral: bool = False) ndarray[source]#

Find the minimum value of an array across specified axes.

Parameters#

arrndarray

The array to find the minimum value of.

axeslist[int] | None

The axes to find the minimum value across. If None, find the minimum value across all axes.

spectralbool

Whether the array is in spectral space.

shard_map(func: callable) callable[source]#

Decorator to apply a function to the active processes only.

Parameters#

funccallable

The function to apply.