Source code for fridom.framework.grid.spectral.grid

import fridom.framework as fr
import numpy as np
from functools import partial

[docs] @fr.utils.jaxify class Grid(fr.grid.cartesian.Grid):
[docs] def __init__(self, N: list[int], L: list[float], periodic_bounds: list[bool] | None = None, ) -> None: super().__init__(N=N, L=L, periodic_bounds=periodic_bounds, diff_mod=fr.grid.cartesian.SpectralDiff(), interp_mod=fr.grid.DummyInterpolation()) self.name = "Spectral Grid" self.mpi_available = False self.spectral_grid = True
[docs] def setup(self, mset: 'fr.ModelSettingsBase'): super().setup(mset, req_halo=0)
[docs] def get_mesh(self, position: fr.grid.Position | None = None, spectral: bool = False ) -> tuple[np.ndarray]: return super().get_mesh(position=self.cell_center, spectral=spectral)
[docs] def fft(self, arr: np.ndarray, padding = fr.grid.FFTPadding.NOPADDING, bc_types: tuple[fr.grid.BCType] | None = None, positions: tuple[fr.grid.AxisPosition] | None = None, ) -> np.ndarray: return super().fft(arr=arr, padding=padding, bc_types=bc_types, positions=self.cell_center)
[docs] def ifft(self, arr: np.ndarray, padding = fr.grid.FFTPadding.NOPADDING, bc_types: tuple[fr.grid.BCType] | None = None, positions: tuple[fr.grid.AxisPosition] | None = None, ) -> np.ndarray: return super().ifft(arr=arr, padding=padding, bc_types=bc_types, positions=self.cell_center)