Source code for fridom.framework.modules.forcings.relaxation

"""A relaxation module."""
from __future__ import annotations

from functools import partial

import fridom.framework as fr


[docs] @partial(fr.utils.jaxify, dynamic=("target", "domain")) class Relaxation(fr.modules.Module): r""" A relaxation module for scalar fields. Description ----------- This module implements the relaxation operator :math:`\mathcal{R}(\phi)` for a scalar field :math:`\phi`. The relaxation operator is defined as: .. math:: \mathcal{R}(\phi) = \frac{\phi^* - \phi}{\tau} \delta_\Omega where :math:`\phi^*` is the target value of the field, :math:`\tau` is the relaxation time scale, and :math:`\delta_\Omega` is one on the domain :math:`\Omega` and zero elsewhere. At each time step, :math:`\mathcal{R}(\phi)` is added to the tendency of the field :math:`\phi`. The analytical solution of the relaxation operator with no other forcing terms is: .. math:: \partial_t \phi = \frac{\phi^* - \phi}{\tau} \Rightarrow \phi(t) = \phi^* + C e^{-t/\tau} with a constant :math:`C`. The relaxation operator can be used to add heating or cooling to a fluid, but also for example for wind stress forcing. Parameters ---------- tau : float The relaxation time scale :math:`\tau`. field_name : str The name of the field that should be relaxed. target : float | fr.ScalarField The target value of the field. domain_function : callable A function that takes the mesh as input and returns a boolean array that indicates the domain where the relaxation should be applied. """ name = "Relaxation"
[docs] def __init__(self, tau: float, field_name: str, target: float | fr.ScalarField, domain_function: callable) -> None: super().__init__() self.tau = tau self.field_name = field_name if type(target) is fr.ScalarField: target = target.arr self.target = target self.domain_function = domain_function self.domain = None
def _on_setup(self) -> None: z = self.mset.state_constructor() mesh = z[self.field_name].get_mesh() self.domain = self.domain_function(mesh) del z
[docs] @fr.modules.module_method def update(self, mz: fr.ModelState) -> fr.ModelState: # noqa: D102 mz.dz = self.relax(mz.z, mz.dz) return mz
[docs] @fr.utils.jaxjit def relax(self, z: fr.VectorField, dz: fr.VectorField) -> fr.VectorField: ncp = fr.config.ncp delta = (self.target - z[self.field_name].arr) / self.tau dz[self.field_name].arr += ncp.where(self.domain, delta, 0) return dz