Source code for fridom.framework.modules.closures.harmonic_diffusion
"""Harmonic diffusion module."""
from __future__ import annotations
from functools import partial
import fridom.framework as fr
[docs]
@partial(fr.utils.jaxify, dynamic=("_diffusion_coefficients",))
class HarmonicDiffusion(fr.modules.Module):
r"""
Harmonic diffusion module.
Description
-----------
The harmonic diffusion operator :math:`\mathcal{H}` on a scalar field
:math:`u` is given by:
.. math::
\mathcal{H}(u) = \nabla \cdot \left (\mathbf{A} \cdot \nabla u \right)
with the diagonal diffusion tensor :math:`\mathbf{A}` given by:
.. math::
\mathbf{A} = \begin{pmatrix} \kappa_1 & \dots & 0 \\
\vdots & \ddots & \vdots \\
0 & \dots & \kappa_n \end{pmatrix}
where :math:`\kappa_i` is the harmonic diffusion coefficient in the
:math:`i`-th direction.
Parameters
----------
field_flags : list[str]
A list of strings that indicate which fields should be diffused.
For example, if `field_flags=["ENABLE_MIXING"]`, all fields with the
flag "ENABLE_MIXING" will be diffused. For more information on possible
flags, see :py:mod:`fridom.framework.ScalarField`.
diffusion_coefficients : tuple[float | fr.ScalarField]
A tuple of diffusion coefficients. The length of the tuple must match
the number of dimensions of the grid.
name : str, (default="Harmonic Diffusion")
Name of the module.
"""
name = "Harmonic Diffusion"
[docs]
def __init__(self,
field_flags: list[str],
diffusion_coefficients: list[float | fr.ScalarField]) -> None:
super().__init__()
self.field_flags = field_flags
self.diffusion_coefficients = diffusion_coefficients
[docs]
@fr.utils.jaxjit
def diffusion_operator(self, u: fr.ScalarField) -> fr.ScalarField:
r"""Apply the harmonic diffusion operator on a scalar field :math:`u`."""
# compute the gradient of the field
grad_u = list(self.diff_module.grad(u))
# multiply the gradient with the diffusion coefficients
for i, coeff in enumerate(self.diffusion_coefficients):
if isinstance(coeff, fr.ScalarField):
# interpolate the diffusion coefficient to the position of the field
c = self.interp_module.interpolate(coeff, grad_u[i].position)
else:
c = coeff
grad_u[i] *= c
# compute the divergence of the gradient
return self.diff_module.div(tuple(grad_u))
[docs]
@fr.utils.jaxjit
def diffuse(self, z: fr.VectorField, dz: fr.VectorField) -> fr.VectorField:
# loop over all fields
#TODO(Silvano): Use new vector field methods
for name, field in z.fields.items():
if not any([field.flags[flag] for flag in self.field_flags]):
# skip the field if it does not have any of the field flags
continue
# apply the diffusion operator
dz.fields[name] += self.diffusion_operator(field)
return dz
[docs]
@fr.modules.module_method
def update(self, mz: fr.ModelState) -> fr.ModelState: # noqa: D102
mz.dz = self.diffuse(mz.z, mz.dz)
return mz
# ----------------------------------------------------------------
# Properties
# ----------------------------------------------------------------
@property
def field_flags(self) -> list[str]:
"""A list of field flags that indicate which fields should be diffused."""
return self._field_flags
@field_flags.setter
def field_flags(self, value: list[str]) -> None:
self._field_flags = value
@property
def diffusion_coefficients(self) -> list[float | fr.ScalarField]:
"""A list of diffusion coefficients."""
return self._diffusion_coefficients
@diffusion_coefficients.setter
def diffusion_coefficients(self, value: list[float | fr.ScalarField]) -> None:
self._diffusion_coefficients = value