Source code for fridom.framework.modules.restart_module

import fridom.framework as fr
import os
import time
import numpy as np


[docs] class RestartModule(fr.modules.Module): name = "Restart Module"
[docs] def __init__(self, realtime_interval: 'np.timedelta64 | None' = None, modeltime_interval: 'np.timedelta64 | None' = None, iteration_interval: int | None = None, restart_command: str | None = None, filename: str = "model", directory: str = "restart") -> None: super().__init__() # ---------------------------------------------------------------- # Check the restart interval input # ---------------------------------------------------------------- num_interval_args = sum([realtime_interval is not None, modeltime_interval is not None, iteration_interval is not None]) if num_interval_args == 0: # No interval is set. Disable the module. fr.log.verbose( "No interval is set in RestartModule. Disabling the module.") self.disable() # ---------------------------------------------------------------- # Check the filename # ---------------------------------------------------------------- if "-" in filename: fr.log.warning( "The filename should not contain the character '-' Replacing '-' with '_' in the filename.") filename = filename.replace("-", "_") # remove the extension from the filename if "." in filename: base, ext = os.path.splitext(filename) ext = ext.lower() filename = base if ext in [".dill", ".pkl", ".pickle"] else filename # ---------------------------------------------------------------- # Set the restart command # ---------------------------------------------------------------- if restart_command is None and self.is_enabled == True: job_id = os.getenv('SLURM_JOB_ID') if job_id is None: fr.log.warning( "No restart command is set. The model will not be able to restart.") restart_command = None else: import subprocess job_info = subprocess.run( ['scontrol', 'show', 'job', job_id], capture_output=True, text=True) command = None for line in job_info.stdout.split('\n'): if line.strip().startswith('Command='): command = line.split('=', 1)[1].strip() if command is None: fr.log.warning( "No restart command is set. The model will not be able to restart.") restart_command = f"sbatch {command}" # ---------------------------------------------------------------- # Set attributes # ---------------------------------------------------------------- self.realtime_interval = realtime_interval self.modeltime_interval = modeltime_interval self.iteration_interval = iteration_interval self.restart_command = restart_command self.filename = filename self.directory = directory self.file = None # private attributes self._last_restart_modeltime = None self._last_restart_iteration = None
[docs] @fr.modules.module_method def setup(self, mset: 'fr.ModelSettingsBase') -> None: super().setup(mset) fr.log.verbose("Touching the restart directory.") os.makedirs(self.directory, exist_ok=True) return
[docs] @fr.modules.module_method def should_restart(self, mz: 'fr.ModelState') -> bool: # ---------------------------------------------------------------- # Realtime interval # ---------------------------------------------------------------- if self.realtime_interval is not None: # get elapsed time elapsed_time = np.timedelta64(int(time.time() - fr.config.load_time), 's') if elapsed_time >= self.realtime_interval: fr.log.info( "Realtime restart interval reached. Model will restart.") self.set_full_filename(mz.it) return True # ---------------------------------------------------------------- # Modeltime interval # ---------------------------------------------------------------- elif self.modeltime_interval is not None: if self._last_restart_modeltime is None: self._last_restart_modeltime = mz.clock.time elapsed_time = mz.clock.time - self._last_restart_modeltime if elapsed_time >= self.modeltime_interval: fr.log.info( "Modeltime restart interval reached. Model will restart.") self.set_full_filename(mz.it) self._last_restart_modeltime = mz.clock.time self._last_restart_iteration = mz.it return True # ---------------------------------------------------------------- # Iteration interval # ---------------------------------------------------------------- elif self.iteration_interval is not None: if self._last_restart_iteration is None: self._last_restart_iteration = mz.it elapsed_time = mz.it - self._last_restart_iteration if elapsed_time >= self.iteration_interval: fr.log.info( "Iteration restart interval reached. Model will restart.") self.set_full_filename(mz.it) self._last_restart_modeltime = mz.clock.time self._last_restart_iteration = mz.it return True return False
[docs] @fr.modules.module_method def should_reload(self) -> bool: fr.log.verbose("Checking if restart files exist.") files = os.listdir(self.directory) files = [f for f in files if f.startswith(self.filename)] if len(files) > 0: fr.log.info("Found restart files. Model will reload.") its = [int(f.split("_")[1]) for f in files] self.set_full_filename(max(its)) return True else: fr.log.info("No restart files found. Model will not reload.") return False
[docs] @fr.modules.module_method def reset(self) -> None: self._last_restart_modeltime = None self._last_restart_iteration = None return
[docs] def set_full_filename(self, it: int) -> None: if fr.utils.MPI_AVAILABLE: rank = fr.utils.MPI.COMM_WORLD.Get_rank() else: rank = 0 filename = f"{self.filename}_{it}_{rank}.dill" self.file = os.path.join(self.directory, filename) return
# ================================================================ # Properties # ================================================================ @property def info(self) -> dict: res = super().info if not self.is_enabled: return res if self.realtime_interval is not None: res["Realtime Restart Interval"] = self.realtime_interval if self.modeltime_interval is not None: res["Modeltime Restart Interval"] = self.modeltime_interval if self.iteration_interval is not None: res["Iteration Restart Interval"] = self.iteration_interval if self.directory is not None: res["Directory"] = self.directory if self.filename is not None: res["Filename"] = self.filename return res