Source code for fridom.nonhydro.modules.pressure_solvers.spectral_pressure_solver
import fridom.framework as fr
import fridom.nonhydro as nh
import numpy as np
[docs]
@fr.utils.jaxify
class SpectralPressureSolver(fr.modules.Module):
"""
This class solves the pressure field with a spectral solver.
"""
name = "Spectral Pressure Solver"
[docs]
@fr.modules.module_method
def setup(self, mset: 'nh.ModelSettings') -> None:
super().setup(mset)
match type(mset.grid):
case nh.grid.cartesian.Grid:
fft_required = True
from fridom.framework.grid.cartesian import discrete_spectral_operators as dso
k2 = [dso.k_hat_squared(kx, dx, use_discrete=True)
for (kx,dx) in zip(self.grid.K, self.grid.dx)]
case nh.grid.spectral.Grid:
fft_required = False
k2 = [kx**2 for kx in self.grid.K]
case _:
raise ValueError("The spectral solver does not support this grid type.")
# scaled discretized wave number squared
k_squared = k2[0] + k2[1] + k2[2] / mset.dsqr
# Compute the inverse of the wave number squared
with np.errstate(divide='ignore', invalid='ignore'):
k_squared_inv = 1 / k_squared
# Set k2_hat_inv to zero where k2_hat is zero
self.k_squared_inv = fr.config.ncp.where(k_squared == 0, 0, k_squared_inv)
self.fft_required = fft_required
return
[docs]
@fr.utils.jaxjit
def solve_for_pressure(self, div: fr.FieldVariable) -> fr.FieldVariable:
if self.fft_required:
return ( - div.fft() * self.k_squared_inv).fft()
else:
return - div * self.k_squared_inv
[docs]
@fr.modules.module_method
def update(self, mz: fr.ModelState) -> fr.ModelState:
mz.z_diag.p.arr = self.solve_for_pressure(mz.z_diag.div).arr
return mz
@property
def info(self) -> dict:
res = super().info
res["Solver"] = "Spectral"
return res