Source code for fridom.framework.vector_field

"""The vector field module."""
from __future__ import annotations

from collections import OrderedDict
from copy import copy
from functools import partial
from typing import Callable, Iterator, Literal, TypeVar

import numpy as np

import fridom.framework as fr
from fridom.framework.grid.fft_padding import FFTPadding

T = TypeVar("T", bound="VectorField")

[docs] @partial(fr.utils.jaxify, dynamic=("_fields", )) class VectorField(fr.FieldBase): """ A vector mapping from grid space to the vector space. Description ----------- A vector field essentially is a list of scalar fields, one for each component of the vector field. This class provides additional methods for vector operations such as dot product, or differential operators like divergence. Parameters ---------- mset : fr.ModelSettingsBase The model settings object. field_list : list[fr.ScalarField] | OrderedDict[str, fr.ScalarField] | None The list of scalar fields that make up the vector field. vector_dim : int | None The vector dimension. If None, it is set to the length of field_list. **kwargs : any Additional keyword arguments to pass to the scalar fields. """ # ================================================================ # Constructor # ================================================================
[docs] def __init__(self, mset: fr.ModelSettingsBase, field_list: (list[fr.ScalarField] | OrderedDict[str, fr.ScalarField] | None) = None, vector_dim: int | None = None, **kwargs: any, ) -> None: super().__init__(mset) if isinstance(field_list, (list, OrderedDict)) and kwargs: msg = "Keyword arguments not allowed when passing a list or dict" raise TypeError(msg) if kwargs: self._check_for_valid_kwargs(kwargs) # if the input is a list, check for duplicated names and convert to dict if isinstance(field_list, list): field_names = [field.name for field in field_list] if len(field_names) != len(set(field_names)): msg = f"Duplicated field names: {field_names}" raise ValueError(msg) field_list = OrderedDict((field.name, field) for field in field_list) elif isinstance(field_list, OrderedDict): pass elif field_list is None: field_list = self._create_default_fields(mset, vector_dim, **kwargs) else: msg = f"Invalid field list type: {type(field_list)}" raise TypeError(msg) # check the vector dimension vector_dim = vector_dim or len(field_list) if vector_dim != len(field_list): msg = f"Vector dimension mismatch: {vector_dim} != {len(field_list)}" raise ValueError(msg) # set the properties self._fields = field_list self._vector_dim = len(field_list)
def _check_for_valid_kwargs(self, kwargs: any) -> None: allowed_keys = { "topo", "is_spectral", "position", "bc_types", "flags", "nc_attrs"} if not set(kwargs).issubset(allowed_keys): msg = f"Invalid keyword arguments: {set(kwargs) - allowed_keys}" raise TypeError(msg) @staticmethod def _create_default_fields(mset: fr.ModelSettingsBase, vector_dim: int | None, **kwargs: any, ) -> OrderedDict[str, fr.ScalarField]: # if no vector dimension is given, set it to 0 vector_dim = vector_dim or 0 # create a dictionary of scalar fields field_list = OrderedDict() for i in range(vector_dim): field = fr.ScalarField(mset, name=f"f{i}", **kwargs) field_list[field.name] = field return field_list @staticmethod def _add_custom_fields(mset: fr.ModelSettingsBase, fields: OrderedDict[str, fr.ScalarField], custom_fields: list[fr.FieldMetadata], **kwargs: any, ) -> OrderedDict[str, fr.ScalarField]: """ Add custom fields to the vector field. Description ----------- This method adds custom fields to the vector field. The custom fields are specified as a list of field metadata objects. The method checks if the names of the custom fields are unique and adds them to the dictionary of fields. This method is particularly useful for creating state or diagnostic vector fields. The list of custom fields can be stored in the model settings object and passed to the vector field constructor. Parameters ---------- mset : fr.ModelSettingsBase The model settings object. fields : OrderedDict[str, fr.ScalarField] The dictionary of scalar fields. custom_fields : list[fr.FieldMetadata] The list of custom fields to add. **kwargs : any Additional keyword arguments to pass to the scalar fields. Returns ------- OrderedDict[str, fr.ScalarField] The dictionary of scalar fields with the custom fields added. """ # check if names are unique field_names = set(fields) custom_names = {field.name for field in custom_fields} if not custom_names.isdisjoint(field_names): msg = f"Field names not unique: {custom_names & field_names}" raise ValueError(msg) # add the custom fields for mdata in custom_fields: field = fr.ScalarField(mset, mdata=mdata.copy(), **kwargs) fields[field.name] = field return fields # ================================================================ # General Methods # ================================================================
[docs] def fft(self: T, # noqa: D102 padding: FFTPadding = FFTPadding.NOPADDING, ) -> T: return self.apply_elementwise(self, lambda field: field.fft(padding=padding))
[docs] def ifft(self: T, # noqa: D102 padding: FFTPadding = FFTPadding.NOPADDING, ) -> T: return self.apply_elementwise(self, lambda field: field.ifft(padding=padding))
[docs] def project(self: T, p_vec: T, q_vec: T) -> T: r""" Project a Vector Field onto a (spectral) vector. Description ----------- The projection of the vector :math:`\boldsymbol{z}` on a P-Vector :math:`\boldsymbol{z}` and a Q-Vector :math:`\boldsymbol{q}` is defined as: .. math:: \boldsymbol{z} = \boldsymbol{q} \cdot \left( \boldsymbol{z} \cdot \boldsymbol{p} \right) The projection is done in spectral space. All vectors are transformed to spectral space before the projection and transformed back to physical space if necessary. Parameters ---------- p_vec : VectorField the projection vector :math:`\boldsymbol{p}` q_vec : VectorField the polarization vector :math:`\boldsymbol{q}` Returns ------- VectorField The projected vector field :math:`\boldsymbol{z}` """ # transform to spectral space if necessary was_spectral = self.is_spectral vec = self if was_spectral else self.fft() # check if the projection vectors are in spectral space if not p_vec.is_spectral: p_vec = p_vec.fft() if not q_vec.is_spectral: q_vec = q_vec.fft() # project vec = q_vec * (vec.dot(p_vec)) # transform back to physical space if necessary if not was_spectral: vec = vec.ifft() return vec
[docs] def sync(self: T) -> T: # noqa: D102 # TODO(Silvano): the test for spectral space should not be necessary # sync should move to the grid if self.vector_dim == 0 or self.is_spectral: # nothing to synchronize in spectral space return self # sync all arrays at once arrs = [field.arr for field in self.fields.values()] arrs = self.grid.sync_multi(arrs) # set the arrays to the fields for field, arr in zip(self.fields.values(), arrs): field.arr = arr # apply the water mask return self.apply_water_mask()
[docs] def apply_water_mask(self: T) -> T: # noqa: D102 for field in self: field.apply_water_mask() return self
[docs] def has_nan(self) -> bool: # noqa: D102 return any(field.has_nan() for field in self.fields.values())
[docs] def set_random(self: T, seed: int = 1234) -> T: # noqa: D102 for i, field in enumerate(self): field.set_random(i * seed) return self
def __copy__(self: T) -> T: # create a new vector field, but copy the fields return self.apply_elementwise(self, lambda field: copy(field)) # ================================================================ # Differential Operators # ================================================================
[docs] def diff(self: T, # noqa: D102 axis: int, order: int = 1, ) -> T: return self.apply_elementwise(self, lambda field: field.diff(axis, order))
[docs] def grad(self, # noqa: D102 axes: list[int] | None = None, ) -> fr.TensorField: # TODO(Silvano): add implementation after TensorField is implemented raise NotImplementedError("grad not implemented yet")
[docs] def laplacian(self: T, # noqa: D102 axes: tuple[int] | None = None, ) -> T: return self.apply_elementwise(self, lambda field: field.laplacian(axes))
[docs] def div(self) -> fr.ScalarField: # noqa: D102 msg = "div not implemented yet" raise NotImplementedError(msg)
[docs] def cumulative_integral(self, # noqa: D102 axis: int, direction: Literal["forward", "backward"] = "forward", ) -> VectorField: return self.apply_elementwise( self, lambda field: field.cumulative_integral(axis, direction))
# ================================================================ # xarray Interface # ================================================================ @property def xr(self) -> xr.Dataset: # noqa: D102 return self.xrs[:] @property def xrs(self) -> fr.utils.SliceableAttribute[xr.Dataset]: # noqa: D102 import xarray as xr def slicer(key: int | slice | tuple[int | slice]) -> xr.Dataset: ds = xr.Dataset({f.name: f.xrs[key] for f in self}) # we need to add the variable names in the correct order # this ensures that the order of the variables is preserved # when loading the vector field from xarray ds.attrs["var_names"] = [f.name for f in self] ds.attrs["vector_dim"] = self.vector_dim return ds return fr.utils.SliceableAttribute(slicer)
[docs] @classmethod def from_xarray(cls: type[T], # noqa: D102 mset: fr.ModelSettingsBase, ds: xr.Dataset, ) -> T: # get the list of variable names var_names = ds.attrs["var_names"] vector_dim = ds.attrs["vector_dim"] # create the field list field_list = [fr.ScalarField.from_xarray(mset, ds[var_name]) for var_name in var_names] # create the vector field return cls(mset, field_list=field_list, vector_dim=vector_dim)
# ================================================================ # Sliceable Interface # ================================================================ def __getitem__(self, key: str | int | slice[int]) -> fr.ScalarField | fr.VectorField: """Get a field or slice of the vector field.""" if isinstance(key, str): return self.fields[key] if isinstance(key, int): # get the name of the field at the index name = list(self.fields)[key] return self.fields[name] if isinstance(key, slice): # get the slice of the fields field_list = list(self.fields.values())[key] return VectorField(self.mset, field_list=field_list, vector_dim=len(field_list)) msg = f"Invalid key type: {type(key)}" raise ValueError(msg) def __setitem__(self, key: str | int | slice[int], value: fr.ScalarField | fr.VectorField) -> None: """Set a field or slice of the vector field.""" if isinstance(key, str): self._check_for_name_mismatch(key, value) self.fields[key] = value return if isinstance(key, int): # get the name of the field at the index name = list(self.fields)[key] self._check_for_name_mismatch(name, value) self.fields[name] = value return if isinstance(key, slice): # get the names of the fields in the slice names = list(self.fields)[key] for name, field in zip(names, value.fields.values()): self._check_for_name_mismatch(name, field) # set the fields in the slice self.fields[name] = field return msg = f"Invalid key type: {type(key)}" raise TypeError(msg) def __iter__(self) -> Iterator[fr.ScalarField]: """Iterate over the fields of the vector field.""" return iter(self.fields.values()) def _check_for_name_mismatch(self, name: str, field: fr.ScalarField) -> None: if field.name != name: msg = f"Field name mismatch: {field.name} != {name}" raise ValueError(msg) # ================================================================ # Properties # ================================================================ @property def info(self) -> dict: # noqa: D102 res = {} for name, field in self.fields.items(): res[name] = f"{field.long_name} [{field.units}]" return res @property def fields(self) -> OrderedDict[str, fr.ScalarField]: """The dictionary of scalar fields.""" return self._fields @fields.setter def fields(self, value: OrderedDict[str, fr.ScalarField]) -> None: # check that all fields have the same spectral flag if len({field.is_spectral for field in value.values()}) != 1: msg = "All fields must have the same spectral flag" raise ValueError(msg) self._fields = value @property def field_list(self) -> list[fr.ScalarField]: """The list of scalar fields.""" return list(self.fields.values()) @property def vector_dim(self) -> int: """The vector dimension.""" # should be read-only return self._vector_dim @property def is_spectral(self) -> bool: # noqa: D102 if self.vector_dim == 0: msg = "Cannot determine if vector field is spectral with 0 components" raise ValueError(msg) return next(iter(self.fields.values())).is_spectral @property def is_constant(self) -> bool: # noqa: D102 return all(field.is_constant for field in self) # ================================================================ # Shrink / Extend operations # ================================================================
[docs] def extend(self: T, topo: tuple[bool]) -> T: # noqa: D102 self.apply_elementwise(self, lambda field: field.extend(topo))
[docs] def sum(self: T, axes: tuple[int] | None = None) -> T: # noqa: D102 return self.apply_elementwise(self, lambda field: field.sum(axes))
[docs] def max(self: T, axes: tuple[int] | None = None) -> T: # noqa: D102 return self.apply_elementwise(self, lambda field: field.max(axes))
[docs] def min(self: T, axes: tuple[int] | None = None) -> T: # noqa: D102 return self.apply_elementwise(self, lambda field: field.min(axes))
[docs] def integrate(self: T, axes: tuple[int] | None = None) -> T: # noqa: D102 return self.apply_elementwise(self, lambda field: field.integrate(axes))
[docs] def mean(self: T, axes: tuple[int] | None = None) -> T: # noqa: D102 return self.apply_elementwise(self, lambda field: field.mean(axes))
# ================================================================ # Arithmetic Operations # ================================================================
[docs] def norm_of_diff(self, other: VectorField) -> float: r""" Norm of difference between two vector fields. Description ----------- The norm of difference computes the normalized difference between two vector fields :math:`\boldsymbol{z}` and :math:`\boldsymbol{z}'`. It is defined as: .. math:: 2 \frac{||\boldsymbol{z} - \boldsymbol{z}'||} {||\boldsymbol{z}|| + ||\boldsymbol{z}'||} where :math:`||\cdot||` is the L2 norm of the vector field. The norm of difference is in the range [0, 2]. Parameters ---------- other : VectorField The other vector field to compare with Returns ------- float The norm of difference between the two vector fields """ return 2 * (self - other).norm_l2() / (self.norm_l2() + other.norm_l2())
[docs] def dot(self: T, # noqa: D102 other: fr.ScalarField | VectorField | fr.TensorField, ) -> fr.ScalarField | VectorField | T: # check that the spectral flag is the same if self.is_spectral != other.is_spectral: msg = "Cannot take dot product of spectral and real fields" raise ValueError(msg) # complex conjugate the other field if it is spectral if other.is_spectral: other = other.conj() if isinstance(other, fr.ScalarField): return self * other if isinstance(other, fr.VectorField): return sum((self * other).fields.values()) if isinstance(other, fr.TensorField): msg = "Dot product of vector field with tensor field not possible" raise TypeError(msg) msg = f"Invalid type for dot product: {type(other)}" raise TypeError(msg)
[docs] def conj(self: T) -> T: # noqa: D102 return self.apply_elementwise(self, lambda field: field.conj())
[docs] def abs(self: T) -> T: # noqa: D102 return self.apply_elementwise(self, lambda field: abs(field))
[docs] @staticmethod def apply_elementwise(vector_field: T, op: Callable[[fr.ScalarField], fr.ScalarField], ) -> T: """ Apply an operation elementwise to the vector field. Description ----------- This method applies an operation elementwise to each scalar field in the vector field and returns a new vector field with the modified fields. Parameters ---------- vector_field : VectorField The vector field to apply the operation to. op : Callable[[fr.ScalarField], fr.ScalarField] The operation to apply to each scalar field. Should take a scalar field as input and return a scalar field. Returns ------- VectorField The new vector field with the modified fields """ cls = vector_field.__class__ # apply the operation to each field new_fields = OrderedDict( (name, op(field)) for name, field in vector_field.fields.items()) return cls(vector_field.mset, field_list=new_fields, vector_dim=vector_field.vector_dim)
@staticmethod def _apply_operation(op: Callable[[T, any], fr.FieldBase], field: T, other: any) -> T: cls = field.__class__ if isinstance(other, fr.VectorField): if field.vector_dim != other.vector_dim: msg = "Vector dimensions do not match: " msg += f"{field.vector_dim} != {other.vector_dim}" raise ValueError(msg) names = list(field.fields) fields = OrderedDict( (name, op(field.fields[name], other.fields[name])) for name in names) return cls(field.mset, field_list=fields, vector_dim=field.vector_dim) if isinstance(other, (fr.ScalarField, float, int, complex, np.number, fr.config.ncp.ndarray)) or other is None: return field.apply_elementwise(field, lambda x: op(x, other)) return NotImplemented