Source code for fridom.framework.grid.cartesian.polynomial_interpolation
from copy import deepcopy
import fridom.framework as fr
from functools import partial
[docs]
@partial(fr.utils.jaxify, dynamic=('water_mask', ))
class PolynomialInterpolation(fr.grid.InterpolationModule):
r"""
Polynomial interpolation for cartesian grids.
Description
-----------
Consider the following grid points:
.. math::
x_i = (i - n/2) \Delta x, \quad i = 0, 1, \ldots, n
where :math:`n` is the (odd) order of the polynomial interpolation. For
example for :math:`n = 3` we have the following grid points:
::
We want to interpolate the field to this point (x=0)
↓
| x_0 | x_1 | x_2 | x_3 |
x/dx = -3/2 -1/2 1/2 3/2
Let :math:`f_i` be the field values at :math:`x_i`. We define the
continuous extension of the field as:
.. math::
f(x) = \sum_{i=0}^{n} \left(
\prod_{j=0, j \neq i}^{n} \left(
\frac{x - x_j}{x_i - x_j} f_i
\right)
\right)
By definition, :math:`f(x_i) = f_i` holds. Finally, to interpolate the
field to the point :math:`x=0`, we insert :math:`x=0` into the above
expression. Note that the grid spacing :math:`\Delta x` cancels out.
.. math::
f(0) = \sum_{i=0}^{n} c_i f_i
with the coefficients :math:`c_i` given by:
.. math::
c_i = \prod_{j=0, j \neq i}^{n} \frac{j-n/2}{j - i}
"""
name = "Polynomial Interpolation"
[docs]
def __init__(self, order: int = 1):
super().__init__()
# order must be an odd number
assert order % 2 == 1
self.required_halo = order // 2 + 1
self.order = order
self._coeffs = None
self._slices = None
self._nexts = None
self._prevs = None
self.water_mask = None
return
[docs]
@fr.modules.module_method
def setup(self, mset: 'fr.ModelSettingsBase') -> None:
super().setup(mset)
self.ndim = ndim = self.mset.grid.n_dims
# coefficients for the polynomial interpolation
order = self.order
coeffs = []
for i in range(order+1):
c = fr.config.dtype_real(1)
for j in range(order+1):
if j != i:
c *= (j - order/2) / (j - i)
coeffs.append(c)
self._coeffs = coeffs
# slices to get certain parts of the array
slices = [slice(i, -order + i) for i in range(order)]
slices.append(slice(order, None))
all_slices = []
for axis in range(ndim):
sl = []
for sli in slices:
s = [slice(None)] * ndim
s[axis] = sli
sl.append(tuple(s))
all_slices.append(sl)
self._slices = all_slices
self._nexts = tuple(self._get_slices(axis)[0] for axis in range(ndim))
self._prevs = tuple(self._get_slices(axis)[1] for axis in range(ndim))
# water mask
self.water_mask = self.mset.grid.water_mask
return
[docs]
@fr.utils.jaxjit
def interpolate(self,
f: fr.ScalarField,
destination: fr.grid.Position) -> fr.ScalarField:
for axis in range(f.arr.ndim):
f = self.interpolate_axis(f, axis, destination.positions[axis])
mask = self.water_mask.get_mask(destination)
f.arr *= mask
return f
[docs]
@partial(fr.utils.jaxjit, static_argnames=('axis', 'destination'))
def interpolate_axis(self,
f: fr.ScalarField,
axis: int,
destination: fr.grid.AxisPosition) -> fr.ScalarField:
if not f.topo[axis]:
# no interpolation when the field has no extend along the axis
return f
if f.position[axis] == destination:
# no interpolation needed
return f
res = fr.ScalarField(mset=f.mset, mdata=deepcopy(f.mdata))
# get the destination slice
match destination:
case fr.grid.AxisPosition.CENTER:
dest_slice = self._nexts[axis]
case fr.grid.AxisPosition.FACE:
dest_slice = self._prevs[axis]
@self.grid.domain_decomp.shard_map
def interpolate(arr):
average = sum(arr[s] * self._coeffs[i]
for i, s in enumerate(self._slices[axis]))
return fr.utils.modify_array(arr, dest_slice, average)
res.arr = interpolate(f.arr)
res.position = f.position.shift(axis)
return res
def _get_slices(self, axis):
n = self.order // 2
if n == 0:
end = None
else:
end = -n
next = tuple(slice(n+1, end) if i == axis else slice(None)
for i in range(self.ndim))
prev = tuple(slice(n, -1-n) if i == axis else slice(None)
for i in range(self.ndim))
return next, prev