fridom.framework.utils.jax_utils.jaxjit

Contents

fridom.framework.utils.jax_utils.jaxjit#

fridom.framework.utils.jax_utils.jaxjit(fun: callable, *args, **kwargs) callable[source]#

Decorator for JAX JIT compilation.

Description#

This decorator is a wrapper around jax.jit. When jax is not installed, the function is returned as it is.

Parameters#

funcallable

The function to JIT compile.

Returns#

callable

The JIT compiled function.

Examples#

>>> import fridom.framework as fr
>>> @fr.utils.jaxjit
... def my_function(x):
...     return x**2