Source code for fridom.nonhydro.modules.pressure_solvers.spectral_pressure_solver
"""Solver for the pressure field using a spectral method."""
from __future__ import annotations
from functools import partial
import numpy as np
import fridom.framework as fr
import fridom.nonhydro as nh
[docs]
@partial(fr.utils.jaxify, dynamic=("k_squared_inv",))
class SpectralPressureSolver(fr.modules.Module):
"""Solve for the pressure field with a spectral solver."""
name = "Spectral Pressure Solver"
def _on_setup(self) -> None:
match type(self.mset.grid):
case nh.grid.cartesian.Grid:
fft_required = True
dso = fr.grid.cartesian.discrete_spectral_operators
k2 = [dso.k_hat_squared(kx, dx, use_discrete=True)
for (kx,dx) in zip(self.grid.K, self.grid.dx)]
case nh.grid.spectral.Grid:
fft_required = False
k2 = [kx**2 for kx in self.grid.K]
case _:
msg = "The spectral solver does not support this grid type."
raise ValueError(msg)
# scaled discretized wave number squared
k_squared = k2[0] + k2[1] + k2[2] / self.mset.dsqr
# Compute the inverse of the wave number squared
with np.errstate(divide="ignore", invalid="ignore"):
k_squared_inv = 1 / k_squared
# Set k2_hat_inv to zero where k2_hat is zero
self.k_squared_inv = fr.config.ncp.where(k_squared == 0, 0, k_squared_inv)
self.fft_required = fft_required
@fr.utils.jaxjit
def _solve_for_pressure(self, div: fr.ScalarField) -> fr.ScalarField:
if self.fft_required:
return ( - div.fft() * self.k_squared_inv).ifft()
return - div * self.k_squared_inv
[docs]
@fr.modules.module_method
def update(self, mz: fr.ModelState) -> fr.ModelState: # noqa: D102
mz.z_diag.p.arr = self._solve_for_pressure(mz.z_diag.div).arr
return mz
@property
def info(self) -> dict: # noqa: D102
res = super().info
res["Solver"] = "Spectral"
return res