Source code for fridom.framework.model_state
"""model_state.py - The base class for model states."""
from functools import partial
import fridom.framework as fr
# pylint: disable=too-many-instance-attributes
[docs]
@partial(fr.utils.jaxify, dynamic=('_z', '_z_diag', '_dz', '_it', '_clock'))
class ModelState:
"""
Stores the model state variables and the time information.
Description
-----------
The base class for model states. It contains the state vector, the time step
and the model time. Child classes may add more attributes as for example the
diagnostic variables needed for the model.
All model state variables should be stored in this class.
Parameters
----------
`mset` : `ModelSettings`
The model settings object.
`clock` : `Clock`, optional
The clock object to keep track of the model time.
"""
[docs]
def __init__(self,
mset: 'fr.ModelSettingsBase',
clock: fr.Clock | None = None) -> None:
self.mset = mset
self.z = mset.state_constructor()
self.z_diag = mset.diagnostic_state_constructor()
self.dz = None
self.it = 0
self._clock = clock or fr.Clock()
# flag to cancel the model run in case something goes wrong
self.panicked = False
[docs]
def reset(self) -> None:
"""Reset the model state."""
self._z *= 0.0
self._z_diag *= 0.0
self._dz = None
self._it = 0
self._clock.reset()
# ================================================================
# xarray conversion
# ================================================================
@property
def xr(self):
"""
Model State as xarray dataset
"""
return self.xrs[:]
@property
def xrs(self):
"""
Model State of sliced domain as xarray dataset
"""
# xarray sometimes takes a long time to load, so we only import it here
# if it is actually needed
try:
import xarray as xr # pylint: disable=import-outside-toplevel
except ImportError as e:
raise ImportError(
"xarray is not installed. Please install it to use this feature."
) from e
def slicer(key):
ds_z = self.z.xrs[key]
ds_zd = self.z_diag.xrs[key]
ds = xr.merge([ds_z, ds_zd])
return ds
return fr.utils.SliceableAttribute(slicer)
# ================================================================
# Properties
# ================================================================
@property
def z(self) -> 'fr.StateBase':
"""
The state vector.
"""
return self._z
@z.setter
def z(self, value: 'fr.StateBase') -> None:
# convert to correct space
if value.is_spectral != value.grid.spectral_grid:
value = value.fft()
self._z = value
@property
def z_diag(self) -> 'fr.StateBase':
"""
The diagnostic state vector.
"""
return self._z_diag
@z_diag.setter
def z_diag(self, value: 'fr.StateBase') -> None:
# convert to correct space
if value.is_spectral != value.grid.spectral_grid:
value = value.fft()
self._z_diag = value
@property
def dz(self) -> 'fr.StateBase':
"""The tendency vector."""
return self._dz
@dz.setter
def dz(self, value: 'fr.StateBase') -> None:
# convert to correct space
if value is not None and value.is_spectral != value.grid.spectral_grid:
value = value.fft()
self._dz = value
@property
def it(self) -> int:
"""The iteration number."""
return self._it
@it.setter
def it(self, value: int) -> None:
self._it = value
@property
def clock(self) -> 'fr.Clock':
"""
The clock of the model.
"""
return self._clock
@property
def panicked(self) -> bool:
"""Flag to cancel the model run in case something goes wrong."""
return self._panicked
@panicked.setter
def panicked(self, value: bool) -> None:
self._panicked = value