import fridom.framework as fr
import numpy as np
from numpy import ndarray
import jax
from jax.experimental import mesh_utils, multihost_utils
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jaxdecomp
import fridom.framework as fr
from functools import partial, cached_property
# ================================================================
# Custom partitioning to avoid unnecessary communication
# ================================================================
def _supported_sharding(sharding: NamedSharding, shape):
rank = len(shape.shape)
max_shared_dims = min(len(sharding.spec), rank-1)
names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims))
return NamedSharding(sharding.mesh, P(*names))
def _infer_sharding_from_operands(mesh, arg_shapes, result_shape):
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
return _supported_sharding(arg_shardings[0], arg_shapes[0])
def _partitionate_function(f: callable,
in_shardings: NamedSharding,
out_shardings: NamedSharding):
def partition(mesh, arg_shapes, result_shape):
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
return (mesh,
f,
_supported_sharding(arg_shardings[0], arg_shapes[0]),
(_supported_sharding(arg_shardings[0], arg_shapes[0]),))
my_f = custom_partitioning(f)
my_f.def_partition(
infer_sharding_from_operands=_infer_sharding_from_operands,
partition=partition)
return jax.jit(my_f, in_shardings=in_shardings, out_shardings=out_shardings)
[docs]
@fr.utils.jaxify
class JaxDecomposition(fr.domain_decomposition.DomainDecomposition):
[docs]
def __init__(self,
shape: tuple[int],
halo: int = 0,
periods: tuple[bool] | None = None,
p_dims: tuple[int] | None = None,
shared_axes: tuple[int] | None = None,
device_ids: list[int] | None = None):
super().__init__(shape, halo, periods, shared_axes, device_ids)
# create a device array
device_ids = device_ids or jax.devices()
size = len(device_ids)
if len(shape) != 3:
raise ValueError("Only 3D domains are supported.")
self._p_dims = p_dims or (size, 1)
if self.size != size:
raise ValueError(f"Number of available devices: {size} does not match to the processor grid: {self.p_dims}")
self._shared_axes = [i for i, x in enumerate(self.p_dims) if x == 1]
all_axes = set(range(3))
self._z_ffts = self._shared_axes
self._y_ffts = list(all_axes - set(self._shared_axes))
self._x_ffts = list(all_axes - set(self._shared_axes) - set(self._y_ffts))
# ----------------------------------------------------------------
# Create the device mesh
# ----------------------------------------------------------------
devices = mesh_utils.create_device_mesh(self.p_dims, devices=device_ids)
self.mesh = Mesh(devices, axis_names=('x', 'y'))
self.p_phys = P('x', 'y', None)
self.p_spec = P('y', None, 'x')
self.shard_phys = NamedSharding(self.mesh, self.p_phys)
self.shard_spec = NamedSharding(self.mesh, self.p_spec)
# ----------------------------------------------------------------
# Halo exchange slices and paddings
# ----------------------------------------------------------------
def _make_slice_tuple(slc):
slice_list = []
for i in range(self.n_dims):
full_slice = [slice(None)]*self.n_dims
full_slice[i] = slc
slice_list.append(tuple(full_slice))
return tuple(slice_list)
# create slices for halo exchange
self._inner_slice = tuple([slice(halo, -halo)]*self.n_dims)
self._inner = _make_slice_tuple(slice(halo, -halo))
self._send_to_next = _make_slice_tuple(slice(-2*halo, -halo))
self._send_to_prev = _make_slice_tuple(slice(halo, 2*halo))
self._recv_from_next = _make_slice_tuple(slice(-halo, None))
self._recv_from_prev = _make_slice_tuple(slice(None, halo))
pw = self.halo
self._padding = ((pw, pw), (pw, pw), (pw, pw))
self._halo_extents = (pw, pw)
# ================================================================
# Halo exchange
# ================================================================
[docs]
@partial(fr.utils.jaxjit, static_argnames=['flat_axes'])
def sync(self, arr: ndarray, flat_axes: list[int] | None = None) -> ndarray:
if flat_axes is not None:
raise NotImplementedError("Flat axes not supported in JaxDecomposition.")
arr = jaxdecomp.halo_exchange(
arr, halo_extents=self._halo_extents, halo_periods=self._periods[:-1])
for axis in self.shared_axes:
arr = self._sync_shared_axis(arr, axis)
return arr
@partial(fr.utils.jaxjit, static_argnames=['axis'])
def _sync_shared_axis(self, arr: ndarray, axis: int,) -> ndarray:
@partial(shard_map, mesh=self.mesh, in_specs=self.p_phys, out_specs=self.p_phys)
def _sync(arr: ndarray) -> ndarray:
rfn = self._recv_from_next[axis]
rfp = self._recv_from_prev[axis]
stn = self._send_to_next[axis]
stp = self._send_to_prev[axis]
arr = arr.at[rfn].set(arr[stp])
arr = arr.at[rfp].set(arr[stn])
return arr
return _sync(arr)
# ================================================================
# Apply Transform (e.g. FFT)
# ================================================================
# def parallel_forward_transform(self, func: callable) -> callable:
# @partial(jax.jit, static_argnames=['axes'])
# def my_transform(arr, axes: tuple[int] | None = None):
# # unpad the array
# arr = self.unpad(arr)
# arr = func(arr, axes=(0, ))
# arr = func(arr, axes=(1, ))
# arr = func(arr, axes=(2, ))
# return arr
# return my_transform
# return self._xy_pencil_forward(func)
# def parallel_backward_transform(self, func: callable) -> callable:
# @partial(jax.jit, static_argnames=['axes'])
# def my_transform(arr, axes: tuple[int] | None = None):
# arr = func(arr, axes=(0, ))
# arr = func(arr, axes=(1, ))
# arr = func(arr, axes=(2, ))
# return self.pad(arr)
# return my_transform
# return self._xy_pencil_backward(func)
def _xy_pencil_forward(self, func: callable) -> callable:
# operation in the z-axis
@partial(_partitionate_function,
in_shardings=NamedSharding(self.mesh, P('x', 'y', None)),
out_shardings=NamedSharding(self.mesh, P('x', 'y', None)))
def func_axis_2(x):
return func(x, axes=(2,))
# operation in the x-axis
@partial(_partitionate_function,
in_shardings=NamedSharding(self.mesh, P('y', 'x', None)),
out_shardings=NamedSharding(self.mesh, P('y', 'x', None)))
def func_axis_0(x):
x = jax.numpy.transpose(x, (2, 0, 1))
x = func(x, axes=(0,))
x = jax.numpy.transpose(x, (1, 2, 0))
return x
# operation in the y-pencil
@partial(_partitionate_function,
in_shardings=NamedSharding(self.mesh, P('x', 'y', None)),
out_shardings=NamedSharding(self.mesh, P('x', 'y', None)))
def func_axis_1(x):
x = jax.numpy.transpose(x, (1, 2, 0))
x = func(x, axes=(1,))
x = jax.numpy.transpose(x, (2, 0, 1))
return x
@partial(jax.jit, static_argnames=['axes'])
def my_forward_transform(arr, axes: list[int] | None = None):
axes = axes or list(range(self.n_dims))
# unpad the array
arr = self.unpad(arr)
# apply the forward transform in the z-axis
if 2 in axes:
arr = func_axis_2(arr)
# make the x-axis the shared axis
arr = jaxdecomp.transposeZtoY(arr)
# apply the forward transform in the x-axis
if 0 in axes:
arr = func_axis_0(arr)
# make the y-axis the shared axis
arr = jaxdecomp.transposeYtoX(arr)
# apply the forward transform in the y-axis
if 1 in axes:
arr = func_axis_1(arr)
# transpose the array back to the original shape
arr = jax.numpy.transpose(arr, (1, 2, 0))
return arr
return my_forward_transform
def _xy_pencil_backward(self, func: callable) -> callable:
# operation in the z-axis
@partial(_partitionate_function,
in_shardings=NamedSharding(self.mesh, P('x', 'y', None)),
out_shardings=NamedSharding(self.mesh, P('x', 'y', None)))
def func_axis_2(x):
return func(x, axes=(2,))
# operation in the x-axis
@partial(_partitionate_function,
in_shardings=NamedSharding(self.mesh, P('y', 'x', None)),
out_shardings=NamedSharding(self.mesh, P('y', 'x', None)))
def func_axis_0(x):
x = jax.numpy.transpose(x, (2, 0, 1))
x = func(x, axes=(0,))
x = jax.numpy.transpose(x, (1, 2, 0))
return x
# operation in the y-pencil
@partial(_partitionate_function,
in_shardings=NamedSharding(self.mesh, P('x', 'y', None)),
out_shardings=NamedSharding(self.mesh, P('x', 'y', None)))
def func_axis_1(x):
x = jax.numpy.transpose(x, (1, 2, 0))
x = func(x, axes=(1,))
x = jax.numpy.transpose(x, (2, 0, 1))
return x
@partial(jax.jit, static_argnames=['axes'])
def my_backward_transform(arr, axes: list[int] | None = None):
axes = axes or list(range(self.n_dims))
# transpose the array to match jaxdecomp
arr = jax.numpy.transpose(arr, (2, 0, 1))
# apply the forward transform in the y-axis
if 1 in axes:
arr = func_axis_1(arr)
# make the x-axis the shared axis
arr = jaxdecomp.transposeXtoY(arr)
# apply the forward transform in the x-axis
if 0 in axes:
arr = func_axis_0(arr)
# make the z-axis the shared axis
arr = jaxdecomp.transposeYtoZ(arr)
# apply the forward transform in the z-axis
if 2 in axes:
arr = func_axis_2(arr)
# restore the padding
arr = self.pad(arr)
arr = self.sync(arr)
return arr
return my_backward_transform
# ================================================================
# Padding
# ================================================================
@cached_property
def pad(self) -> callable:
@partial(shard_map, mesh=self.mesh, in_specs=self.p_phys, out_specs=self.p_phys)
def _pad(arr: ndarray) -> ndarray:
ncp = fr.config.ncp
if self.halo == 0:
return arr
arr = ncp.pad(arr, pad_width=self._padding, mode="wrap")
return arr
@partial(jax.jit, static_argnames=['flat_axes'])
def pad(arr: ndarray, flat_axes: list[int] | None = None) -> ndarray:
if flat_axes is not None:
raise NotImplementedError("Flat axes not supported in JaxDecomposition.")
return _pad(arr)
return pad
@cached_property
def unpad(self) -> callable:
@partial(jax.jit, static_argnames=['flat_axes'])
@partial(shard_map, mesh=self.mesh, in_specs=self.p_phys, out_specs=self.p_phys)
def unpad(arr: ndarray, flat_axes: tuple[int] | None = None) -> ndarray:
if flat_axes is not None:
raise NotImplementedError("Flat axes not supported in JaxDecomposition.")
if self.halo == 0:
return arr
return arr[self._inner_slice]
return unpad
# ----------------------------------------------------------------
# Spectral paddings
# ----------------------------------------------------------------
[docs]
def pad_extend(self, arr: ndarray) -> ndarray:
raise NotImplementedError("Spectral padding not supported in JaxDecomposition.")
[docs]
def unpad_extend(self, arr: ndarray) -> ndarray:
raise NotImplementedError("Spectral padding not supported in JaxDecomposition.")
[docs]
def pad_trim(self, arr: ndarray) -> ndarray:
raise NotImplementedError("Spectral padding not supported in JaxDecomposition.")
# ================================================================
# Gather
# ================================================================
[docs]
def gather(self,
arr: ndarray,
slc: tuple[slice] | None = None,
dest_rank: int | None = None,
spectral: bool = False) -> ndarray:
# first we unpad the array
if not spectral:
arr = self.unpad(arr)
# gather the array
slc = slc or (slice(None), )*self.n_dims
if jax.process_count() == 1:
# single process
return jax.device_get(arr[slc])
# multiple processes
if dest_rank is not None:
raise ValueError("dest_rank is not supported in JaxDecomposition.")
return multihost_utils.process_allgather(arr[slc], tiled=True)
# ================================================================
# Array creation
# ================================================================
@partial(jax.jit, static_argnames=['pad', 'spectral'])
def create_array(self,
pad: bool = True,
spectral: bool = False) -> ndarray:
dtype = fr.config.dtype_comp if spectral else fr.config.dtype_real
sharding = self.shard_spec if spectral else self.shard_phys
@partial(jax.jit, out_shardings=sharding)
def create_zeros():
return jax.numpy.zeros(self.shape, dtype=dtype)
arr = create_zeros()
if pad and not spectral:
arr = self.pad(arr)
arr = self.sync(arr)
return arr
[docs]
def create_random_array(self,
seed: int = 1234,
pad: bool = True,
spectral: bool = False) -> ndarray:
dtype = fr.config.dtype_comp if spectral else fr.config.dtype_real
sharding = self.shard_spec if spectral else self.shard_phys
@partial(jax.jit, out_shardings=sharding)
def create_random_array():
real = jax.random.normal(jax.random.PRNGKey(seed), self.shape)
if not spectral:
return real.astype(dtype)
imag = jax.random.normal(jax.random.PRNGKey(2*seed+3), self.shape)
return jax.numpy.array(real + 1j*imag, dtype=dtype)
arr = create_random_array()
if pad and not spectral:
arr = self.pad(arr)
arr = self.sync(arr)
return arr
[docs]
def create_meshgrid(self,
*args: ndarray,
pad: bool = True,
spectral: bool = False) -> tuple[ndarray]:
sharding = self.shard_spec if spectral else self.shard_phys
shardings = [sharding]*len(args)
@partial(jax.jit, out_shardings=shardings)
def create_meshgrid():
return jax.numpy.meshgrid(*args, indexing='ij')
arrs = create_meshgrid()
if pad and not spectral:
return tuple(self.pad(arr) for arr in arrs)
return arrs
# ================================================================
# Array operations
# ================================================================
[docs]
def sum(self,
arr: ndarray,
axes: list[int] | None = None,
spectral: bool = False) -> ndarray:
arr = self.unpad(arr)
return jax.numpy.sum(arr, axis=axes)
[docs]
def max(self,
arr: ndarray,
axes: list[int] | None = None,
spectral: bool = False) -> ndarray:
arr = self.unpad(arr)
return jax.numpy.max(arr, axis=axes)
[docs]
def min(self,
arr: ndarray,
axes: list[int] | None = None,
spectral: bool = False) -> ndarray:
arr = self.unpad(arr)
return jax.numpy.min(arr, axis=axes)
# ================================================================
# Helper functions
# ================================================================
[docs]
def shard_map(self, func: callable) -> callable:
return shard_map(func,
mesh=self.mesh,
in_specs=self.p_phys,
out_specs=self.p_phys)