Source code for fridom.framework.model_state

"""model_state.py - The base class for model states."""
from functools import partial
import fridom.framework as fr


# pylint: disable=too-many-instance-attributes
[docs] @partial(fr.utils.jaxify, dynamic=('_z', '_z_diag', '_dz', '_it', '_clock')) class ModelState: """ Stores the model state variables and the time information. Description ----------- The base class for model states. It contains the state vector, the time step and the model time. Child classes may add more attributes as for example the diagnostic variables needed for the model. All model state variables should be stored in this class. Parameters ---------- `mset` : `ModelSettings` The model settings object. `clock` : `Clock`, optional The clock object to keep track of the model time. """
[docs] def __init__(self, mset: 'fr.ModelSettingsBase', clock: fr.Clock | None = None) -> None: self.mset = mset self.z = mset.state_constructor() self.z_diag = mset.diagnostic_state_constructor() self.dz = None self._clock = clock or fr.Clock() # flag to cancel the model run in case something goes wrong self.panicked = False
[docs] def reset(self) -> None: """Reset the model state.""" self._z *= 0.0 self._z_diag *= 0.0 self._dz = None self._clock.reset()
# ================================================================ # xarray conversion # ================================================================ @property def xr(self): """ Model State as xarray dataset """ return self.xrs[:] @property def xrs(self): """ Model State of sliced domain as xarray dataset """ # xarray sometimes takes a long time to load, so we only import it here # if it is actually needed try: import xarray as xr # pylint: disable=import-outside-toplevel except ImportError as e: raise ImportError( "xarray is not installed. Please install it to use this feature." ) from e def slicer(key): ds_z = self.z.xrs[key] ds_zd = self.z_diag.xrs[key] ds = xr.merge([ds_z, ds_zd]) return ds return fr.utils.SliceableAttribute(slicer) # ================================================================ # Properties # ================================================================ @property def z(self) -> 'fr.VectorField': """ The state vector. """ return self._z @z.setter def z(self, value: 'fr.VectorField') -> None: # if the vector is empty, just set it if value.vector_dim == 0: self._z = value return # convert to correct space spectral_grid = value.grid.spectral_grid if spectral_grid and not value.is_spectral: value = value.fft() if not spectral_grid and value.is_spectral: value = value.ifft() self._z = value @property def z_diag(self) -> 'fr.VectorField': """The diagnostic state vector.""" return self._z_diag @z_diag.setter def z_diag(self, value: 'fr.VectorField') -> None: # if the vector is empty, just set it if value.vector_dim == 0: self._z_diag = value return # convert to correct space spectral_grid = value.grid.spectral_grid if spectral_grid and not value.is_spectral: value = value.fft() if not spectral_grid and value.is_spectral: value = value.ifft() self._z_diag = value @property def dz(self) -> 'fr.VectorField': """The tendency vector.""" return self._dz @dz.setter def dz(self, value: 'fr.VectorField') -> None: if value is None: self._dz = value return if value.vector_dim == 0: self._dz = value return # convert to correct space spectral_grid = value.grid.spectral_grid if spectral_grid and not value.is_spectral: value = value.fft() if not spectral_grid and value.is_spectral: value = value.ifft() self._dz = value @property def it(self) -> int: """The iteration number.""" fr.log.warning( "The iteration number is deprecated. Use the clock.it attribute instead." ) return self._it @it.setter def it(self, value: int) -> None: fr.log.warning( "The iteration number is deprecated. Use the clock.it attribute instead." ) self._it = value @property def clock(self) -> 'fr.Clock': """ The clock of the model. """ return self._clock @property def panicked(self) -> bool: """Flag to cancel the model run in case something goes wrong.""" return self._panicked @panicked.setter def panicked(self, value: bool) -> None: self._panicked = value