Source code for fridom.framework.utils.array_ops
"""array_ops.py - Utilities for array operations."""
from __future__ import annotations
from typing import Callable, Generic, TypeVar
import numpy as np
import fridom.framework as fr
T = TypeVar("T")
[docs]
class SliceableAttribute(Generic[T]):
"""
Class to make an object sliceable.
Parameters
----------
slicer : Callable
The slicer function.
"""
[docs]
def __init__(self, slicer: Callable[[int | slice | tuple[int, slice]], T]) -> None:
self.slicer = slicer
def __getitem__(self, key: int | slice | tuple[int, slice]) -> T:
return self.slicer(key)
[docs]
def modify_array(arr: np.ndarray, where: slice, value: np.ndarray) -> np.ndarray:
"""
Return a new array with the modifications.
Description
-----------
A fundamental difference between JAX and NumPy is that NumPy allows
in-place modification of arrays, while JAX does not. This function does
not modify the input array in place, but returns a new array with the
modifications.
Parameters
----------
`arr` : `np.ndarray`
The array to modify.
`where` : `slice`
The slice to modify.
`value` : `np.ndarray | float | int`
The value to set.
Returns
-------
`np.ndarray`
The modified array.
Examples
--------
>>> import fridom.framework as fr
>>> x = fr.config.ncp.arange(10) # create some array
>>> # instead of x[2:5] = 0, we use the modify_array function
>>> x = fr.utils.modify_array(x, slice(2,5), 0)
"""
if fr.config.backend_is_jax:
return arr.at[where].set(value)
res = arr.copy()
res[where] = value
return res
[docs]
def random_array(shape: tuple[int], seed=12345, **kwargs) -> np.ndarray:
"""Create a random array."""
if "ignore_warning" not in kwargs:
fr.log.warning("The random_array function is deprecated and will be removed in the future.")
fr.log.warning("Please use the create array method from the grid object instead")
if fr.config.backend_is_jax:
# we need to import jax here since it is an optional dependency
import jax # pylint: disable=import-outside-toplevel
key = jax.random.key(seed)
return jax.random.normal(key, shape)
ncp = fr.config.ncp
default_rng = ncp.random.default_rng
return default_rng(seed).standard_normal(shape)
[docs]
def array_is_constant(arr: np.ndarray) -> bool:
"""
Check if an array is constant.
Description
-----------
This function checks if all elements of an array are the same.
Parameters
----------
arr : np.ndarray
The array to check.
Returns
-------
bool
True if the array is constant, False otherwise.
Examples
--------
>>> import fridom.framework as fr
>>> x = fr.config.ncp.ones(10)
>>> fr.utils.array_is_constant(x)
True
>>> x[5] = 0
>>> fr.utils.array_is_constant(x)
False
"""
return fr.config.ncp.allclose(arr, arr.flatten()[0])