Source code for fridom.framework.modules.netcdf_writer

import os
from typing import Union
from netCDF4 import Dataset
import numpy as np
import fridom.framework as fr

[docs] class NetCDFWriter(fr.modules.Module): """ Writing model output to NetCDF files. Parameters ---------- `write_interval` : `np.timedelta64 | float` The interval at which the data should be written to the file. `filename` : `str`, optional The name of the file to write to. Default is "snap" (no directory). `directory` : `str`, optional The directory where the files should be stored. Default is "snapshots". `start_time` : `np.datetime64`, optional The time at which the first file should be written. Default is `end_time` : `np.datetime64`, optional The time at which the last file should be written. Default is None. `restart_interval` : `np.timedelta64`, optional The interval at which a new file should be created. Default is None. `snap_slice` : `tuple`, optional The slice of the grid that should be written to the file. Default is None. `get_variables` : `callable`, (default: None) A function that returns a list of field variables that should be written to the file. If None, all fields of the State object will be written. The function signature of get_variables is: `get_variables(mz: 'ModelState') -> list[FieldVariable]` Examples -------- The following example shows how to create a netCDF output from a nonhydrostatic model using the SingleWave initial condition. .. code-block:: python import fridom.nonhydro as nh import numpy as np import matplotlib.pyplot as plt # create a netCDF writer that outputs u, v, w, b, and p nc_writer = nh.modules.NetCDFWriter( get_variables = lambda mz: [mz.u, mz.v, mz.w, mz.b, mz.z_diag.p], write_interval = np.timedelta64(1, 's')) # create the model grid = nh.grid.cartesian.Grid( N=[128]*3, L=[1]*3, periodic_bounds=(True, True, True)) mset = nh.ModelSettings(grid=grid, dsqr=0.02, Ro=0.0) mset.time_stepper.dt = np.timedelta64(10, 'ms') # add the netCDF writer to the diagnostics mset.diagnostics.add_module(nc_writer) mset.setup() z = nh.initial_conditions.SingleWave(mset, kx=2, ky=0, kz=1) model = nh.Model(mset) model.z = z model.run(runlen=np.timedelta64(10, 's')) """ name = "NetCDFWriter"
[docs] def __init__(self, write_interval: Union[np.timedelta64, float], filename: str = "snap", start_time: Union[np.datetime64, float] = 0, end_time: Union[np.datetime64, float, None] = None, restart_interval: Union[np.timedelta64, float, None] = None, snap_slice: tuple | None = None, directory: str | None = None, get_variables: 'callable | None' = None, ): super().__init__() directory = directory or "snapshots" filename = os.path.join(directory, filename) self.execute_at_start = True # Convert the times to seconds if isinstance(write_interval, np.timedelta64): write_interval = fr.utils.to_seconds(write_interval) if isinstance(restart_interval, np.timedelta64): restart_interval = fr.utils.to_seconds(restart_interval) if isinstance(start_time, np.datetime64): start_time = fr.utils.to_seconds(start_time) if isinstance(end_time, np.datetime64): end_time = fr.utils.to_seconds(end_time) if get_variables is None: def get_variables(mz: 'fr.ModelState'): return mz.z.field_list if snap_slice is not None: raise NotImplementedError("snap_slice is not implemented yet.") # ---------------------------------------------------------------- # Set Attributes # ---------------------------------------------------------------- self.directory = directory self.filename = filename self.start_time = start_time self.end_time = end_time self.write_interval = write_interval self.restart_interval = restart_interval self.snap_slice = snap_slice self.get_variables = get_variables # private attributes self._current_start_time = None self._last_checkpoint_time = None self._last_write_time = None self._file_is_open = False self._ncfile = None return
[docs] @fr.modules.module_method def setup(self, mset: 'fr.ModelSettingsBase') -> None: super().setup(mset) # create snapshot folder if it doesn't exist fr.log.verbose(f"Touching snapshot directory: {self.directory}") os.makedirs(self.directory, exist_ok=True) # snap slice: if self.snap_slice is None: self.snap_slice = tuple([slice(None)]*self.grid.n_dims) return
[docs] @fr.modules.module_method def start(self): if self._file_is_open: fr.log.warning( "NetCDFWriter: start() called while a file is already open. Continue with closing the file") self._close_file() return
[docs] @fr.modules.module_method def stop(self): if self._file_is_open: self._close_file() self._current_start_time = None self._last_checkpoint_time = None self._last_write_time = None return
[docs] @fr.modules.module_method def update(self, mz: 'fr.ModelState') -> 'fr.ModelState': time = mz.clock.time # ---------------------------------------------------------------- # Check if the model time is in the writing range # ---------------------------------------------------------------- # check if the model time is smaller than the start time if self.start_time is not None and time < self.start_time: return mz # check if the model time is larger than the end time if self.end_time is not None: if time > self.end_time and not self._file_is_open: return mz if time > self.end_time and self._file_is_open: self._close_file() return mz # ---------------------------------------------------------------- # Check if it is time to write # ---------------------------------------------------------------- if self._last_write_time is None or self._last_checkpoint_time is None: time_to_write = True else: next_write_time = self._last_write_time + self.write_interval if (self._last_checkpoint_time < next_write_time and time >= next_write_time): time_to_write = True else: time_to_write = False self._last_checkpoint_time = time if not time_to_write: return mz # ---------------------------------------------------------------- # Cehck if the current start time is set # ---------------------------------------------------------------- if self._current_start_time is None: self._current_start_time = time # ---------------------------------------------------------------- # Check if the file should be closed # ---------------------------------------------------------------- if self.restart_interval is not None: next_restart_time = self._current_start_time + self.restart_interval if time >= next_restart_time and self._file_is_open: self._close_file() # ---------------------------------------------------------------- # Create a new file if the current file is not open # ---------------------------------------------------------------- if not self._file_is_open: start_time = self._current_start_time if self.restart_interval is not None: while time - start_time >= self.restart_interval: start_time += self.restart_interval self._current_start_time = start_time self._create_file(mz) # ---------------------------------------------------------------- # Write data # ---------------------------------------------------------------- self._write_data(mz) self._last_write_time = time return mz
def _create_file(self, mz: 'fr.ModelState'): # ---------------------------------------------------------------- # Create the filename # ---------------------------------------------------------------- base, ext = os.path.splitext(self.filename) ext = ext.lower() base = base if ext in [".nc", ".cdf"] else self.filename ext = ext if ext in [".nc", ".cdf"] else ".cdf" clock = mz.clock tot_time = clock.get_total_time(clock.passed_time) if not isinstance(tot_time, np.datetime64): tot_time = fr.utils.humanize_number(tot_time, unit="seconds") tot_time = tot_time.replace(" ", "-") filename = f"{base}_{tot_time}{ext}" # ---------------------------------------------------------------- # Create the NetCDF file # ---------------------------------------------------------------- fr.log.info(f"Creating NetCDF file: {filename}") # check if the model is running in parallel parallel = (self.grid.domain_decomp.parallel) ncfile = Dataset(filename, "w", format="NETCDF4", parallel=parallel) dtype = fr.config.dtype_real n_dims = self.grid.n_dims if n_dims <= 3: x_names = ['x', 'y', 'z'][:n_dims] else: x_names = [f"x{i}" for i in range(n_dims)] # ---------------------------------------------------------------- # General attributes # ---------------------------------------------------------------- ncfile.description = f"fridom: {self.mset.model_name}" import time as system_time ncfile.created = system_time.ctime(system_time.time()) # ---------------------------------------------------------------- # Create the dimensions # ---------------------------------------------------------------- for i, name in enumerate(x_names): nx = len(self.grid.x_global[i][self.snap_slice[i]]) ncfile.createDimension(name, nx) _time_dim = ncfile.createDimension('time', None) # ---------------------------------------------------------------- # Create the variables # ---------------------------------------------------------------- # Coordinate variables x = [ncfile.createVariable(name, dtype, (name,)) for name in x_names] time = ncfile.createVariable("time", dtype, ("time",)) for xi, name in zip(x, x_names): xi.units = "m" xi.long_name = f"{name} coordinate" # time.units = f"seconds since {mz.start_time}" time.units = "seconds" time.long_name = "UTC time" time.calendar = "standard" time.standard_name = "time" if parallel: time.set_collective(True) # store the coordinates for i in range(n_dims): x[i][:] = fr.utils.to_numpy(self.grid.x_global[i][self.snap_slice[i]]) # create the output variables for var in self.get_variables(mz): nc_var = ncfile.createVariable( var.name, dtype, ("time", *x_names[::-1])) nc_var.units = var.units nc_var.long_name = var.long_name for key, value in var.nc_attrs.items(): setattr(nc_var, key, value) if parallel: nc_var.set_collective(True) # ---------------------------------------------------------------- # Store the attributes # ---------------------------------------------------------------- self._file_is_open = True self._ncfile = ncfile return def _write_data(self, mz: 'fr.ModelState'): time = self._ncfile.variables["time"] time_ind = time.size ind = time_ind, *tuple(slice(None) for _ in range(self.grid.n_dims)) time[time_ind] = mz.clock.passed_time for var in self.get_variables(mz): nc_var = self._ncfile.variables[var.name] arr = var.unpad() nc_var[ind] = fr.utils.to_numpy(arr.T) return def _close_file(self): self._ncfile.close() self._ncfile = None self._file_is_open = False return