Source code for fridom.framework.model_settings_base
"""model_settings_base.py - Base class for model settings container."""
from __future__ import annotations
from functools import partial
from typing import Literal, TypeVar
import fridom.framework as fr
T = TypeVar("T", bound="ModelSettingsBase")
[docs]
@partial(fr.utils.jaxify, dynamic=("grid",))
class ModelSettingsBase:
"""
Base class for model settings container.
Description
-----------
This class should be used as a base class for all model settings containers.
It provides a set of attributes and methods that are common to all models.
Child classes should override the following attributes:
- n_dims
- model_name
- tendencies
- diagnostics
And the following methods:
- state_constructor
- diagnostic_state_constructor
Examples
--------
Create a new model settings class by inheriting from `ModelSettingsBase`:
.. code-block:: python
import fridom.framework as fr
class ModelSettings(fr.ModelSettingsBase):
def __init__(self, grid, **kwargs):
super().__init__(grid)
self.model_name = "MyModel"
# set other parameters
self.my_parameter = 1.0
# Finally, set attributes from keyword arguments
self.set_attributes(**kwargs)
# optional: override the parameters property
@property
def parameters(self):
res = super().parameters
res["my_parameter"] = self.my_parameter
return res
"""
model_name = "Unnamed model"
[docs]
def __init__(self, grid: fr.grid.GridBase, **kwargs: dict) -> None:
self._tendencies = fr.modules.ModuleContainer("All Tendencies")
self._diagnostics = fr.modules.ModuleContainer("All Diagnostics")
self._time_stepper = fr.time_steppers.AdamBashforth()
self._progress_bar = fr.modules.ProgressBar()
self._nan_checker = fr.modules.NaNChecker()
self._restart_module = fr.modules.RestartModule()
self._timer = fr.timing_module.TimingModule()
self._custom_state_fields = []
self._custom_diagnostic_fields = []
self._halo = None
self._raise_error_when_something_goes_wrong = False
self.grid = grid
self.is_setup = False
self.set_attributes(**kwargs)
[docs]
def set_attributes(self, **kwargs: dict) -> None:
"""
Set model settings attributes from keyword arguments.
Parameters
----------
kwargs : dict
Keyword arguments to set the attributes of the model settings.
Raises
------
AttributeError
The attribute does not exist in the model settings.
"""
# Set attributes from keyword arguments
for key, value in kwargs.items():
# Check if attribute exists
if not hasattr(self, key):
message = f"ModelSettings has no attribute '{key}'"
raise AttributeError(message)
setattr(self, key, value)
[docs]
def setup_grid(self, setup_mode: Literal["default", "forced"] = "default") -> None:
"""Set the grid object up."""
# TODO(Silvano): Pass the setup mode to the grid setup
self.grid.setup(mset=self)
def _setup_all_modules(self,
setup_mode: Literal["default", "forced"] = "default",
) -> None:
"""Set all modules up."""
self.grid.water_mask.setup(mset=self)
modules = [self.nan_checker, self.progress_bar, self.restart_module,
self.tendencies, self.diagnostics, self.time_stepper]
for module in modules:
module.setup(mset=self, setup_mode=setup_mode)
[docs]
def setup_settings_parameters(self) -> None:
"""Set the model settings parameters up."""
[docs]
def setup(self: T, setup_mode: Literal["default", "forced"] = "default") -> T:
"""
Set the model settings up.
Description
-----------
This method will initialize the grid object and setup all modules.
It must be called before accessing any attributes of the grid or modules.
Returns
-------
ModelSettingsBase
The model settings object
"""
if self.is_setup and setup_mode == "default":
# If the model settings are already set up, return
return self
fr.log.verbose("Setting up model settings")
self.is_setup = True
self.setup_grid(setup_mode=setup_mode)
self.setup_settings_parameters()
self._setup_all_modules(setup_mode=setup_mode)
fr.log.info(self)
return self
[docs]
def state_constructor(self) -> None:
"""Construct the state vector from this model settings."""
return fr.VectorField(self)
[docs]
def diagnostic_state_constructor(self) -> fr.VectorField:
"""Construct the diagnostic state vector from this model settings."""
return fr.VectorField(self)
def __repr__(self) -> str:
return f"""
=================================================
Model Settings:
-------------------------------------------------
# {self.model_name}
# Parameters: {self.__parameters_to_string()}
# Grid: {self.grid}
# Time Stepper: {self.time_stepper}
# {self.restart_module}
# Tendencies: {self.tendencies}
# Diagnostics: {self.diagnostics}
=================================================
"""
# ================================================================
# Properties
# ================================================================
@property
def parameters(self) -> dict:
"""
Return a dictionary with all parameters of the model settings.
Description
-----------
This method should be overridden by the child class to return a dictionary
with all parameters of the model settings. This dictionary is used to print
the model settings in the `__repr__` method.
"""
return {}
def __parameters_to_string(self) -> str:
res = ""
for key, value in self.parameters.items():
res += f"\n - {key}: {value}"
return res
@property
def grid(self) -> fr.grid.GridBase:
"""The spatial grid."""
return self._grid
@grid.setter
def grid(self, value: fr.grid.GridBase) -> None:
self._grid = value
# ----------------------------------------------------------------
# Module properties
# ----------------------------------------------------------------
@property
def time_stepper(self) -> None:
"""The time stepper object (default: AdamBashforth)."""
return self._time_stepper
@time_stepper.setter
def time_stepper(self, value: fr.time_steppers.TimeStepper) -> None:
self._time_stepper = value
if self.is_setup:
value.setup(mset=self)
@property
def progress_bar(self) -> fr.modules.ProgressBar:
"""The progress bar object (default: ProgressBar)."""
return self._progress_bar
@progress_bar.setter
def progress_bar(self, value: fr.modules.ProgressBar) -> None:
self._progress_bar = value
if self.is_setup:
value.setup(mset=self)
@property
def nan_checker(self) -> fr.modules.NaNChecker:
"""The NaN checker object (default: NaNChecker)."""
return self._nan_checker
@nan_checker.setter
def nan_checker(self, value: fr.modules.NaNChecker) -> None:
self._nan_checker = value
if self.is_setup:
value.setup(mset=self)
@property
def tendencies(self) -> fr.modules.ModuleContainer:
"""The module container for all tendencies."""
return self._tendencies
@tendencies.setter
def tendencies(self, value: fr.modules.ModuleContainer) -> None:
self._tendencies = value
old_halo = self.halo
if self.is_setup:
value.setup(mset=self)
if old_halo != self.halo:
self.grid.setup(mset=self)
@property
def diagnostics(self) -> fr.modules.ModuleContainer:
"""The module container for all diagnostics."""
return self._diagnostics
@diagnostics.setter
def diagnostics(self, value: fr.modules.ModuleContainer) -> None:
self._diagnostics = value
old_halo = self.halo
if self.is_setup and value is not None:
value.setup(mset=self)
if old_halo != self.halo:
self.grid.setup(mset=self)
@property
def restart_module(self) -> fr.modules.RestartModule:
"""The restart module."""
return self._restart_module
@restart_module.setter
def restart_module(self, value: fr.modules.RestartModule) -> None:
self._restart_module = value
if self.is_setup:
value.setup(mset=self)
@property
def timer(self) -> fr.timing_module.TimingModule:
"""The timing module."""
return self._timer
@timer.setter
def timer(self, value: fr.timing_module.TimingModule) -> None:
self._timer = value
@property
def raise_error_when_something_goes_wrong(self) -> bool:
"""Raise an error when something goes wrong."""
return self._raise_error_when_something_goes_wrong
@raise_error_when_something_goes_wrong.setter
def raise_error_when_something_goes_wrong(self, value: bool) -> None:
self._raise_error_when_something_goes_wrong = value
# ----------------------------------------------------------------
# Other properties
# ----------------------------------------------------------------
@property
def is_setup(self) -> bool:
"""Return whether the model settings are set up."""
return self._is_setup
@is_setup.setter
def is_setup(self, value: bool) -> None:
self._is_setup = value
@property
def halo(self) -> int:
"""Return the halo size of the model."""
if self._halo is not None:
return self._halo
return self.tendencies.required_halo
@halo.setter
def halo(self, value: int) -> None:
old_halo = self.halo
self._halo = value
if old_halo != self.halo: # we need to force a new setup
self.setup(setup_mode="forced")
@property
def custom_state_fields(self) -> list[fr.FieldMetadata]:
"""List of custom state fields."""
return self._custom_state_fields
@property
def custom_diagnostic_fields(self) -> list[fr.FieldMetadata]:
"""List of custom diagnostic fields."""
return self._custom_diagnostic_fields