"""netcdf_writer.py - Writing model output to NetCDF files."""
from __future__ import annotations
from pathlib import Path
from typing import Callable
import numpy as np
from netCDF4 import Dataset
import fridom.framework as fr
[docs]
class NetCDFWriter(fr.modules.Module):
"""
Writing model output to NetCDF files.
Parameters
----------
write_trigger : fr.ClockTrigger, optional
The trigger that determines when the data should be written to the file.
Default is None which means that the data will be written at every time step.
restart_trigger : fr.ClockTrigger, optional
The trigger that determines when a new file should be created.
Default is None which means that only one file will be created.
filename : str, optional
The name of the file to write to. Default is "snap" (no directory).
directory : str, optional
The directory where the files should be stored. Default is "snapshots".
get_variables : callable, (default: None)
A function that returns a list of scalar fields that should be written
to the file. If None, all fields of the State object will be written.
The function signature of get_variables is:
`get_variables(mz: 'ModelState') -> list[ScalarField]`
"""
name = "NetCDFWriter"
[docs]
def __init__(self,
write_trigger: fr.ClockTrigger | None = None,
restart_trigger: fr.ClockTrigger | None = None,
filename: str = "snap",
directory: str | None = None,
get_variables: Callable | None = None,
) -> None:
super().__init__()
directory = directory or "snapshots"
filename = Path(directory) / filename
self.execute_at_start = True
if get_variables is None:
def get_variables(mz: fr.ModelState) -> list[fr.ScalarField]:
return mz.z.field_list
# ----------------------------------------------------------------
# Set Attributes
# ----------------------------------------------------------------
self.directory = directory
self.filename = filename
self.write_trigger = write_trigger or fr.ClockTrigger()
self.restart_trigger = restart_trigger
self._snap_slice = None
self._add_timestamp = True
self.get_variables = get_variables
# private attributes
self._file_is_open = False
self._ncfile = None
def _on_setup(self) -> None:
# create snapshot folder if it doesn't exist
fr.log.verbose(f"Touching snapshot directory: {self.directory}")
Path(self.directory).mkdir(parents=True, exist_ok=True)
# snap slice:
if self._snap_slice is None:
self._snap_slice = tuple([slice(None)]*self.grid.n_dims)
[docs]
@fr.modules.module_method
def start(self) -> None: # noqa: D102
if self._file_is_open:
msg = "NetCDFWriter: start() called while a file is already open."
fr.log.warning(msg)
self._close_file()
[docs]
@fr.modules.module_method
def stop(self) -> None: # noqa: D102
if self._file_is_open:
self._close_file()
def _on_reset(self) -> None:
self.write_trigger.reset()
if self.restart_trigger is not None:
self.restart_trigger.reset()
[docs]
@fr.modules.module_method
def update(self, mz: fr.ModelState) -> fr.ModelState: # noqa: D102
# ----------------------------------------------------------------
# Check if it is time to write
# ----------------------------------------------------------------
if not self.write_trigger.check(mz.clock):
return mz
# ----------------------------------------------------------------
# Check if the file should be restarted
# ----------------------------------------------------------------
if self.restart_trigger is not None and self.restart_trigger.check(mz.clock):
self._close_file()
# ----------------------------------------------------------------
# Create a new file if the current file is not open
# ----------------------------------------------------------------
if not self._file_is_open:
self._create_file(mz)
# ----------------------------------------------------------------
# Write data
# ----------------------------------------------------------------
self._write_data(mz)
return mz
def _format_filename(self, clock: fr.Clock) -> Path:
"""
Add a timestamp to the filename.
Parameters
----------
clock : fr.Clock
The clock of the model with the current time.
Returns
-------
Path
The formatted filename.
"""
# we first remove the suffix from the filename, if the suffix is .nc or .cdf
suffix = self.filename.suffix.lower()
if suffix in [".nc", ".cdf"]:
base_name = self.filename.parent / self.filename.stem
else:
base_name = self.filename
suffix = ".cdf"
# add the timestamp to the filename
if not self.add_timestamp:
return base_name.with_name(f"{base_name.stem}{suffix}")
tot_time = clock.get_total_time()
if isinstance(tot_time, np.datetime64):
time_stamp = tot_time
else:
time_stamp = fr.utils.humanize_number(tot_time, unit="seconds")
time_stamp = time_stamp.replace(" ", "_")
return base_name.with_name(f"{base_name.stem}_{time_stamp}{suffix}")
def _create_file(self, mz: fr.ModelState) -> None:
# ----------------------------------------------------------------
# Make sure that there is no file open
# ----------------------------------------------------------------
self._close_file()
# ----------------------------------------------------------------
# Create the filename
# ----------------------------------------------------------------
filename = self._format_filename(mz.clock)
# ----------------------------------------------------------------
# Create the NetCDF file
# ----------------------------------------------------------------
fr.log.info(f"Creating NetCDF file: {filename}")
# check if the model is running in parallel
parallel = (self.grid.domain_decomp.parallel)
ncfile = Dataset(filename, "w", format="NETCDF4", parallel=parallel)
dtype = fr.config.dtype_real
n_dims = self.grid.n_dims
if n_dims <= 3: # noqa: PLR2004
x_names = ["x", "y", "z"][:n_dims]
else:
x_names = [f"x{i}" for i in range(n_dims)]
# ----------------------------------------------------------------
# General attributes
# ----------------------------------------------------------------
ncfile.description = f"fridom: {self.mset.model_name}"
import time as system_time
ncfile.created = system_time.ctime(system_time.time())
# ----------------------------------------------------------------
# Create the dimensions
# ----------------------------------------------------------------
for i, name in enumerate(x_names):
nx = len(self.grid.x_global[i][self.snap_slice[i]])
ncfile.createDimension(name, nx)
_time_dim = ncfile.createDimension("time", None)
# ----------------------------------------------------------------
# Create the variables
# ----------------------------------------------------------------
# Coordinate variables
x = [ncfile.createVariable(name, dtype, (name,)) for name in x_names]
time = ncfile.createVariable("time", dtype, ("time",))
for xi, name in zip(x, x_names):
xi.units = "m"
xi.long_name = f"{name} coordinate"
time.units = "seconds"
time.long_name = "UTC time"
time.calendar = "standard"
time.standard_name = "time"
if parallel:
time.set_collective(True)
# store the coordinates
for i in range(n_dims):
x[i][:] = fr.utils.to_numpy(self.grid.x_global[i][self.snap_slice[i]])
# create the output variables
for var in self.get_variables(mz):
nc_var = ncfile.createVariable(
var.name, dtype, ("time", *x_names[::-1]))
nc_var.units = var.units
nc_var.long_name = var.long_name
for key, value in var.nc_attrs.items():
setattr(nc_var, key, value)
if parallel:
nc_var.set_collective(True)
# ----------------------------------------------------------------
# Store the attributes
# ----------------------------------------------------------------
self._file_is_open = True
self._ncfile = ncfile
def _write_data(self, mz: fr.ModelState) -> None:
time = self._ncfile.variables["time"]
time_ind = time.size
ind = time_ind, *tuple(slice(None) for _ in range(self.grid.n_dims))
time[time_ind] = mz.clock.passed_time
for var in self.get_variables(mz):
nc_var = self._ncfile.variables[var.name]
arr = var.unpad()[self.snap_slice]
nc_var[ind] = fr.utils.to_numpy(arr.T)
def _close_file(self) -> None:
if self._ncfile is not None:
fr.log.debug(f"Closing NetCDF file: {self._ncfile.filepath()}")
self._ncfile.close()
self._ncfile = None
self._file_is_open = False
# ----------------------------------------------------------------
# Properties
# ----------------------------------------------------------------
@property
def snap_slice(self) -> tuple[slice, ...]:
"""The slice of the grid that should be written to the file."""
return self._snap_slice
@snap_slice.setter
def snap_slice(self, value: tuple[slice, ...]) -> None:
# we close the file if the slice changes
self._close_file()
# set the new slice
self._snap_slice = value
@property
def add_timestamp(self) -> bool:
"""Whether a timestamp should be added to the filename."""
return self._add_timestamp
@add_timestamp.setter
def add_timestamp(self, value: bool) -> None:
self._add_timestamp = value