Source code for fridom.framework.grid.cartesian.weno_interpolation

"""WENO interpolation module."""
from __future__ import annotations

from copy import deepcopy
from typing import TYPE_CHECKING

import fridom.framework as fr

if TYPE_CHECKING:  # pragma: no cover
    import numpy as np

[docs] @fr.utils.jaxify class InterWENO(fr.grid.cartesian.InterENO): """ WENO interpolation module. Description ----------- TODO Parameters ---------- order : int Order of the WENO interpolation (must be odd), default is 5. eps : float Small value to avoid division by zero, default is 1e-10. """ name = "WENO Interpolation" max_implementation_order = 5
[docs] def __init__(self, order: int = 5, eps: float = 1e-10) -> None: super().__init__() if order % 2 == 0: msg = f"Order {order} is not odd. Please use an odd order for WENO." raise ValueError(msg) if order > self.max_implementation_order: msg = f"Order {order} is too high. " msg += f"Please use an order <= {self.max_implementation_order}." raise ValueError(msg) self.order = order self.stencil_size = (order + 1) // 2 self.required_halo = self.stencil_size self.eps = eps self.pol_coeffs = None # Coefficients for polynomial reconstruction self.stencil_slices = None self.d_coeffs = None self.beta_inner = None self.beta_outer = None
def _on_setup(self) -> None: # check that the grid is periodic in all dimensions if not all(self.grid.periodic_bounds): msg = "WENO only works on periodic grids" raise ValueError(msg) # Initialize polynomial coefficients self.pol_coeffs = self.compute_polynomial_coefficients_cell_average( self.stencil_size) # Precompute the axis slices for the stencil self.stencil_slices = self._compute_stencil_slices( self.stencil_size, self.grid.n_dims) # Precompute the d coefficients for the WENO reconstruction self.d_coeffs = self.compute_d_coeffs(self.stencil_size) # Precompute the beta coefficients and slices self.beta_coeffs = self.compute_beta_coeffs( self.stencil_size) self.beta_slices = self._compute_b_slices( self.stencil_size, self.grid.n_dims) def _compute_stencil_slices(self, stencil_size: int, ndim: int, ) -> list[list[tuple[slice]]]: """ Precompute the slices for each stencil. The first index of the output list corresponds to the axis, and the second index corresponds to the stencil offset. """ k = stencil_size # Precompute the axis slices for the stencil axis_slices = [ slice(i if i != 0 else None, -k + i + 1 if i != k - 1 else None) for i in range(k) ] # Build full slice tuples per stencil offset and per axis return [[ tuple(slice(None) if ax != axis else axis_slices[i] for ax in range(ndim)) for i in range(k) ] for axis in range(ndim)] def _compute_b_slices(self, stencil_size: int, ndim: int, ) -> list[list[tuple[slice]]]: """ Precompute the slices for each stencil for the b coefficients. The first index of the output list corresponds to the axis, and the second index corresponds to the stencil offset. """ k = stencil_size # Precompute the axis slices for the stencil axis_slices = [ slice(i if i != 0 else None, -2*(k-1)+i if i != 2*(k-1) else None) for i in range(2*k-1) ] # Build full slice tuples per stencil offset and per axis return [[ tuple(slice(None) if ax != axis else axis_slices[i] for ax in range(ndim)) for i in range(2*k-1) ] for axis in range(ndim)]
[docs] def compute_d_coeffs(self, stencil_size: int) -> np.ndarray: """Compute the coefficients d_i for the WENO reconstruction.""" tabulated_coeffs = { 1: [1.0], 2: [1.0/3.0, 2.0/3.0], 3: [1.0/10.0, 6.0/10.0, 3.0/10.0], 4: [1.0/35.0, 12.0/35.0, 18.0/35.0, 4.0/35.0], 5: [1.0/126.0, 20.0/126.0, 60.0/126.0, 40.0/126.0, 5.0/126.0], }[stencil_size] return fr.config.ncp.asarray(tabulated_coeffs, dtype=fr.config.dtype_real)
[docs] def compute_beta_coeffs(self, stencil_size: int) -> tuple[np.ndarray]: """Coefficients that are used to compute the smoothness indicators.""" beta_order3 = [[[-1, 1]], [[-1, 1]]] beta_order5 = [[[1, -2, 1], [1, -4, 3]], [[1, -2, 1], [1, 0, -1]], [[1, -2, 1], [3, -4, 1]]] beta_inner = {2: beta_order3, 3: beta_order5}[stencil_size] beta_outer = { 2: [1.0], 3: [13.0/12.0, 1.0/4.0], }[stencil_size] return (fr.config.ncp.asarray(beta_inner, dtype=fr.config.dtype_real), fr.config.ncp.asarray(beta_outer, dtype=fr.config.dtype_real) )
[docs] @fr.utils.jaxjit def reconstruct(self, f: fr.ScalarField, destination: fr.grid.Position) -> tuple[fr.ScalarField]: """ Reconstruct the field at the destination points using WENO interpolation. Parameters ---------- f : fr.ScalarField The field to be reconstructed. destination : fr.grid.Position The position where the field is to be reconstructed. (Can only reconstuct in one dimension at a time) Returns ------- tuple[fr.ScalarField] Left and right biased reconstructed fields. (Should be combined with a flux limiter to get the final reconstructed field) """ # First, we need to find out the axis in which we are reconstructing axis = self._get_reconstruction_axis(f, destination) # Get the interpolation from the individual stencils interpolations = self._interpolate_stencils(f, axis) # Compute the smoothness indicators beta = self.compute_smoothness_indicators(f, axis) # Compute the left and right biased weights left_weights = self._compute_weights(self.d_coeffs, beta) right_weights = self._compute_weights(self.d_coeffs[::-1], beta) # left_weights = self.d_coeffs # right_weights = self.d_coeffs[::-1] # Compute the left and right biased reconstructed fields slices = self.stencil_slices[axis] left_arr = sum(left_weights[i] * interpolations[i][slices[i]] for i in range(self.stencil_size)) right_arr = sum(right_weights[i] * interpolations[i+1][slices[i]] for i in range(self.stencil_size)) # Create the reconstructed scalar fields left = fr.ScalarField(f.mset, mdata=deepcopy(f.mdata)) right = fr.ScalarField(f.mset, mdata=deepcopy(f.mdata)) # If the destination is the face, we interpolate one to the right offset = 0 if destination.positions[axis] == fr.grid.AxisPosition.FACE: offset = 1 # Set the reconstructed values right.arr = fr.utils.modify_array(right.arr, self.beta_slices[axis][self.stencil_size-1-offset], right_arr) left.arr = fr.utils.modify_array(left.arr, self.beta_slices[axis][self.stencil_size-offset], left_arr) # Set the positions of the reconstructed fields left.position = destination right.position = destination return left, right
def _get_reconstruction_axis(self, f: fr.ScalarField, destination: fr.grid.Position) -> int: """ Get the axis in which we are reconstructing. Parameters ---------- f : fr.ScalarField The field to be reconstructed. destination : fr.grid.Position The position where the field is to be reconstructed. (Can only reconstuct in one dimension at a time) Returns ------- int The axis in which we are reconstructing. Raises ------ ValueError 1. If the reconstruction is not done in one dimension. 2. If the destination position is the same as the field position. """ # First, we need to find out the axis in which we are reconstructing diff_axes = [i for i, (pos, dest) in enumerate( zip(f.position.positions, destination.positions)) if pos != dest] if len(diff_axes) > 1: msg = "Reconstruction can only be done in one dimension at a time" raise ValueError(msg) if len(diff_axes) == 0: msg = "Destination position is the same as the field position" raise ValueError(msg) return diff_axes[0] def _interpolate_stencils(self, f: fr.ScalarField, axis: int) -> list[np.ndarray]: """ Interpolate to each point in the stencil. Parameters ---------- f : fr.ScalarField The field to be reconstructed. axis : int The axis in which we are reconstructing. Returns ------- list[np.ndarray] The interpolated values at each point in the stencil. """ k = self.stencil_size slices = self.stencil_slices[axis] coeffs = self.pol_coeffs return [sum(coeffs[n,i] * f.arr[slices[i]] for i in range(k)) for n in range(k+1)][::-1] # reverse the order of the stencils
[docs] def compute_smoothness_indicators(self, f: fr.ScalarField, axis: int) -> list[np.ndarray]: """Compute the smoothness indicators for each stencil.""" stencil_size = self.stencil_size bi, co = self.beta_coeffs slices = self.beta_slices[axis] return [ sum( co[j] * sum( bi[n,j,i] * f.arr[slices[i+n]] for i in range(stencil_size) # loop over each point in stencil ) ** 2 # square the derivative for j in range(stencil_size - 1)) # derivative loop for n in range(stencil_size) # loop over stencils ]
def _compute_weights(self, d: np.ndarray, beta: list[np.ndarray]) -> tuple[np.ndarray]: # compute the weights weights = [d[i] / ( (self.eps + beta[i])**2 ) for i in range(self.stencil_size)] # normalize the weights weights_sum = sum(weights) return [w / weights_sum for w in weights]