fridom.framework.utils.jax_utils.jaxify#
- fridom.framework.utils.jax_utils.jaxify(cls: Generic[T], dynamic: tuple[str] | None = None) T[source]#
Add JAX pytree support to a class (for jit compilation).
Description#
In order to use jax.jit on custom classes, the class must be registered to jax. This decorator adds the necessary methods to the class to make it compatible with jax.jit. By default, all attributes of an object are considered static, i.e., they they will not be traced by jax. Attributes that should be dynamic must be marked specified with the dynamic argument.
Note
The dynamic argument must be a tuple of attribute names. If you only have one dynamic attribute, use dynamic=(‘attr’,) instead of dynamic=(‘attr’).
Note
If a static attribute is changed, all jit compiled functions of the class must be recompiled. Hence, such attributes should be marked as dynamic. However, marking an attribute as dynamic will increase the computational cost. So, it is advisable to only mark attributes as dynamic that are actually changing during the simulation.
Warning
Methods that are jit compiled with fr.utils.jaxjit will not modify the object in place.
Parameters#
- clstype
The class to add jax support to.
- dynamictuple[str] | None (default=None)
A tuple of attribute names that should be considered dynamic.
Examples#
A class with no dynamic attributes:
import fridom.framework as fr @fr.utils.jaxify class MyClass: _dynamic_attributes = ["x",] def __init__(self, power): self.power = power @fr.utils.jaxjit def raise_to_power(self, arr): return arr**self.power
A class with dynamic attributes:
import fridom.framework as fr from functools import partial @partial(fr.utils.jaxify, dynamic=('arr',)) class MyClass: def __init__(self, arr, power): self.power = power self.arr = arr @fr.utils.jaxjit def raise_to_power(self): return self.arr**self.power