Source code for fridom.framework.modules.module

"""The base class for all modules."""
from __future__ import annotations

from abc import abstractmethod
from functools import wraps
from typing import Literal, TypeVar

import fridom.framework as fr

T = TypeVar("T")

[docs] def module_method(method: T) -> T: """ Decorate the start, update and stop method of a module. Description ----------- Sets the log level of the module if the log level is set and time the duration of the method. """ @wraps(method) def wrapper(self: Module, *args: tuple, **kwargs: dict) -> any: if self.is_enabled(): # if the log level is set, change the log level for the module if self.log_level is not None: old_log_level = fr.log.level fr.config.set_log_level(self.log_level.value) fr.log.debug( f"Calling '{method.__name__}' of: {self.name}") # check if the model settings are already set if not self.is_setup: result = method(self, *args, **kwargs) else: with self.mset.timer[self.name]: result = method(self, *args, **kwargs) # if the log level was set, change it back to the old log level if self.log_level is not None: fr.config.set_log_level(old_log_level) return result # if the module is disabled and the method is the update method, return # the model state if method.__name__ == "update": return kwargs.get("mz") return None return wrapper
[docs] class Module: """ Base class for all modules. Description ----------- A module is a component of the model that is executed at each time step. It can for example be a tendency term, a parameterization, or a diagnostic as for example outputting the model state to a file. Required methods: 1. `__init__(self, ...) -> None`: The constructor only takes keyword argument which are stored as attributes. Always call the parent constructor with `super().__init__(name, **kwargs)`. The name of the module is stored in the timing module and should not be too long. 2. `update(self, mz: ModelState) -> None`: This method is called by the model at each time step. It can for example update the tendency state `mz.dz` based on the model state `mz`. Or write the model state to a file. Make sure to wrap the method with the `@update_module` decorator. Optional methods: 1. `start(self, mset: ModelSettingsBase) -> None`: This method is called by the model when the module is started. It can for example open an output file. Make sure to wrap the method with the `@start_module` decorator. 2. `stop(self) -> None`: This method is called by the model when the module is stopped. It can for example close an output file. Make sure to wrap the method with the `@stop_module` decorator. Parameters ---------- name : str The name of the module. **kwargs Keyword arguments that are stored as attributes of the module. Flags ----- required_halo : int The number of halo points required by the module. mpi_available : bool Whether the module can be run in parallel. execute_at_start : bool Whether the module should be executed before the first time step. Examples -------- .. code-block:: python import fridom.framework as fr class Increment(fr.modules.Module): def __init__(self): # sets the module name to "Increment", and the number to None super().__init__("Increment", number=None) @fr.modules.start_module def start(self): self.number = 0 # sets the number to 0 @fr.modules.update_module def update(self, mz: fr.ModelSettingsBase) -> None: self.number += 1 # increments the number by 1 @fr.modules.stop_module def stop(self): self.number = None # sets the number to None """ name = "Base Module" _is_mod_submodule = False
[docs] def __init__(self) -> None: # The module is enabled by default self.__enabled = True # The log level self.log_level: str | int | None = None # Set the flags self._required_halo = None # The number of halo points required by the module self.mpi_available = True # Whether the module can be run in parallel self.execute_at_start = False # Whether the module is setup self.is_setup = False # The grid should be set by the model when the module is started self.mset: fr.ModelSettingsBase | None = None # Differentiation and interpolation modules self._diff_module: fr.grid.DiffModule | None = None self._interp_module: fr.grid.InterpolationModule | None = None
[docs] @module_method def setup(self, mset: fr.ModelSettingsBase, setup_mode: Literal["default", "forced"] = "default", ) -> None: """ Set the module up. Description ----------- This method is called by the ModelSettings.setup() and sets the ModelSettings as well as the differentiation and interpolation modules. Parameters ---------- mset : fr.ModelSettingsBase The model settings object. setup_mode : Literal["default", "forced"] The setup mode. If the setup mode is "default" and the module is already setup, the method will return. If the setup mode is "forced", the module will be setup again. """ if self.is_setup and setup_mode == "default": return self.is_setup = True self.mset = mset if not self._is_mod_submodule: # setup the differentiation and interpolation modules self._setup_submodule("diff_module", mset) self._setup_submodule("interp_module", mset) self._on_setup()
def _setup_submodule(self, name: str, mset: fr.ModelSettingsBase) -> None: submodule = getattr(self, name) if submodule is None: submodule = getattr(mset.grid, name) setattr(self, name, submodule) else: submodule.setup(mset=mset) @abstractmethod def _on_setup(self) -> None: """Is called by the setup method."""
[docs] def start(self) -> None: """ Start the module. Description ----------- This method is called at the beginning of the model run. Child classes that require a start method (for example to start an output writer) should overwrite this method. Make sure to decorate the method with the `@module_method` decorator. """
[docs] def stop(self) -> None: """ Stop the module. Description ----------- This method is called by the model at the end of the model run or when the model is reset. Child classes that require a stop method (for example to close an output file) should overwrite this method. Make sure to decorate the method with the `@module_method` decorator. """ return
[docs] @module_method def reset(self) -> None: """Stop and start the module.""" self.stop() self.start() self._on_reset()
@abstractmethod def _on_reset(self) -> None: """Is called by the reset method."""
[docs] def update(self, mz: fr.ModelState) -> fr.ModelState: """ Update the model state. Description ----------- This method is called by the model at each time step. Child classes should overwrite this method to update the module. Make sure to decorate the method with the `@module_method` decorator. Parameters ---------- mz : fr.ModelState The model state at the current time step. Returns ------- fr.ModelState The updated model state. """ return mz
[docs] def enable(self) -> None: """ Enable the module. Description ----------- Enabling the module means that it will be executed at each time step. Disabled modules are neither initialized nor updated. """ self.__enabled = True
[docs] def disable(self) -> None: """ Disable the module. Description ----------- Enabling the module means that it will be executed at each time step. Disabled modules are neither initialized nor updated. """ self.__enabled = False
[docs] def is_enabled(self) -> bool: """Whether the module is enabled or not.""" return self.__enabled
def __repr__(self) -> str: """Format the module info to a string.""" res = self.name if not self.__enabled: res += " (disabled)" for key, value in self.info.items(): res += f"\n - {key}: {value}" return res # ================================================================ # Properties # ================================================================ @property def info(self) -> dict: """ Return a dictionary with information about the time stepper. Description ----------- This method should be overridden by the child class to return a dictionary with information about the time stepper. This information is used to print the time stepper in the `__repr__` method. """ info = {} # # ---------------------------------------------------------------- # # Check if the differentiation module should be printed # # ---------------------------------------------------------------- # if ( (self.is_setup and self.diff_module is not self.grid.diff_module) or # (not self.is_setup and self.diff_module is not None) ): # info["Diff. Module"] = self.diff_module.name # # ---------------------------------------------------------------- # # Check if the differentiation module should be printed # # ---------------------------------------------------------------- # if ( (self.is_setup and # self.interp_module is not self.grid.interp_module) or # (not self.is_setup and self.interp_module is not None) ): # info["Interp. Module"] = self.interp_module.name # ---------------------------------------------------------------- # Check if the required halo should be printed # ---------------------------------------------------------------- print_halo = self._required_halo is not None if print_halo: info["Required Halo"] = self.required_halo return info @property def is_setup(self) -> bool: """Whether the module is set up.""" return self._is_setup @is_setup.setter def is_setup(self, value: bool) -> None: self._is_setup = value @property def mset(self) -> fr.ModelSettingsBase: """The model settings.""" fr.exceptions.NotSetUpError.check(self, "mset") return self._mset @mset.setter def mset(self, mset: fr.ModelSettingsBase) -> None: self._mset = mset @property def grid(self) -> fr.grid.GridBase: """The grid of the model settings.""" return self.mset.grid @property def diff_module(self) -> fr.grid.DiffModule | None: """The differentiation module to be used by this module.""" return self._diff_module @diff_module.setter def diff_module(self, value: fr.grid.DiffModule) -> None: self._diff_module = value if self.is_setup: value.setup(mset=self.mset) @property def interp_module(self) -> fr.grid.InterpolationModule | None: """The interpolation module to be used by this module.""" return self._interp_module @interp_module.setter def interp_module(self, value: fr.grid.InterpolationModule) -> None: self._interp_module = value @property def required_halo(self) -> int: """The required halo points for this module.""" # Return the required halo if it is set if self._required_halo is not None: return self._required_halo # If it is not set, check the differentiation and interpolation modules # If they are not set, return 0 req_halo = 0 if self.diff_module is not None: req_halo = self.diff_module.required_halo if self.interp_module is not None: req_halo = max(req_halo, self.interp_module.required_halo) return req_halo @required_halo.setter def required_halo(self, value: int) -> None: self._required_halo = value