Source code for fridom.framework.scalar_field

"""Scalar field class definition."""
from __future__ import annotations

from copy import deepcopy
from functools import partial
from typing import TYPE_CHECKING, Literal

import numpy as np

import fridom.framework as fr
from fridom.framework.grid.fft_padding import FFTPadding

if TYPE_CHECKING:  # pragma: no cover
    import xarray as xr
    from numpy import ndarray

[docs] @partial(fr.utils.jaxify, dynamic=("_arr", "_mdata")) class ScalarField(fr.FieldBase): """ A scalar mapping from grid space to real / complex numbers. Description ----------- A scalar field is the most basic field in FRIDOM. It is a mapping from the grid space to real or complex numbers. It is used to represent scalar quantities like pressure, temperature, etc. Essentially, a scalar field is wrapper around a numpy-like array with additional metadata and methods. Parameters ---------- mset : fr.ModelSettingsBase The model settings object. mdata : fr.FieldMetadata, optional The metadata object for the field. If not provided, a new one is created. arr : ndarray, optional The underlying array of the field. If not provided, a new array is created. """
[docs] def __init__(self, mset: fr.ModelSettingsBase, mdata: fr.FieldMetadata | None = None, arr: ndarray | None = None, **kwargs: any) -> None: super().__init__(mset=mset) # create new metadata object if not provided mdata = mdata or fr.FieldMetadata() # set default values of the metadata mdata.set_default(mset) # set the attributes from the kwargs for key, value in kwargs.items(): setattr(mdata, key, value) # The underlying array if arr is None: data = mset.grid.create_array( pad=True, spectral=mdata.is_spectral, topo=tuple(mdata.topo)) else: conf = fr.config dtype = conf.dtype_comp if mdata.is_spectral else conf.dtype_real data = conf.ncp.array(arr, dtype=dtype) # ---------------------------------------------------------------- # Set attributes # ---------------------------------------------------------------- self._mdata = mdata self._arr = data
# ================================================================ # General methods # ================================================================
[docs] def fft(self, # noqa: D102 padding: FFTPadding = FFTPadding.NOPADDING, ) -> ScalarField: self._fft_possible() # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() transformed_arr = self.grid.fft( arr=self.arr, padding=padding, bc_types=self.bc_types, positions=self.position.positions) conf = fr.config transformed_arr = conf.ncp.array(transformed_arr, dtype=conf.dtype_comp) return ScalarField(self.mset, arr=transformed_arr, mdata=deepcopy(self.mdata), is_spectral=True)
[docs] def ifft(self, # noqa: D102 padding: FFTPadding = FFTPadding.NOPADDING, ) -> ScalarField: self._ifft_possible() # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() transformed_arr = self.grid.ifft( arr=self.arr, padding=padding, bc_types=self.bc_types, positions=self.position.positions) # only keep the real part conf = fr.config transformed_arr = conf.ncp.array(transformed_arr.real, dtype=conf.dtype_real) return ScalarField(self.mset, arr=transformed_arr, mdata=deepcopy(self.mdata), is_spectral=False)
[docs] def sync(self) -> ScalarField: # noqa: D102 if self.is_spectral: # nothing to synchronize in spectral space return self # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() # synchronize the array self.arr = self.grid.sync(self.arr) self.apply_water_mask() return self
[docs] def apply_water_mask(self) -> ScalarField: # noqa: D102 # the water mask is defined on the 3D grid so we can't apply it # for non full domain fields # TODO(Silvano): Maybe we can assign custom water masks for scalar fields # so that we can apply them to non full domain fields fr.exceptions.PartialDomainError.check(self) self._check_not_spectral() self.arr *= self.grid.water_mask.get_mask(self.position) return self
[docs] def has_nan(self) -> bool: # noqa: D102 ncp = fr.config.ncp return ncp.any(ncp.isnan(self.arr))
[docs] def set_random(self, seed: int = 1234) -> ScalarField: # noqa: D102 # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() # create the random array and set it self.arr = self.grid.create_random_array(seed=seed, spectral=self.is_spectral) return self
def __copy__(self) -> ScalarField: # copy the array and the metadata but not the model settings arr = deepcopy(self.arr) mdata = deepcopy(self.mdata) return ScalarField(mset=self.mset, mdata=mdata, arr=arr)
[docs] def unpad(self) -> ndarray: """ Remove padding from the Scalar Field. Returns ------- ndarray The unpadded array. """ # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() # TODO(Silvano): Make this work for spectral fields if self.is_spectral: msg = "Cannot unpad spectral field" raise ValueError(msg) return self.grid.unpad(self.arr)
[docs] def get_mesh(self) -> tuple[ndarray]: """ Get the meshgrid of the ScalarField. Description ----------- This method returns the meshgrid of the ScalarField. It returns a tuple of ndarrays, where each ndarray represents the meshgrid in one direction. For example, a 3D field that is extended in x, z but not in y would return a tuple of 2 ndarrays (x, z). Returns ------- tuple[ndarray] The meshgrid of the ScalarField for each direction that is extended. """ # TODO(Silvano): Make this work for non full domain fields fr.exceptions.PartialDomainError.check(self) return self.grid.get_mesh(self.position, self.is_spectral)
[docs] def interpolate(self, destination: fr.grid.Position) -> ScalarField: """ Interpolate the field to the destination position. Parameters ---------- destination : fr.grid.Position The position to interpolate to. Returns ------- ScalarField The interpolated field. """ # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() # TODO(Silvano): Make this work for spectral fields self._check_not_spectral() return self.grid.interp_module.interpolate(self, destination)
# ================================================================ # Check Methods (for internal use) # ================================================================ def _check_full_domain(self) -> None: if not all(self.topo): msg = "Operation not available for non full domain fields" raise NotImplementedError(msg) def _check_not_spectral(self) -> None: if self.is_spectral: msg = "Operation not available for spectral fields" raise NotImplementedError(msg) def _check_axes_argument(self, axes: tuple[int] | None) -> None: if axes is not None: msg = "Operation not available for specific axes" raise NotImplementedError(msg) # ================================================================ # Differential Operators # ================================================================
[docs] def diff(self, axis: int, order: int = 1) -> ScalarField: # noqa: D102 # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() # TODO(Silvano): Make this work for spectral fields self._check_not_spectral() return self.grid.diff_module.diff(self, axis, order)
[docs] def grad(self, axes: list[int] | None = None ) -> fr.VectorField: # noqa: D102 # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() # TODO(Silvano): Make this work for spectral fields self._check_not_spectral() return self.grid.diff_module.grad(self, axes)
[docs] def laplacian(self, # noqa: D102 axes: tuple[int] | None = None, ) -> ScalarField: # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() # TODO(Silvano): Make this work for spectral fields self._check_not_spectral() return self.grid.diff_module.laplacian(self, axes)
[docs] def div(self, axes: list[int] | None = None) -> None: # noqa: D102 _ = axes msg = "Divergence is not defined for scalar fields" raise ValueError(msg)
[docs] def cumulative_integral(self, # noqa: D102 axis: int, direction: Literal["forward", "backward"] = "forward", ) -> ScalarField: return self.grid.cumulative_integral(self, axis, direction)
# ================================================================ # xarray Interface # ================================================================ def _convert_slice_to_xarray(self, key: int | slice | tuple[int | slice], ) -> xr.DataArray: # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() import xarray as xr # normalize the key key = self._normalize_slice_key(key) # gather the array on the root process arr = self.grid.domain_decomp.gather( self.arr, key, spectral=self.is_spectral) # get the dimensions and coordinates of the slice dims, coords = self._get_sliced_coords(key, arr.shape) # reverse the dimensions dims.reverse() # get all attributes all_attrs = self.mdata.to_serializable() # create the xarray DataArray dv = xr.DataArray( fr.utils.to_numpy(np.squeeze(arr).T), coords=coords, dims=tuple(dims), name=self.name, attrs=all_attrs) # add the additional attributes to the coordinates x_unit = "1/m" if self.is_spectral else "m" for dim in dims: dv[dim].attrs["units"] = x_unit # add the slice key as an attribute dv.attrs["slice_key"] = str(key) return dv def _normalize_slice_key(self, key: int | slice | tuple[int | slice], ) -> tuple[int | slice]: """Normalize the slice key to the number of dimensions.""" ndim = self.grid.n_dims # convert key to list key = [key] if not isinstance(key, (tuple, list)) else list(key) # extend the key to the number of dimensions key += [slice(None)] * (ndim - len(key)) for i in range(ndim): # set non-extended axes to 0 if not self.topo[i]: key[i] = slice(0, 1) # convert negative indices to slices if isinstance(key[i], int): if key[i] < 0: key[i] = slice(key[i]-1, key[i]) else: key[i] = slice(key[i], key[i]+1) return tuple(key) def _get_sliced_coords(self, key: tuple[int | slice], shape: tuple[int], ) -> tuple[list, dict]: """Get the coordinates for a slice of the ScalarField.""" ndim = self.grid.n_dims realistic_dims = 3 # get the coordinates if ndim <= realistic_dims: dim_names = ["kx", "ky", "kz"] if self.is_spectral else ["x", "y", "z"] all_dims = tuple(dim_names[:ndim]) else: prefix = "k" if self.is_spectral else "x" all_dims = tuple(f"{prefix}{i}" for i in range(ndim)) mesh = self.grid.k_global if self.is_spectral else self.grid.x_global dims = [] coords = {} for axis in range(self.grid.n_dims): if shape[axis] == 1: # skip non-extended axes continue dim = all_dims[axis] dims.append(dim) coords[dim] = fr.utils.to_numpy(mesh[axis][key[axis]]) return dims, coords @property def xr(self) -> xr.DataArray: # noqa: D102 return self.xrs[:] @property def xrs(self) -> fr.utils.SliceableAttribute[xr.DataArray]: # noqa: D102 return fr.utils.SliceableAttribute(self._convert_slice_to_xarray)
[docs] @classmethod def from_xarray(cls, # noqa: D102 mset: fr.ModelSettingsBase, ds: xr.DataArray, ) -> ScalarField: conf = fr.config # read in the slice key # in general, eval poses a security risk, we eliminate this risk by # setting the __builtins__ to None and only allowing the slice function # to be used # this means that statements like "__import__('os')" will not work slice_key = eval(ds.attrs["slice_key"], # noqa: S307 {"__builtins__": None}, {"slice": slice}) # TODO(Silvano): Add option to read from sliced dataarrays # This could for example be implemented by creating the full array # and then setting the slice region to the values of the dataarray # another option would be to allow for local fields that lives in a subregion # of the domain. For now, we don't allow for sliced dataarrays to be # converted back to ScalarFields if not isinstance(slice_key, tuple): slice_key = (slice_key,) for key in slice_key: if key != slice(None): msg = "Cannot convert sliced dataarray to ScalarField" raise ValueError(msg) # load the metadata mdata = fr.FieldMetadata.from_serializable(ds.attrs) # convert the array to backend arr = ds.to_numpy().T # if the array is loaded with xarray from a netcdf file, complex arrays # are stored as two separate arrays for the real and imaginary part # we check if the array has a "r" and "i" key and if so, we combine them try: arr["r"] separate = True except IndexError: separate = False if separate: arr_real = conf.ncp.array(arr["r"]) arr_imag = conf.ncp.array(arr["i"]) arr = conf.ncp.array(arr_real + 1j * arr_imag, dtype=conf.dtype_comp) else: dtype = conf.dtype_comp if mdata.is_spectral else conf.dtype_real arr = conf.ncp.array(arr, dtype=dtype) if not mdata.is_spectral: # pad the array arr = mset.grid.pad(arr) # create the ScalarField field = cls(mset=mset, mdata=mdata, arr=arr) # synchronize the field return field.sync()
[docs] @classmethod def from_netcdf(cls, # noqa: D102 mset: fr.ModelSettingsBase, path: str) -> ScalarField: import xarray as xr ds = xr.open_dataarray(path) return cls.from_xarray(mset, ds)
@property def value(self) -> complex | float: """ The value of the constant ScalarField. Description ----------- This property returns the value of a constant ScalarField. Constant means that the ScalarField has no extension in any direction. If the ScalarField is not constant, a ValueError is raised. Returns ------- complex | float The value of the constant ScalarField. Raises ------ ValueError If the ScalarField is not constant. """ if not self.is_constant: msg = "The field is not constant" raise ValueError(msg) return self.arr.item() # ================================================================== # SLICING # ================================================================== def __getitem__(self, key: slice | tuple[slice | int]) -> ScalarField: msg = "Slicing is currently not supported for ScalarFields" raise NotImplementedError(msg) def __setitem__(self, key: slice | tuple[slice | int], value: ScalarField | float) -> None: msg = "Slicing is currently not supported for ScalarFields" raise NotImplementedError(msg) # ================================================================ # Properties # ================================================================ @property def info(self) -> dict: """Dictionary with information about the field.""" res = {} if not any(self.topo): res["value"] = self.arr.item() res["name"] = self.name res["long_name"] = self.long_name res["units"] = self.units res["is_spectral"] = self.is_spectral res["position"] = self.position res["topo"] = self.topo res["bc_types"] = self.bc_types enabled_flags = [key for key, value in self.flags.items() if value] res["enabled_flags"] = enabled_flags return res @property def arr(self) -> ndarray: """The underlying array.""" return self._arr @arr.setter def arr(self, arr: ndarray) -> None: self._arr = arr @property def mdata(self) -> fr.FieldMetadata: """The metadata of the ScalarField.""" return self._mdata @mdata.setter def mdata(self, mdata: fr.FieldMetadata) -> None: self._mdata = mdata @property def name(self) -> str: """The name of the ScalarField.""" return self.mdata.name @name.setter def name(self, name: str) -> None: self.mdata.name = name @property def long_name(self) -> str: """The long name of the ScalarField.""" return self.mdata.long_name @long_name.setter def long_name(self, long_name: str) -> None: self.mdata.long_name = long_name @property def units(self) -> str: """The unit of the ScalarField.""" return self.mdata.units @units.setter def units(self, units: str) -> None: self.mdata.units = units @property def nc_attrs(self) -> dict: """Dictionary with additional attributes for the NetCDF file or xarray.""" return self.mdata.nc_attrs @nc_attrs.setter def nc_attrs(self, nc_attrs: dict) -> None: self.mdata.nc_attrs = nc_attrs @property def is_spectral(self) -> bool: """True if the ScalarField is in spectral space.""" return self.mdata.is_spectral @property def topo(self) -> tuple[bool]: """ Topology of the ScalarField. Description ----------- Scalar fields do not have to be extended in all directions. For example, one might want to create a 2D forcing field for a 3D simulation, that only depends on x and y. In this case, the topo of the ScalarField would be (True, True, False). """ return self.mdata.topo @property def is_constant(self) -> bool: # noqa: D102 return not any(self.topo) @property def position(self) -> fr.grid.Position: """The position of the ScalarField on the staggered grid.""" return self.mdata.position @position.setter def position(self, position: fr.grid.Position) -> None: self.mdata.position = position @property def bc_types(self) -> tuple[fr.grid.BCType] | None: """The boundary condition types for the ScalarField.""" return self.mdata.bc_types @bc_types.setter def bc_types(self, bc_types: tuple[fr.grid.BCType] | None) -> None: self.mdata.bc_types = bc_types @property def flags(self) -> dict: """Dictionary with flag options for the ScalarField.""" return self.mdata.flags @flags.setter def flags(self, flags: dict) -> None: self.mdata.flags = flags # ================================================================ # Shrink / Extend operations # ================================================================
[docs] def extend(self, topo: tuple[bool]) -> ScalarField: # noqa: D102 # check if the topology is valid (no shrinking) old_topo = self.topo for (old, new) in zip(old_topo, topo): if old and not new: msg = "Cannot shrink the field in any direction" raise ValueError(msg) # TODO(Silvano): The grid.extend method is not implemented yet msg = "The grid.extend method is not implemented yet" raise NotImplementedError(msg) return self.grid.extend(self, topo)
def _set_shrinked_field(self, arr: ndarray, axes: tuple[int] | None) -> ScalarField: """Shrink the ScalarField in the specified axes and set the new array.""" if axes is None: axes = tuple(i for i in range(self.grid.n_dims)) new_mdata = deepcopy(self.mdata) # set the new topo topo = list(new_mdata.topo) for axis in axes: topo[axis] = False new_mdata.topo = tuple(topo) return ScalarField(mset=self.mset, mdata=new_mdata, arr=arr)
[docs] def sum(self, axes: tuple[int] | None = None) -> ScalarField: # noqa: D102 # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() # TODO(Silvano): Implement sum over specific axes self._check_axes_argument(axes) # TODO(Silvano): This should call the grid.sum method domain = self.grid.domain_decomp result = domain.sum(self.arr, axes=axes, spectral=self.is_spectral) # result must be a n-dimensional array shape = tuple([1] * self.grid.n_dims) result = fr.config.ncp.full(shape, result) return self._set_shrinked_field(arr=result, axes=axes)
[docs] def max(self, axes: tuple[int] | None = None) -> ScalarField: # noqa: D102 # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() # TODO(Silvano): Implement max over specific axes self._check_axes_argument(axes) # TODO(Silvano): This should call the grid.max method domain = self.grid.domain_decomp result = domain.max(self.arr, axes=axes, spectral=self.is_spectral) # result must be a n-dimensional array shape = tuple([1] * self.grid.n_dims) result = fr.config.ncp.full(shape, result) return self._set_shrinked_field(arr=result, axes=axes)
[docs] def min(self, axes: tuple[int] | None = None) -> ScalarField: # noqa: D102 # TODO(Silvano): Make this work for non full domain fields self._check_full_domain() # TODO(Silvano): Implement min over specific axes self._check_axes_argument(axes) # TODO(Silvano): This should call the grid.min method domain = self.grid.domain_decomp result = domain.min(self.arr, axes=axes, spectral=self.is_spectral) # result must be a n-dimensional array shape = tuple([1] * self.grid.n_dims) result = fr.config.ncp.full(shape, result) return self._set_shrinked_field(arr=result, axes=axes)
[docs] def integrate(self, axes: tuple[int] | None = None) -> ScalarField: # noqa: D102 if isinstance(axes, list): axes = tuple(axes) return self.grid.integrate(self, axes=axes)
[docs] def mean(self, axes: tuple[int] | None = None) -> ScalarField: # noqa: D102 if isinstance(axes, list): axes = tuple(axes) volume = self.grid.characteristic_function.integrate(axes=axes) return self.integrate(axes=axes) / volume
# ================================================================ # Arithmetic operations # ================================================================
[docs] def abs(self) -> ScalarField: # noqa: D102 arr = fr.config.ncp.abs(self.arr) return ScalarField(mset=self.mset, mdata=deepcopy(self.mdata), arr=arr)
[docs] def dot(self, # noqa: D102 other: ScalarField | fr.VectorField | fr.TensorField, ) -> ScalarField | fr.VectorField | fr.TensorField: # check that the spectral flag is the same if self.is_spectral != other.is_spectral: msg = "Cannot take dot product of spectral and real fields" raise ValueError(msg) # complex conjugate the other field if it is spectral if other.is_spectral: other = other.conj() # compute the dot product return other * self
[docs] def conj(self) -> ScalarField: # noqa: D102 return ScalarField(mset=self.mset, mdata=deepcopy(self.mdata), arr=self.arr.conj())
@staticmethod def _apply_operation(op: callable, field: ScalarField, other: ScalarField | complex | np.number, ) -> ScalarField: new_mdata = deepcopy(field.mdata) if isinstance(other, ScalarField): topo = [p or q for p, q in zip(field.topo, other.topo)] new_mdata.topo = topo result = op(field.arr, other.arr) elif isinstance(other, (int, float, complex, np.number, fr.config.ncp.ndarray)) or other is None: result = op(field.arr, other) else: return NotImplemented return ScalarField(mset=field.mset, mdata=new_mdata, arr=result)