Source code for fridom.nonhydro.modules.advection.flux_divergence_base

"""Base class for flux divergence advection schemes."""
from __future__ import annotations

from abc import abstractmethod

import fridom.framework as fr
import fridom.nonhydro as nh


[docs] class FluxDivergenceBase(fr.modules.advection.AdvectionBase): r""" Base class for flux divergence advection schemes. Description ----------- Assuming that the velocity field is divergence-free, then we can write the advection of the momentum flux tensor as: .. math:: (\mathbf{u} \cdot \nabla) \mathbf{u} = \nabla \cdot \mathbf{F} where :math:`\mathbf{F}` is the momentum flux tensor: .. math:: \mathbf{F} = \mathbf{u}^T \mathbf{u} = \begin{bmatrix} u^2 & uv & uw \\ uv & v^2 & vw \\ uw & vw & w^2 \end{bmatrix} Child classes must implement the compute_momentum_flux_tensor method. """
[docs] @abstractmethod def compute_momentum_flux_tensor(self, velocity: nh.VectorField, ) -> tuple[nh.VectorField]: """ Compute the momentum flux tensor. Parameters ---------- velocity : nh.VectorField Velocity field (u, v, w). Returns ------- tuple[nh.VectorField] Momentum flux tensor (vel*u, vel*v, vel*w). """ msg = "Subclasses must implement the compute_momentum_flux_tensor method." raise NotImplementedError(msg)
[docs] @abstractmethod def compute_tracer_advection(self, velocity: nh.VectorField, field: nh.ScalarField) -> nh.ScalarField: """ Compute the advection of a tracer field. Parameters ---------- velocity : nh.VectorField Velocity field (u, v, w). field : nh.ScalarField Tracer field to be advected. Returns ------- nh.ScalarField Advection term for the tracer field. """ msg = "Subclasses must implement the compute_tracer_advection method." raise NotImplementedError(msg)
[docs] @fr.utils.jaxjit def advect_state(self, z: nh.State, dz: nh.State) -> nh.State: """Advect the state vector.""" # get the differential operator diff = self.diff_module # calculate the momentum flux tensor fu, fv, fw = self.compute_momentum_flux_tensor(z.velocity) # calculate the flux divergence dz.u += self.scaling * diff.div(fu) dz.v += self.scaling * diff.div(fv) dz.w += self.scaling * diff.div(fw) # advection of tracer fields for field in z.tracers: if field.flags["NO_ADV"]: continue dz[field.name] += self.scaling * self.compute_tracer_advection( z.velocity, field) return dz