Source code for fridom.framework.model

# Import external modules
from typing import TYPE_CHECKING, Union
import numpy as np
import fridom.framework as fr
# Import internal modules
from fridom.framework import config
# Import type information
if TYPE_CHECKING:
    from fridom.framework.model_settings_base import ModelSettingsBase


[docs] class Model: """ Base class for the model. Attributes: mset (ModelSettings) : Model settings. grid (Grid) : Grid. z (State) : State variable. dz_list (list) : List of tendency terms (for time stepping). pointer (np.ndarray) : Pointer for time stepping. coeff_AB (np.ndarray) : Adam-Bashforth coefficients. timer (TimingModule) : Timer. it (int) : Iteration counter. time (float) : Model time. dz (State) : Current tendency term. Methods: run() : Run the model for a given number of steps. step() : Perform one time step. reset() : Reset the model (pointers, tendencies) """
[docs] def __init__(self, mset: 'ModelSettingsBase') -> None: """ Constructor. """ self.mset = mset # state variable from fridom.framework.model_state import ModelState self.model_state = ModelState(mset) # Timer self.timer = mset.timer # Modules self.progress_bar = mset.progress_bar self.restart_module = mset.restart_module self.tendencies = mset.tendencies self.diagnostics = mset.diagnostics # Time stepper self.time_stepper = mset.time_stepper return
[docs] def start(self): """ Prepare the model for running. """ # start all modules self.timer.total.start() self.restart_module.start() self.tendencies.start() self.diagnostics.start() self.time_stepper.start() self.model_state.panicked = False # compile the modules from time import time if fr.config.backend_is_jax: fr.log.notice("Compiling tendency modules") start_time = time() mz = fr.ModelState(self.mset) mz.dz = self.mset.state_constructor() self.tendencies.update(mz) fr.log.notice( f"Compilation finished in {time()-start_time:.2f} seconds") # start the progress bar at the very end self.progress_bar.start() return
[docs] def stop(self): """ Finish the model run. """ self.restart_module.stop() self.tendencies.stop() self.diagnostics.stop() self.time_stepper.stop() self.timer.total.stop() self.progress_bar.stop() return
[docs] def reset(self) -> None: """ Reset the model (pointers, tendencies). """ self.restart_module.reset() self.tendencies.reset() self.diagnostics.reset() self.time_stepper.reset() self.model_state.reset() self.timer.reset() # to implement in child class return
# ============================================================ # RUN MODEL # ============================================================
[docs] def run(self, steps: int | None = None, runlen: Union[np.timedelta64, float, int, None] = None, start_time: Union[np.datetime64, float, int] = 0, end_time: Union[np.datetime64, float, int, None] = None, progress_bar=True) -> None: """ Run the model Parameters ---------- `steps` : `int` (default: None) Number of steps to run. `runlen` : `np.timedelta64 | float | int` (default: None) Length of the run. `start_time` : `np.datetime64 | float | int` (default: 0) Start time of the run. `end_time` : `np.datetime64 | float | int` (default: None) End time of the run. `progress_bar` : `bool` (default: True) Show progress bar. Raises ------ `ValueError` Only one of `steps`, `runlen` or `end_time` can be given. """ # ---------------------------------------------------------------- # Check input # ---------------------------------------------------------------- # only one of steps, runlen or end_time can be given if sum([steps is not None, runlen is not None, end_time is not None]) > 1: raise ValueError("Only one of steps, runlen or end_time can be given.") # ---------------------------------------------------------------- # Convert time parameters to seconds # ---------------------------------------------------------------- self.model_state.clock.set_start(start_time) datetime_formatting = False if isinstance(start_time, np.datetime64): datetime_formatting = True start_time = fr.utils.to_seconds(start_time) if isinstance(end_time, np.datetime64): datetime_formatting = True end_time = fr.utils.to_seconds(end_time) if isinstance(runlen, np.timedelta64): runlen = fr.utils.to_seconds(runlen) # set the start time self.model_state.clock.time = start_time # ---------------------------------------------------------------- # Calculate number of steps / end time # ---------------------------------------------------------------- # calculate end time if runlen is given if runlen is not None: end_time = start_time + runlen # calculate the final iteration step if steps is given if steps is not None: main_loop_type = "for loop" start_value = self.model_state.it final_value = start_value + steps else: main_loop_type = "while loop" start_value = start_time final_value = end_time # ---------------------------------------------------------------- # Load the model # ---------------------------------------------------------------- # check if the model needs to be reloaded if self.restart_module.should_reload(): self.load(self.restart_module.file) # start the model self.start() # Set the progress bar options self.progress_bar.set_options( main_loop_type=main_loop_type, datetime_formatting=datetime_formatting, start_value=start_value, final_value=final_value) # ---------------------------------------------------------------- # Initial diagnostics # ---------------------------------------------------------------- for module in self.diagnostics.module_list: if module.execute_at_start: self.model_state = module.update(self.model_state) # ---------------------------------------------------------------- # Main loop: Given number of setps # ---------------------------------------------------------------- if steps is not None: start_it = self.model_state.it fr.log.info( f"Running model from iteration {start_value} to {final_value}") # loop over the given number of steps for _ in range(start_it, final_value): self.step() if self.model_state.panicked: fr.log.warning( "Something went wrong. Stopping model.") break # ---------------------------------------------------------------- # Main loop: Given run length # ---------------------------------------------------------------- elif end_time is not None: fr.log.info( f"Running model from {self.model_state.clock.time} to {end_time}") # loop until the end time is reached while self.model_state.clock.time < end_time: self.step() if self.model_state.panicked: fr.log.warning( "Something went wrong. Stopping model.") break # stop the model self.stop() fr.log.info( f"Model run finished at it: {self.model_state.it}, time: {self.model_state.clock.time}") fr.log.info(self.mset.timer) return
# ============================================================ # SINGLE TIME STEP # ============================================================
[docs] def step(self) -> None: """ Update the model state by one time step. """ # synchronize the state vector (ghost points) with self.timer["sync"]: self.z.sync() # perform the time step self.model_state = self.time_stepper.update(mz=self.model_state) # check if there are any nans in the state variable with self.timer["check_nan"]: if self.model_state.it % self.mset.nan_check_interval == 0: if self.model_state.z.has_nan(): fr.log.critical( "State variable contains NaNs. Stopping model.") self.model_state.panicked = True # make diagnostics self.model_state = self.diagnostics.update(mz=self.model_state) # Update the progress bar self.progress_bar.update(self.model_state) # check if the model should restart if self.restart_module.should_restart(self.model_state): self.restart()
[docs] def restart(self) -> None: fr.log.info( f"Stopping model at it: {self.model_state.it}, time: {self.model_state.clock.time}") self.stop() self.save(self.restart_module.file) fr.log.info(self.mset.timer) fr.log.info("Spawning new sbatch job:") fr.log.info(self.restart_module.restart_command) fr.utils.mpi_barrier() if fr.utils.MPI_AVAILABLE: import subprocess result = subprocess.run( self.restart_module.restart_command.split(), capture_output=True, text=True) fr.log.notice(result.stdout) if result.stderr: fr.log.error(result.stderr) fr.utils.mpi_barrier() exit()
# ============================================================ # Getters and setters # ============================================================ @property def z(self): """ Returns the current state variable. """ return self.model_state.z @z.setter def z(self, value): """ Set the current state variable. """ self.model_state.z = value return # ============================================================ # OTHER METHODS # ============================================================
[docs] def load(self, file: str) -> None: # underscores are not allowed in the filename import dill # get a list of all files in the directory that start with the filename with open(file, "rb") as f: model = dill.load(f) model.mset.grid = self.mset.grid for key, attr in vars(model).items(): setattr(self, key, attr) return
[docs] def save(self, file: str) -> None: import dill with open(file, "wb") as f: fr.log.verbose(f"Saving model to {file}") grid = self.mset.grid # remove the grid from the model before pickling self.mset.grid = None dill.dump(self, f) # restore the grid self.mset.grid = grid return