Source code for fridom.framework.modules.flux_functions.upwind

"""Upwind flux function."""
from __future__ import annotations

from copy import deepcopy
from typing import TYPE_CHECKING

import fridom.framework as fr

if TYPE_CHECKING:  # pragma: no cover
    import numpy as np


[docs] @fr.utils.jaxify class Upwind(fr.modules.flux_functions.FluxFunctionBase): r""" Upwind flux function. Description ----------- Let's assume we have a flux from which we want to compute the divergence .. math:: \partial_x F \quad \text{where } F = u q with the advection velocity :math:`u` and the advected quantity :math:`q`. Let's assume we have two estimations of the flux :math:`F_L` and :math:`F_R` where :math:`F_L` is a flux which is biased to the left (e.g. for its computation was more influenced by values on the left side of the cell) and :math:`F_R` is a flux which is biased to the right. The upwind flux is computed by selecting the flux based on the sign of the velocity :math:`u`: .. math:: F = \begin{cases} F_L & \text{if } u \ge 0 \\ F_R & \text{if } u < 0 \end{cases} """ name = "Upwind"
[docs] @fr.utils.jaxjit def compute(self, # noqa: D102 flux_left: fr.ScalarField, flux_right: fr.ScalarField, velocity: fr.ScalarField) -> fr.ScalarField: # We assume that the flux_left and flux_right are at the same position # interpolate velocity to the flux position u = self.interp_module.interpolate(velocity, flux_left.position) flux = fr.ScalarField(mset=flux_left.mset, mdata=deepcopy(flux_left.mdata)) @self.grid.domain_decomp.shard_map def _flux(right_flux: np.ndarray, left_flux: np.ndarray, u: np.ndarray) -> np.ndarray: return fr.config.ncp.where(u>=0, left_flux, right_flux) flux.arr = _flux(flux_right.arr, flux_left.arr, u.arr) return flux