import fridom.framework as fr
# Import external modules
import numpy as np
from functools import partial
# Import internal modules
from fridom.framework import config, utils
def _create_kn_mesh(N: int):
ncp = config.ncp
n = ncp.arange(0, N)
n, k = ncp.meshgrid(n, n, indexing="ij")
return n, k
def _apply_weights(x, weights, axis):
ncp = config.ncp
y = ncp.tensordot(x, weights, axes=([axis], [0]))
y = ncp.moveaxis(y, -1, axis)
return y
@partial(utils.jaxjit, static_argnames=['axis', 'N'])
def dct_type2(x, axis, N):
ncp = config.ncp
n, k = _create_kn_mesh(N)
weights = 2 * ncp.cos((ncp.pi / N) * k * (n + 0.5))
return _apply_weights(x, weights, axis)
@partial(utils.jaxjit, static_argnames=['axis', 'N'])
def idct_type2(x, axis, N):
ncp = config.ncp
k, n = _create_kn_mesh(N)
weights = 2 * ncp.cos((ncp.pi / N) * k * (n + 0.5))
weights = utils.modify_array(weights, (0, slice(None)), 1)
return _apply_weights(x, weights, axis) / (2 * N)
@partial(utils.jaxjit, static_argnames=['axis', 'N'])
def dst_type1(x, axis, N):
# we assume that the position of the variable is at the cell edges
# |-----x-----|-----x-----|-----x-----|-----x-----|
# ^ ^ ^ ^
# x0 x1 x2 x(N-1)
# A function f with frequency k is given by:
# f(xi) = sin(k*(xi+dx/2))
# = -i/2 * (exp(i*k*(xi+dx/2)) - exp(-i*k*(xi+dx/2)))
# we only consider positive frequencies in the sine transform
# f(xi) = -i/2 * exp(i*k*(xi+dx/2))
# = -i/2 * exp(i*k*dx/2) * exp(i*k*xi)
# the factor 1/2 does not matter, but we need the rotation by -i*exp(i*k*dx/2)
# so that the sine transform is consistent with fourier transforms
# Note that dx is given by pi/N
ncp = config.ncp
n, k = _create_kn_mesh(N)
weights = 2 * ncp.sin(ncp.pi * k * (n+1) / N)
# apply the rotation factor
weights = weights * -1j * ncp.exp(1j*k*ncp.pi/(2*N))
return _apply_weights(x, weights, axis)
@partial(utils.jaxjit, static_argnames=['axis', 'N'])
def idst_type1(x, axis, N):
ncp = config.ncp
k, n = _create_kn_mesh(N)
weights = 2 * ncp.sin(ncp.pi * k * (n+1) / N)
# similar as the dst1, we need to apply the inverse rotation factor
weights = weights * 1j * ncp.exp(-1j*k*ncp.pi/(2 * N))
return _apply_weights(x, weights, axis) / (2 * N)
@partial(utils.jaxjit, static_argnames=['axis', 'N'])
def dst_type2(x, axis, N):
ncp = config.ncp
n, k = _create_kn_mesh(N)
weights = -2j * ncp.sin(ncp.pi * k * (2*n+1) / (2*N))
return _apply_weights(x, weights, axis)
@partial(utils.jaxjit, static_argnames=['axis', 'N'])
def idst_type2(x, axis, N):
ncp = config.ncp
k, n = _create_kn_mesh(N)
weights = 2j * ncp.sin(ncp.pi * k * (2*n+1) / (2*N))
return _apply_weights(x, weights, axis) / (2 * N)
[docs]
@utils.jaxify
class FFT:
"""
Class for performing fourier transforms on a cartesian grid.
Description
-----------
Model grids that have periodic boundary conditions in some directions,
and non-periodic boundary conditions in other directions, require a
combination of fast fourier transforms and discrete cosine transforms.
This class provides a method to transform an array from physical space to
spectral space and back. For the discrete cosine transform, the type 2
transform is used. This means that the variable must be located at the
cell centers in that direction.
Parameters
----------
`periodic` : `tuple[bool]`
A list of booleans that indicate whether the axis is periodic.
If True, the axis is periodic, if False, the axis is non-periodic.
Examples
--------
.. code-block:: python
import numpy as np
from fridom.framework.grid.cartesian import FFT
fft = FFT(periodic=(True, True, False))
u = np.random.rand(*(32, 32, 8))
v = fft.forward(u)
w = fft.backward(v).real
assert np.allclose(u, w)
"""
[docs]
def __init__(self,
periodic: tuple[bool]) -> None:
# --------------------------------------------------------------
# Check which axis to apply fft, dct
# --------------------------------------------------------------
fft_axes = [] # Periodic axes (fast fourier transform)
dct_axes = [] # Non-periodic axes (discrete cosine transform)
for i in range(len(periodic)):
if periodic[i]:
fft_axes.append(i)
else:
dct_axes.append(i)
# --------------------------------------------------------------
# Set the attributes
# --------------------------------------------------------------
# private attributes
self._periodic = periodic
self._fft_axes = fft_axes
self._dct_axes = dct_axes
return
[docs]
def get_freq(self, shape: tuple[int],
dx: tuple[float], ) -> tuple[np.ndarray]:
"""
Get the frequencies for the given shape and dx.
Description
-----------
This method calculates the frequencies for the given shape and dx. The
returned frequencies could be used to construct wavenumber meshgrids.
Parameters
----------
`shape` : `tuple[int]`
The global shape (number of grid points in each direction).
`dx` : `tuple[float]`
The grid spacing in each direction.
Returns
-------
`tuple[np.ndarray]`
The frequencies in each direction.
Examples
--------
.. code-block:: python
import numpy as np
from fridom.framework.grid.cartesian import FFT
fft = FFT(periodic=(True, True, False))
shape = (32, 32, 8) # Number of grid points in x,y,z
dx = (0.1, 0.1, 0.1) # Grid spacing in x,y,z
kx, ky, kz = fft.get_freq(shape, dx)
KX, KY, KZ = np.meshgrid(kx, ky, kz, indexing='ij')
"""
ncp = config.ncp
k = []
for i in range(len(shape)):
if self._periodic[i]:
k.append(ncp.fft.fftfreq(shape[i], dx[i]/(2*ncp.pi)))
else:
k.append(ncp.linspace(0, ncp.pi/dx[i], shape[i], endpoint=False))
return tuple(k)
[docs]
@partial(utils.jaxjit, static_argnames=['axes', 'bc_types', 'positions'])
def forward(self,
u: np.ndarray,
axes: list[int] | None = None,
bc_types: tuple[fr.grid.BCType] | None = None,
positions: tuple[fr.grid.AxisPosition] | None = None,
) -> np.ndarray:
"""
Forward transform from physical space to spectral space.
Parameters
----------
`u` : `np.ndarray`
The array to transform from physical space to spectral space.
`axes` : `list[int] | None`
The axes to transform. If None, all axes are transformed.
`bc_types` : `tuple[fr.grid.BCType] | None`
The type of boundary conditions for each axis.
`positions` : `tuple[fr.grid.AxisPosition] | None`
The position of the variable in each direction.
Returns
-------
`np.ndarray`
The transformed array in spectral space. If all dimensions are
periodic, the obtained array is real, else it is complex.
"""
ncp = config.ncp; scp = config.scp
# Get the axes to apply fft, dct
if axes is None:
fft_axes = self._fft_axes
dct_axes = self._dct_axes
else:
fft_axes = list(set(axes) & set(self._fft_axes))
dct_axes = list(set(axes) & set(self._dct_axes))
u_hat = u
if bc_types is None:
bc_types = tuple(fr.grid.BCType.NEUMANN for _ in range(u.ndim))
if positions is None:
positions = tuple(fr.grid.AxisPosition.CENTER for _ in range(u.ndim))
# discrete cosine transform
for axis in dct_axes:
if bc_types[axis] == fr.grid.BCType.NEUMANN:
if config.backend_is_jax:
u_hat = dct_type2(u_hat, axis, u_hat.shape[axis])
else:
u_hat = scp.fft.dct(u_hat, axis=axis)
if bc_types[axis] == fr.grid.BCType.DIRICHLET:
if positions[axis] == fr.grid.AxisPosition.CENTER:
u_hat = dst_type2(u_hat, axis, u_hat.shape[axis])
if positions[axis] == fr.grid.AxisPosition.FACE:
u_hat = dst_type1(u_hat, axis, u_hat.shape[axis])
# fourier transform for periodic boundary conditions
u_hat = ncp.fft.fftn(u_hat, axes=fft_axes)
return u_hat
[docs]
@partial(utils.jaxjit, static_argnames=['axes', 'bc_types', 'positions'])
def backward(self, u_hat: np.ndarray,
axes: list[int] | None = None,
bc_types: tuple[fr.grid.BCType] | None = None,
positions: tuple[fr.grid.AxisPosition] | None = None,
) -> np.ndarray:
"""
Backward transform from spectral space to physical space.
Parameters
----------
`u_hat` : `np.ndarray`
The array to transform from spectral space to physical space.
`axes` : `list[int] | None`
The axes to transform. If None, all axes are transformed.
`bc_types` : `tuple[fr.grid.BCType] | None`
The type of boundary conditions for each axis.
`positions` : `tuple[fr.grid.AxisPosition] | None`
The position of the variable in each direction.
Returns
-------
`np.ndarray`
The transformed array in physical space.
"""
ncp = config.ncp; scp = config.scp
if axes is None:
fft_axes = self._fft_axes
dct_axes = self._dct_axes
else:
fft_axes = list(set(axes) & set(self._fft_axes))
dct_axes = list(set(axes) & set(self._dct_axes))
# fourier transform for periodic boundary conditions
u = ncp.fft.ifftn(u_hat, axes=fft_axes)
if bc_types is None:
bc_types = tuple(fr.grid.BCType.NEUMANN for _ in range(u.ndim))
if positions is None:
positions = tuple(fr.grid.AxisPosition.CENTER for _ in range(u.ndim))
# discrete cosine transform
for axis in dct_axes:
if bc_types[axis] == fr.grid.BCType.NEUMANN:
if config.backend_is_jax:
u = idct_type2(u, axis, u.shape[axis])
else:
u = scp.fft.idct(u, axis=axis)
if bc_types[axis] == fr.grid.BCType.DIRICHLET:
if positions[axis] == fr.grid.AxisPosition.CENTER:
u = idst_type2(u, axis, u.shape[axis])
if positions[axis] == fr.grid.AxisPosition.FACE:
u = idst_type1(u, axis, u.shape[axis])
return u