fridom.framework.utils.jax_utils.jaxify

Contents

fridom.framework.utils.jax_utils.jaxify#

fridom.framework.utils.jax_utils.jaxify(cls: Generic[T], dynamic: tuple[str] | None = None) T[source]#

Add JAX pytree support to a class (for jit compilation).

Description#

In order to use jax.jit on custom classes, the class must be registered to jax. This decorator adds the necessary methods to the class to make it compatible with jax.jit. By default, all attributes of an object are considered static, i.e., they they will not be traced by jax. Attributes that should be dynamic must be marked specified with the dynamic argument.

Note

The dynamic argument must be a tuple of attribute names. If you only have one dynamic attribute, use dynamic=(‘attr’,) instead of dynamic=(‘attr’).

Note

If a static attribute is changed, all jit compiled functions of the class must be recompiled. Hence, such attributes should be marked as dynamic. However, marking an attribute as dynamic will increase the computational cost. So, it is advisable to only mark attributes as dynamic that are actually changing during the simulation.

Warning

Methods that are jit compiled with fr.utils.jaxjit will not modify the object in place.

Parameters#

clstype

The class to add jax support to.

dynamictuple[str] | None (default=None)

A tuple of attribute names that should be considered dynamic.

Examples#

A class with no dynamic attributes:

import fridom.framework as fr

@fr.utils.jaxify
class MyClass:
    _dynamic_attributes = ["x",]
    def __init__(self, power):
        self.power = power

    @fr.utils.jaxjit
    def raise_to_power(self, arr):
        return arr**self.power

A class with dynamic attributes:

import fridom.framework as fr
from functools import partial

@partial(fr.utils.jaxify, dynamic=('arr',))
class MyClass:
    def __init__(self, arr, power):
        self.power = power
        self.arr = arr

    @fr.utils.jaxjit
    def raise_to_power(self):
        return self.arr**self.power