"""Restart module for cluster computing."""
from __future__ import annotations
import os
import subprocess
import sys
import time
from pathlib import Path
from typing import Callable
import numpy as np
import fridom.framework as fr
[docs]
class RestartModule(fr.modules.Module):
"""
Automatically restart the job at a given interval.
Description
-----------
Parameters
----------
realtime_interval : np.timedelta64, optional
The interval in real time at which the model should restart.
If None, the model will not restart based on real time.
clock_trigger : fr.ClockTrigger, optional
Triggers the restart based on the model time.
If None, the model will not restart based on model time.
restart_command : str or Callable, optional
The command to start the job.
If the command is a string:
The model will restart by running the command in a subprocess.
If the command is a callable:
The model will restart by calling the function. This function should
simply restart the job. It should have no arguments, and should
return nothing.
If the command is None:
The model will try to find the command from the environment.
If the command is not found, the model will not be able to restart.
file_path : Path, optional
The path to the restart file.
"""
name = "Restart Module"
[docs]
def __init__(self,
realtime_interval: np.timedelta64 | None = None,
clock_trigger: fr.ClockTrigger | None = None,
restart_command: str | Callable | None = None,
file_path: Path | str = Path("restart/model.dill")) -> None:
super().__init__()
# ----------------------------------------------------------------
# Check the restart interval input
# ----------------------------------------------------------------
fr.exceptions.TooManyArgumentsError.check(
max_args=1,
realtime_interval=realtime_interval,
clock_trigger=clock_trigger)
if realtime_interval is None and clock_trigger is None:
# No interval is set. Disable the module.
fr.log.verbose(
"No interval is set in RestartModule. Disabling the module.")
self.disable()
# ----------------------------------------------------------------
# Set attributes
# ----------------------------------------------------------------
self.realtime_interval = realtime_interval
self.clock_trigger = clock_trigger
self.file_path = file_path
self.restart_command = restart_command
self.file = None
def _on_setup(self) -> None:
self._touch_restart_directory()
[docs]
@fr.modules.module_method
def should_restart(self, mz: fr.ModelState) -> bool:
"""Check if one of the restart conditions is met."""
# ----------------------------------------------------------------
# Realtime interval
# ----------------------------------------------------------------
elapsed_time = np.timedelta64(int(time.time() - fr.config.load_time), "s")
interval = self.realtime_interval
if interval is not None and elapsed_time >= interval:
fr.log.info(
"Realtime restart interval reached. Model will restart.")
self.set_full_filename(mz.clock.it)
return True
# ----------------------------------------------------------------
# Modelclock trigger
# ----------------------------------------------------------------
if self.clock_trigger is not None and self.clock_trigger.check(mz.clock):
fr.log.info("Modeltime trigger reached. Model will restart.")
self.set_full_filename(mz.clock.it)
return True
return False
[docs]
@fr.modules.module_method
def should_reload(self) -> bool:
"""
Check if the model should reload from a restart file.
Description
-----------
If a restart file is found in the directory, the model will reload.
If multiple restart files are found, the model will reload from the
restart file with the highest iteration number.
"""
fr.log.verbose("Checking if restart files exist.")
files = os.listdir(self.file_path.parent)
files = [f for f in files if f.startswith(self.file_path.stem)]
if len(files) > 0:
fr.log.info("Found restart files. Model will reload.")
its = [int(f.split("_")[1]) for f in files]
self.set_full_filename(max(its))
return True
fr.log.info("No restart files found. Model will not reload.")
return False
def _on_reset(self) -> None:
self.clock_trigger.reset()
[docs]
def set_full_filename(self, it: int) -> None:
"""Set the full filename with the iteration number and rank."""
rank = fr.utils.get_my_rank()
filename = f"{self.file_path.stem}_{it}_{rank}{self.file_path.suffix}"
self.file = self.file_path.parent / filename
[docs]
def restart(self, model: fr.Model) -> None:
"""
Restart the model.
Description
-----------
This function stops the model, saves the model state to a file,
and executes the restart command.
Parameters
----------
model : fr.Model
The model to restart.
"""
fr.log.info(
"Stopping model at it: %d, time: %f",
model.model_state.clock.it,
model.model_state.clock.time)
model.stop()
model.save(self.file)
fr.log.info(model.mset.timer)
if isinstance(self.restart_command, str):
self._restart_from_command()
return
if isinstance(self.restart_command, Callable):
self.restart_command()
return
msg = "No restart command is set. The model can't restart."
raise ValueError(msg)
def _restart_from_command(self) -> None:
"""Restart the model from the restart command."""
if isinstance(self.restart_command, str):
fr.log.info(f"Running restart command: {self.restart_command}")
fr.utils.mpi_barrier()
if fr.utils.MPI_AVAILABLE:
result = subprocess.run( # noqa: S603
self.restart_command.split(),
capture_output=True,
text=True,
check=False)
fr.log.notice(result.stdout)
if result.stderr:
fr.log.error(result.stderr)
fr.utils.mpi_barrier()
sys.exit()
def _touch_restart_directory(self) -> None:
"""Touch the restart directory."""
fr.log.verbose("Touching the restart directory.")
self.file_path.parent.mkdir(parents=True, exist_ok=True)
# ================================================================
# Properties
# ================================================================
@property
def info(self) -> dict: # noqa: D102
res = super().info
if not self.is_enabled:
return res
if self.realtime_interval is not None:
res["Realtime Restart Interval"] = self.realtime_interval
if self.clock_trigger is not None:
res["Clock Trigger"] = self.clock_trigger
if self.restart_command is not None:
res["Restart Command"] = self.restart_command
res["File Path"] = self.file_path
return res
@property
def realtime_interval(self) -> np.timedelta64 | None:
"""The interval in real time at which the model should restart."""
return self._realtime_interval
@realtime_interval.setter
def realtime_interval(self, realtime_interval: np.timedelta64 | None) -> None:
if realtime_interval is None:
self._realtime_interval = None
return
# check that the interval is a timedelta
if not isinstance(realtime_interval, np.timedelta64):
msg = "The interval should be a numpy timedelta64."
raise TypeError(msg)
self._realtime_interval = realtime_interval
@property
def clock_trigger(self) -> fr.ClockTrigger | None:
"""
Triggers the restart based on the model time.
By default, the clock trigger will not trigger on the first step.
Any other option will be overridden. To trigger on the first step,
set the trigger_on_first_step to True after setting the clock trigger,
e.g.:
.. code-block:: python
import fridom.framework as fr
restart_module = fr.modules.RestartModule(clock_trigger=fr.ClockTrigger())
restart_module.clock_trigger.trigger_on_first_step = True
"""
return self._clock_trigger
@clock_trigger.setter
def clock_trigger(self, clock_trigger: fr.ClockTrigger | None) -> None:
self._clock_trigger = clock_trigger
if isinstance(clock_trigger, fr.ClockTrigger):
clock_trigger.trigger_on_first_step = False
@property
def file_path(self) -> Path:
"""The path to the restart file."""
return self._file_path
@file_path.setter
def file_path(self, file_path: Path | str) -> None:
# cast to Path
if isinstance(file_path, str):
file_path = Path(file_path)
# sanitize the filename
filename = file_path.stem
# replace "_" with "-"
if "_" in filename:
msg = "The filename should not contain the character '_'."
msg += " Replacing '_' with '-' in the filename."
fr.log.warning(msg)
filename = filename.replace("_", "-")
self._file_path = file_path.with_name(filename + file_path.suffix)
@property
def restart_command(self) -> str:
"""The command to restart the job."""
return self._restart_command
@restart_command.setter
def restart_command(self, restart_command: str | None) -> None:
self._restart_command = None
if not self.is_enabled:
return
if restart_command is not None:
self._restart_command = restart_command
return
# Get the job id from the environment
job_id = os.getenv("SLURM_JOB_ID")
if job_id is None:
msg = "No job id is found in the environment."
msg += " The model will not be able to restart."
fr.log.warning(msg)
return
if not job_id.isdigit():
fr.log.warning("Invalid job id. The model will not be able to restart.")
return
job_info = subprocess.run( # noqa: S603
["/usr/bin/scontrol", "show", "job", job_id],
capture_output=True,
text=True,
check=False)
# extract the restart command from the job info
command = None
for line in job_info.stdout.split("\n"):
if line.strip().startswith("Command="):
command = line.split("=", 1)[1].strip()
if command is None:
fr.log.warning(
"No restart command is set. The model will not be able to restart.")
self._restart_command = f"sbatch {command}"