Source code for fridom.framework.modules.progress_bar
"""A progress bar module to display the progress of the simulation."""
from __future__ import annotations
import time
import numpy as np
import fridom.framework as fr
[docs]
class ProgressBar(fr.modules.Module):
"""
A progress bar module to display the progress of the simulation.
Description
-----------
The progress bar class is a wrapper around the tqdm progress bar. It
has a custom format and handles the output to the stdout when the
stdout is a file.
Parameters
----------
disable : bool
Whether to disable the progress bar.
"""
name = "Progress Bar"
[docs]
def __init__(self) -> None:
super().__init__()
self._pbar = None
self._file_output = None
self._output = None
self._last_call = None
self._main_loop_type = None
self._datetime_formatting = None
self._start_value = None
self._final_value = None
[docs]
@fr.modules.module_method
def start(self) -> None: # noqa: D102
# only the main rank should print the progress bar
disable = not fr.utils.I_AM_MAIN_RANK
# ----------------------------------------------------------------
# Set the progress bar format
# ----------------------------------------------------------------
bar_format = "{percentage:3.2f}%|{bar}| "
bar_format += "[{elapsed}<{remaining}]{postfix}"
# ----------------------------------------------------------------
# Check if the stdout is a file
# ----------------------------------------------------------------
file_output = fr.utils.stdout_is_file()
if file_output:
# if the stdout is a file, tqdm would print to the stderr by default
# we could instead print to the stdout, but this would mess up
# the look of the progress bar due to "\r" characters
# so we create a StringIO object to capture the output
# and adjust the progress bar accordingly
import io
output = io.StringIO()
else:
import sys
output = sys.stdout
# ----------------------------------------------------------------
# Create the progress bar
# ----------------------------------------------------------------
from tqdm import tqdm
pbar = tqdm(
total=100,
disable=disable,
bar_format=bar_format,
unit="%",
file=output)
# ----------------------------------------------------------------
# Set the attributes
# ----------------------------------------------------------------
self._pbar = pbar
self._file_output = file_output
self._output = output
self._last_call = time.time()
self._main_loop_type = None
self._datetime_formatting = None
self._start_value = None
self._final_value = None
[docs]
@fr.modules.module_method
def stop(self) -> None: # noqa: D102
if self._pbar is not None:
self._pbar.close()
self._pbar = None
self._file_output = None
self._output = None
self._last_call = None
self._main_loop_type = None
self._datetime_formatting = None
self._start_value = None
self._final_value = None
[docs]
def set_options(self,
main_loop_type: str,
datetime_formatting: bool,
start_value: float,
final_value: float):
self._main_loop_type = main_loop_type
self._datetime_formatting = datetime_formatting
self._start_value = start_value
self._final_value = final_value
[docs]
@fr.modules.module_method
def update(self, mz: fr.ModelState) -> fr.ModelState: # noqa: D102
if self._start_value is None:
return mz
# get the time between the last call (in milliseconds)
now = time.time()
elapsed = now - self._last_call
self._last_call = now
elapsed = f"{int(elapsed*1e3)} ms/it"
# Get the current progress value
match self._main_loop_type:
case "for loop":
value = mz.clock.it
case "while loop":
value = mz.clock.time
# map the value to a percentage
value = 100 * ( (value - self._start_value)
/ (self._final_value - self._start_value) )
# clamp the value between 0 and 100
value = max(0, min(100, value))
# Create a postfix string for the progress bar
if self._datetime_formatting:
time_str = np.datetime64(int(mz.clock.time), "s")
else:
time_str = fr.utils.humanize_number(mz.clock.time, unit="seconds")
postfix = f"It: {mz.clock.it} - Time: {time_str}"
# update the progress bar
self._pbar.n = value
self._pbar.set_postfix_str(f"{elapsed} at {postfix}")
if not self._file_output:
return mz
# print the progress to the stdout
fr.log.info(self._output.getvalue().split("\r")[1])
# clear the output string
self._output.seek(0)
return mz