Source code for fridom.framework.model

"""Model class for the fridom framework."""
from __future__ import annotations

import numpy as np

import fridom.framework as fr


[docs] class Model: """ The main model class. Parameters ---------- mset : ModelSettingsBase The model settings. """
[docs] def __init__(self, mset: fr.ModelSettingsBase) -> None: self.mset = mset self.model_state = fr.ModelState(mset)
[docs] def start(self) -> None: """Prepare the model for running.""" # start all modules self.timer.total.start() for module in self._modules: module.start() self.model_state.panicked = False
[docs] def stop(self) -> None: """Finish the model run.""" for module in self._modules: module.stop() self.timer.total.stop() self.progress_bar.stop()
[docs] def reset(self) -> None: """Reset the model (pointers, tendencies).""" for module in self._modules: module.reset() self.model_state.reset() self.timer.reset()
# ============================================================ # RUN MODEL # ============================================================
[docs] def run_backward(self, steps: int | None = None) -> None: """ Run the model backward in time. Description ----------- This method runs the model backward in time for a given number of steps. """ # Prepare the model for running if self.restart_module.should_reload(): self.load(self.restart_module.file) self.time_stepper.dt = -abs(self.time_stepper.dt) # ensure dt is negative start_value = self.model_state.clock.it # step count always increases even when running backward final_value = start_value + steps self.start() # Execute the first time step self._execute_first_time_step() # Start the progress bar self.progress_bar.start() self.progress_bar.set_options( main_loop_type="for loop", datetime_formatting=False, start_value=start_value, final_value=final_value) # ---------------------------------------------------------------- # Main loop # ---------------------------------------------------------------- self._main_loop_steps(start_value+1, final_value) # stop the model self._finalize_run()
[docs] def run(self, steps: int | None = None, runlen: np.timedelta64 | float | None = None, start_step: int = 0, start_time: np.datetime64 | float = 0, end_time: np.datetime64 | float | None = None) -> None: """ Run the model. Parameters ---------- steps : int (default: None) Number of steps to run. runlen : np.timedelta64 | float | int (default: None) Length of the run. start_step : int (default: 0) Start iteration of the run. start_time : np.datetime64 | float | int (default: 0) Start time of the run. end_time : np.datetime64 | float | int (default: None) End time of the run. """ # Check input fr.exceptions.TooManyArgumentsError.check( max_args=1, steps=steps, runlen=runlen, end_time=end_time) # Convert time parameters to seconds datetime_formatting = self._datetime_formatting( start_time, end_time, runlen) start_time, end_time, runlen = self._initialize_run( start_time, end_time, runlen, start_step) # Determine loop type and final values main_loop_type, start_value, final_value = self._determine_loop_type( steps, start_time, runlen, end_time) # check if the final value is reached if start_value >= final_value: return # Prepare the model for running if self.restart_module.should_reload(): self.load(self.restart_module.file) # make sure the time stepper has a positive time step self.time_stepper.dt = abs(self.time_stepper.dt) self.start() # Execute the first time step self._execute_first_time_step() # Start the progress bar self.progress_bar.start() self.progress_bar.set_options( main_loop_type=main_loop_type, datetime_formatting=datetime_formatting, start_value=start_value, final_value=final_value) # ---------------------------------------------------------------- # Main loop # ---------------------------------------------------------------- if main_loop_type == "for loop": self._main_loop_steps(start_value+1, final_value) elif main_loop_type == "while loop": self._main_loop_time(final_value) # stop the model self._finalize_run()
def _datetime_formatting(self, start_time: np.datetime64 | float, end_time: np.datetime64 | float | None, runlen: np.timedelta64 | float | None) -> bool: """Check if some of the input parameters are in datetime format.""" if isinstance(start_time, np.datetime64): return True if isinstance(end_time, np.datetime64): return True return bool(isinstance(runlen, np.timedelta64)) def _initialize_run(self, start_time: np.datetime64 | float, end_time: np.datetime64 | float | None, runlen: np.timedelta64 | float | None, start_step: int) -> tuple[float, float, float]: """Convert time parameters and initialize the clock.""" self.model_state.clock.set_start(start_time) self.model_state.clock.time = fr.utils.to_seconds(start_time) self.model_state.clock.it = start_step return (fr.utils.to_seconds(start_time), fr.utils.to_seconds(end_time), fr.utils.to_seconds(runlen)) def _determine_loop_type(self, steps: int | None, start_time: float, runlen: float | None, end_time: float | None) -> tuple[str, int, int]: """Determine the loop type and final values.""" if runlen is not None: end_time = start_time + runlen if steps is not None: main_loop_type = "for loop" start_value = self.model_state.clock.it final_value = start_value + steps else: main_loop_type = "while loop" start_value = start_time final_value = end_time return main_loop_type, start_value, final_value def _execute_first_time_step(self) -> None: """Print the timing of the first time step.""" # compile the modules from time import time if fr.config.backend_is_jax: fr.log.notice("Compiling modules at first time step") start_time = time() # Execute modules that should run at the start for module in self.diagnostics.module_list: if module.execute_at_start: self.model_state = module.update(self.model_state) # Execute the first time step self.progress_bar.disable() self._safe_step() self.progress_bar.enable() # Print the compilation time if fr.config.backend_is_jax: fr.log.notice( f"Compilation finished in {time()-start_time:.2f} seconds") def _safe_step(self) -> None: """Run a single time step and catch any exceptions.""" if self.mset.raise_error_when_something_goes_wrong: self.step() return try: self.step() except Exception as e: # noqa: BLE001 fr.log.error("An error occurred during the model run.") fr.log.error(e) self.model_state.panicked = True def _main_loop_steps(self, start_value: int, final_value: int) -> None: start_it = self.model_state.clock.it fr.log.info( f"Running model from iteration {start_value} to {final_value}") # loop over the given number of steps for _ in range(start_it, final_value): self._safe_step() if self.model_state.panicked: fr.log.warning( "Something went wrong. Stopping model.") break def _main_loop_time(self, end_time: float) -> None: fr.log.info( f"Running model from {self.model_state.clock.time} to {end_time}") # loop until the end time is reached while self.model_state.clock.time < end_time: self._safe_step() if self.model_state.panicked: fr.log.warning( "Something went wrong. Stopping model.") break def _finalize_run(self) -> None: """Finalize the model run.""" # finalize the model self.stop() fr.log.info( "Model run finished at it: %d, time: %f", self.model_state.clock.it, self.model_state.clock.time) fr.log.info(self.mset.timer) # ============================================================ # SINGLE TIME STEP # ============================================================
[docs] def step(self) -> None: """Update the model state by one time step.""" # synchronize the state vector (ghost points) with self.timer["sync"]: self.z.sync() # perform the time step self.model_state = self.time_stepper.update(mz=self.model_state) # check if there are any nans in the state variable self.model_state = self.nan_checker.update(self.model_state) # make diagnostics self.model_state = self.diagnostics.update(mz=self.model_state) # Update the progress bar self.progress_bar.update(self.model_state) # check if the model should restart if self.restart_module.should_restart(self.model_state): self.restart_module.restart(self)
# ============================================================ # Getters and setters # ============================================================ @property def z(self) -> fr.VectorField: """Returns the current state variable.""" return self.model_state.z @z.setter def z(self, value: fr.VectorField) -> None: """Set the current state variable.""" self.model_state.z = value @property def timer(self) -> fr.Timer: """The timing module.""" return self.mset.timer @property def nan_checker(self) -> fr.modules.NaNChecker: """The NaN checker module.""" return self.mset.nan_checker @property def progress_bar(self) -> fr.modules.ProgressBar: """The progress bar module.""" return self.mset.progress_bar @property def restart_module(self) -> fr.modules.RestartModule: """The restart module.""" return self.mset.restart_module @property def time_stepper(self) -> fr.time_steppers.TimeStepper: """The time stepper.""" return self.mset.time_stepper @property def tendencies(self) -> fr.modules.ModuleContainer: """The module container for all tendencies.""" return self.mset.tendencies @property def diagnostics(self) -> fr.modules.ModuleContainer: """The module container for all diagnostics.""" return self.mset.diagnostics @property def _modules(self) -> list[fr.modules.Module]: """List of all modules.""" return [ self.nan_checker, self.restart_module, self.time_stepper, self.tendencies, self.diagnostics, ] # ============================================================ # OTHER METHODS # ============================================================
[docs] def load(self, file: str) -> None: """ Load a model from a file. Parameters ---------- file : str The filename to load the model from """ # underscores are not allowed in the filename from pathlib import Path import dill # get a list of all files in the directory that start with the filename with Path(file).open("rb") as f: model = dill.load(f) # noqa: S301 model.mset.grid = self.mset.grid for key, attr in vars(model).items(): setattr(self, key, attr)
[docs] def save(self, file: str) -> None: """ Save the full model to a file. Parameters ---------- file : str The filename to save the model to """ from pathlib import Path import dill with Path(file).open("wb") as f: fr.log.verbose(f"Saving model to {file}") grid = self.mset.grid # remove the grid from the model before pickling self.mset.grid = None dill.dump(self, f) # restore the grid self.mset.grid = grid