Source code for fridom.framework.grid.cartesian.spectral_diff

from copy import deepcopy
import fridom.framework as fr
from functools import partial

[docs] @fr.utils.jaxify class SpectralDiff(fr.grid.DiffModule): r""" Differentiation module in spectral space. Description ----------- In spectral space, the differentiation of a field is equivalent to a multiplication by the wavenumber and the imaginary unit: .. math:: u = U e^{ikx} \Rightarrow \partial_x u = ik u """ name = "Spectral Difference"
[docs] @fr.modules.module_method def setup(self, mset: 'fr.ModelSettingsBase') -> None: super().setup(mset) # check if the grid is either a cartesian grid or a spectral grid if not isinstance(self.mset.grid, (fr.grid.spectral.Grid, fr.grid.cartesian.Grid)): raise ValueError("SpectralDiff requires a spectral or cartesian grid") return
[docs] @partial(fr.utils.jaxjit, static_argnames=('axis', 'order')) def diff(self, f: fr.ScalarField, axis: int, order: int = 1) -> fr.ScalarField: # ---------------------------------------------------------------- # Transform to spectral space if necessary # ---------------------------------------------------------------- transformed = False if not f.is_spectral: fr.log.warning("Called diff on a non-spectral field.") fr.log.warning("Fourier transforming the field to spectral space, differentiating, and transforming back.") f = f.fft() transformed = True # ---------------------------------------------------------------- # Update the type of the boundary conditions # ---------------------------------------------------------------- bc_types = list(f.bc_types) if order % 2 == 1: match f.bc_types[axis]: case fr.grid.BCType.DIRICHLET: bc_types[axis] = fr.grid.BCType.NEUMANN case fr.grid.BCType.NEUMANN: bc_types[axis] = fr.grid.BCType.DIRICHLET # ---------------------------------------------------------------- # Compute the derivative # ---------------------------------------------------------------- res = fr.ScalarField(mset=f.mset, mdata=deepcopy(f.mdata)) res.bc_types = tuple(bc_types) k = self.grid.get_mesh(spectral=True)[axis] res.arr = f.arr * (1j * k) ** order # ---------------------------------------------------------------- # Transform back to physical space if necessary # ---------------------------------------------------------------- if transformed: res = res.fft() return res