Source code for fridom.framework.modules.closures.harmonic_diffusion
import fridom.framework as fr
from functools import partial
[docs]
@partial(fr.utils.jaxify, dynamic=("_diffusion_coefficients",))
class HarmonicDiffusion(fr.modules.Module):
r"""
Harmonic diffusion module
Description
-----------
The harmonic diffusion operator :math:`\mathcal{H}` on a scalar field
:math:`u` is given by:
.. math::
\mathcal{H}(u) = \nabla \cdot \left (\mathbf{A} \cdot \nabla u \right)
with the diagonal diffusion tensor :math:`\mathbf{A}` given by:
.. math::
\mathbf{A} = \begin{pmatrix} \kappa_1 & \dots & 0 \\
\vdots & \ddots & \vdots \\
0 & \dots & \kappa_n \end{pmatrix}
where :math:`\kappa_i` is the harmonic diffusion coefficient in the
:math:`i`-th direction.
Parameters
----------
`field_flags` : `list[str]`
A list of strings that indicate which fields should be diffused.
For example, if `field_flags=["ENABLE_MIXING"]`, all fields with the
flag "ENABLE_MIXING" will be diffused. For more information on possible
flags, see :py:mod:`fridom.framework.FieldVariable`.
`diffusion_coefficients` : `tuple[float | fr.FieldVariable]`
A tuple of diffusion coefficients. The length of the tuple must match
the number of dimensions of the grid.
`name` : `str`, (default="Harmonic Diffusion")
Name of the module.
"""
name = "Harmonic Diffusion"
[docs]
def __init__(self,
field_flags: list[str],
diffusion_coefficients: list[float | fr.FieldVariable]):
super().__init__()
self.field_flags = field_flags
self.diffusion_coefficients = diffusion_coefficients
return
[docs]
@fr.utils.jaxjit
def diffusion_operator(self, u: fr.FieldVariable) -> fr.FieldVariable:
r"""
Applies the harmonic diffusion operator on a scalar field :math:`u`.
"""
# compute the gradient of the field
grad_u = list(self.diff_module.grad(u))
# multiply the gradient with the diffusion coefficients
for i, coeff in enumerate(self.diffusion_coefficients):
if isinstance(coeff, fr.FieldVariable):
# interpolate the diffusion coefficient to the position of the field
c = self.interp_module.interpolate(coeff, grad_u[i].position)
else:
c = coeff
grad_u[i] *= c
# compute the divergence of the gradient
div_u = self.diff_module.div(tuple(grad_u))
return div_u
[docs]
@fr.utils.jaxjit
def diffuse(self, z: fr.StateBase, dz: fr.StateBase) -> fr.StateBase:
# loop over all fields
for name, field in z.fields.items():
if not any([field.flags[flag] for flag in self.field_flags]):
# skip the field if it does not have any of the field flags
continue
# apply the diffusion operator
dz.fields[name] += self.diffusion_operator(field)
return dz
[docs]
@fr.modules.module_method
def update(self, mz: fr.ModelState) -> fr.ModelState:
mz.dz = self.diffuse(mz.z, mz.dz)
return mz
# ----------------------------------------------------------------
# Properties
# ----------------------------------------------------------------
@property
def field_flags(self) -> list[str]:
"""A list of field flags that indicate which fields should be diffused."""
return self._field_flags
@field_flags.setter
def field_flags(self, value):
self._field_flags = value
return
@property
def diffusion_coefficients(self) -> list[float | fr.FieldVariable]:
"""A list of diffusion coefficients."""
return self._diffusion_coefficients
@diffusion_coefficients.setter
def diffusion_coefficients(self, value):
self._diffusion_coefficients = value
return