Source code for fridom.framework.projection.optimal_balance

import fridom.framework as fr
from typing import Union
from copy import copy, deepcopy
import numpy as np


[docs] class OptimalBalance(fr.projection.Projection): """ Nonlinear balancing using the optimal balance method. Parameters ---------- `mset` : `ModelSettings` The model settings. `base_proj` : `Projection` The projection onto the base point. `ramp_period` : `np.timedelta64 | float | int` (default: None) The ramping period. `mset_backwards` : `ModelSettings` The model settings for the backward ramping. If None, the forward model settings are used. This option is useful when the backwards ramping should be done with a different setup (e.g. negative viscosity). `ramp_type` : `str` The ramping type. Choose from "exp", "pow", "cos", "lin". `disable_diagnostic` : `bool` Whether to disable the diagnostic tendencies during the iterations. `update_base_point` : `bool` Whether to update the base point after each iteration. This has no effect on OB. But it matters for OBTA. Should be True for OBTA. `max_it` : `int` Maximum number of iterations. `stop_criterion` : `float` The stopping criterion. """
[docs] def __init__(self, mset: 'fr.ModelSettingsBase', base_proj: 'fr.projection.Projection', ramp_period: Union[np.timedelta64, float, int, None], mset_backwards: 'fr.ModelSettingsBase' = None, ramp_type: str = "exp", update_base_point: bool = True, max_it: int = 3, stop_criterion: float = 1e-9, disable_diagnostic: bool = True, return_details: bool = False) -> None: mset = deepcopy(mset) super().__init__(mset) self.mset_backwards = mset_backwards or mset self.base_proj = base_proj self.return_details = return_details # initialize the model self.model_forward = fr.Model(self.mset) self.model_backward = fr.Model(self.mset_backwards) if disable_diagnostic: self.model_forward.diagnostics.disable() self.model_backward.diagnostics.disable() # save the parameters self.ramp_period = ramp_period self.ramp_steps = int(ramp_period / mset.time_stepper.dt) self.ramp_func = OptimalBalance.get_ramp_func(ramp_type) self.max_it = max_it self.stop_criterion = stop_criterion self.update_base_point = update_base_point self.default_scaling = mset.tendencies.advection.scaling # prepare the balancing self.z_base = None return
[docs] def calc_base_coord(self, z: 'fr.StateBase') -> None: self.z_base = self.base_proj(z) return
[docs] def forward_to_nonlinear(self, z: 'fr.StateBase') -> 'fr.StateBase': """ Perform forward ramping from linear model to nonlinear model. """ model = self.model_forward model.reset() mset = model.mset time_stepper = model.time_stepper # make sure the time step is positive time_stepper.dt = np.abs(time_stepper.dt) # initialize the model model.z = copy(z) # perform the forward ramping for n in range(self.ramp_steps): mset.tendencies.advection.scaling = self.ramp_func(n / self.ramp_steps) * self.default_scaling model.step() return model.z
[docs] def backward_to_linear(self, z: 'fr.StateBase') -> 'fr.StateBase': """ Perform backward ramping from nonlinear model to linear model. """ model = self.model_backward model.reset() mset = model.mset # make sure the time step is negative model.time_stepper.dt = - np.abs(model.time_stepper.dt) # initialize the model model.z = copy(z) # perform the backward ramping for n in range(self.ramp_steps): mset.tendencies.advection.scaling = self.ramp_func(1 - n / self.ramp_steps) * self.default_scaling model.step() return model.z
[docs] def forward_to_linear(self, z: 'fr.StateBase') -> 'fr.StateBase': """ Perform forward ramping from nonlinear model to linear model. """ model = self.model_forward model.reset() mset = model.mset # make sure the time step is positive model.time_stepper.dt = np.abs(model.time_stepper.dt) # initialize the model model.z = copy(z) # perform the forward ramping for n in range(self.ramp_steps): mset.tendencies.advection.scaling = self.ramp_func(n / self.ramp_steps) * self.default_scaling model.step() return model.z
[docs] def backward_to_nonlinear(self, z: 'fr.StateBase') -> 'fr.StateBase': """ Perform backward ramping from linear model to nonlinear model. """ model = self.model_backward model.reset() mset = model.mset # make sure the time step is negative model.time_stepper.dt = - np.abs(model.time_stepper.dt) # initialize the model model.z = copy(z) # perform the backward ramping for n in range(self.ramp_steps): mset.tendencies.advection.scaling = self.ramp_func(1 - n / self.ramp_steps) * self.default_scaling model.step() return model.z
[docs] def get_ramp_func(ramp_type): if ramp_type == "exp": def ramp_func(theta): t1 = 1./np.maximum(1e-32,theta ) t2 = 1./np.maximum(1e-32,1.-theta ) return np.exp(-t1)/(np.exp(-t1)+np.exp(-t2)) elif ramp_type == "pow": def ramp_func(theta): return theta**3/(theta**3+(1.-theta)**3) elif ramp_type == "cos": def ramp_func(theta): return 0.5*(1.-np.cos(np.pi*theta)) elif ramp_type == "lin": def ramp_func(theta): return theta else: raise ValueError( "Invalid ramp type. Choose from 'exp', 'pow', 'cos', 'lin'.") return ramp_func
def __call__(self, z: 'fr.StateBase') -> 'fr.StateBase': """ Project a state to the balanced subspace using optimal balance. Parameters ---------- `z` : `State` The state to project. Returns ------- `State` The projection of the state onto the balanced subspace. """ iterations = np.arange(self.max_it) errors = np.ones(self.max_it) # save the base coordinate self.calc_base_coord(z) z_res = copy(z) # start the iterations fr.log.info("Starting optimal balance iterations") for it in iterations: fr.log.verbose(f"Starting iteration {it}") # backward ramping fr.log.verbose("Performing backward ramping") z_lin = self.backward_to_linear(z_res) # project to the base point z_lin = self.base_proj(z_lin) # forward ramping fr.log.verbose("Performing forward ramping") z_bal = self.forward_to_nonlinear(z_lin) # exchange base point coordinate z_new = z_bal - self.base_proj(z_bal) + self.z_base # calculate the error errors[it] = error = z_new.norm_of_diff(z_res) fr.log.verbose(f"Difference to previous iteration: {error:.2e}") # update the state z_res = z_new # check the stopping criterion if error < self.stop_criterion: fr.log.info("Stopping criterion reached.") break # check if the error is increasing if it > 0 and error > errors[it-1]: fr.log.warning("Error is increasing. Stopping iterations.") break # recalculate the base coordinate if needed if self.update_base_point: # check if it is not the last iteration if it < self.max_it - 1: self.calc_base_coord(z_res) if self.return_details: return z_res, (iterations, errors) else: return z_res