Source code for fridom.framework.modules.animation.video_writer
"""Create a mp4 video from the model."""
from __future__ import annotations
import os
import warnings
from pathlib import Path
from typing import Union
import numpy as np
import fridom.framework as fr
# FIXME(Silvano): The parallel option is not working with jax anymore
[docs]
class VideoWriter(fr.modules.Module):
"""
Create a mp4 video from the model.
Description
-----------
To create a mp4 video from the model, one must provide a `ModelPlotter`
that will be used to create the figure. The video writer does not support
MPI parallelism.
Parameters
----------
model_plotter : ModelPlotter
The model plotter that will be used to create the figure.
write_interval : np.timedelta64 | float
The interval at which the data should be written to the file.
filename : str, optional (default="output.mp4")
The filename of the video (will be stored in videos/filename).
fps : int, optional (default=30)
The frames per second of the video.
parallel : bool, optional (default=True)
If True, the video writer will use parallelism to create the video.
max_jobs : float, optional (default=0.4)
The maximum fraction of the available threads that will be used.
"""
name = "Video Writer"
[docs]
def __init__(self,
model_plotter: fr.modules.animation.ModelPlotter,
model_time_per_second: np.timedelta64 | float,
filename: str = "output.mp4",
fps: int = 30,
parallel: bool = True,
max_jobs: float = 0.2,
) -> None:
super().__init__()
# Convert the times to seconds
if isinstance(model_time_per_second, np.timedelta64):
model_time_per_second = fr.utils.to_seconds(
model_time_per_second)
# Compute the write interval
write_interval = model_time_per_second / fps
# add .mp4 extension if it is not there
path = Path(filename)
if not path.suffix:
filename += ".mp4"
self.model_plotter = model_plotter
self.write_interval = write_interval
self.filename = os.path.join("videos", filename)
self.fps = fps
self.max_jobs = max_jobs
# set the flag for MPI availability
self.mpi_available = False
self.writer = None
self.parallel = parallel
self.fig = None
self._last_write_time = None
self._last_checkpoint_time = None
def _on_setup(self) -> None:
# create video folder if it does not exist
if not os.path.exists("videos"):
fr.log.info("Creating videos folder")
os.makedirs("videos")
# delete the file if it already exists
if os.path.exists(self.filename):
fr.log.notice(f"Deleting existing video file {self.filename}")
os.remove(self.filename)
# use maximum of 40% the available threads
if self.parallel:
import multiprocessing as mp
self.maximum_jobs = int(self.max_jobs*mp.cpu_count())
self.fig = None
[docs]
@fr.modules.module_method
def start(self) -> None: # noqa: D102
# list for the jobs and queues for creating the figures
self.running_jobs = [] # Processes
self.open_queues = [] # Queues
# start the writer
if self.writer is not None and not self.writer.closed:
fr.log.warning(
"VideoWriter.start() called without closing the previous writer.",
"Continue with the previous writer.")
else:
import imageio
self.writer = imageio.get_writer(self.filename, fps=self.fps)
[docs]
@fr.modules.module_method
def stop(self) -> None: # noqa: D102
# collect all figures
while len(self.running_jobs) > 0:
fr.log.info("Collecting remaining figures")
self.collect_figures()
if self.writer is not None:
fr.log.debug("Closing the video writer")
self.writer.close()
fr.log.debug("Video writer closed")
if self.fig is not None:
import matplotlib.pyplot as plt
plt.close(self.fig)
self.fig = None
self._last_write_time = None
self._last_checkpoint_time = None
[docs]
@fr.modules.module_method
def update(self, mz: fr.ModelState) -> fr.ModelState: # noqa: D102
time = mz.clock.time
# ----------------------------------------------------------------
# Check if it is time to write
# ----------------------------------------------------------------
if self._last_write_time is None or self._last_checkpoint_time is None:
time_to_write = True
else:
next_write_time = self._last_write_time + self.write_interval
if (self._last_checkpoint_time < next_write_time and
time >= next_write_time):
time_to_write = True
else:
time_to_write = False
self._last_checkpoint_time = time
if not time_to_write:
return mz
self._last_write_time = time
if self.parallel:
self.parallel_update(mz)
else:
self.single_update(mz)
return mz
[docs]
def parallel_update(self, mz: fr.ModelState) -> None:
# collect finished figures
self.collect_figures()
# wait until there is space for a new job
while len(self.running_jobs) >= self.maximum_jobs:
self.collect_figures()
# create a new figure
import multiprocessing as mp
q = mp.Queue()
diagnostics = self.mset.diagnostics
self.mset.diagnostics = None
kw = {"kwargs": self.model_plotter.prepare_arguments(mz),
"output_queue": q,
"model_plotter": self.model_plotter}
job = mp.Process(target=VideoWriter.p_make_figure, kwargs=kw)
self.mset.diagnostics = diagnostics
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
job.start()
self.open_queues.append(q)
self.running_jobs.append(job)
[docs]
def single_update(self, mz: fr.ModelState) -> None:
if self.fig is None:
self.fig = self.model_plotter.create_figure()
else:
self.fig.clear()
kw = self.model_plotter.prepare_arguments(mz)
self.model_plotter.update_figure(fig=self.fig, **kw)
img = self.model_plotter.convert_to_img(self.fig)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
self.writer.append_data(img)
[docs]
def show_video(self, width=600):
from IPython.display import Video
return Video(self.filename, width=width, embed=True)
@property
def info(self) -> dict: # noqa: D102
res = super().info
res["filename"] = self.filename
res["fps"] = self.fps
res["max_jobs"] = self.max_jobs
return res
def __to_numpy__(self, memo):
return self.__deepcopy__(memo)
def __deepcopy__(self, memo):
dont_copy = ["writer", "fig"]
# now deepcopy the object
from copy import deepcopy
new = self.__class__.__new__(self.__class__)
for key in self.__dict__:
if key in dont_copy:
setattr(new, key, getattr(self, key))
else:
setattr(new, key, deepcopy(getattr(self, key), memo))
return new
# =====================================================================
# PARALLEL FUNCTIONS
# =====================================================================